from collections import defaultdict
from collections.abc import Callable, Sequence
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from transformers import PreTrainedTokenizerBase
from steering_vectors.aggregators import Aggregator, mean_aggregator
from steering_vectors.token_utils import adjust_read_indices_for_padding, fix_pad_token
from steering_vectors.utils import batchify
from .layer_matching import LayerType, ModelLayerConfig, guess_and_enhance_layer_config
from .record_activations import record_activations
from .steering_vector import SteeringVector
@dataclass
class SteeringVectorTrainingSample:
positive_str: str
negative_str: str
read_positive_token_index: int | None = None
read_negative_token_index: int | None = None
[docs]
def aggregate_activations(
pos_acts_by_layer: dict[int, list[Tensor]],
neg_acts_by_layer: dict[int, list[Tensor]],
aggregator: Aggregator = mean_aggregator(),
) -> dict[int, Tensor]:
"""
Apply the aggregator to the positive and negative activations for each layer.
Args:
pos_acts_by_layer: A dictionary mapping layer numbers to lists of positive
activations.
neg_acts_by_layer: A dictionary mapping layer numbers to lists of negative
activations.
aggregator: A function that takes the positive and negative activations for a
layer and returns a single vector.
Returns:
A dictionary mapping layer numbers to the aggregated activations.
"""
layer_activations = {}
for layer_num in pos_acts_by_layer.keys():
layer_pos_acts = pos_acts_by_layer[layer_num]
layer_neg_acts = neg_acts_by_layer[layer_num]
direction_vec = aggregator(
torch.concat(layer_pos_acts), torch.concat(layer_neg_acts)
)
layer_activations[layer_num] = direction_vec
return layer_activations
[docs]
@torch.no_grad()
def train_steering_vector(
model: nn.Module,
tokenizer: PreTrainedTokenizerBase,
training_samples: Sequence[SteeringVectorTrainingSample | tuple[str, str]],
layers: list[int] | None = None,
layer_type: LayerType = "decoder_block",
layer_config: ModelLayerConfig | None = None,
move_to_cpu: bool = False,
read_token_index: int | Callable[[str], int] = -1,
show_progress: bool = False,
aggregator: Aggregator = mean_aggregator(),
batch_size: int = 1,
tqdm_desc: str = "Training steering vector",
) -> SteeringVector:
"""
Train a steering vector for the given model.
Args:
model: The model to train the steering vector for
tokenizer: The tokenizer to use
training_samples: A list of training samples, where each sample is a tuple of
(positive_str, negative_str). The steering vector approximate the
difference between the positive prompt and negative prompt activations.
layers: A list of layer numbers to train the steering vector on. If None, train
on all layers.
layer_type: The type of layer to train the steering vector on. Default is
"decoder_block".
layer_config: A dictionary mapping layer types to layer matching functions.
If not provided, this will be inferred automatically.
move_to_cpu: If True, move the activations to the CPU before training. Default False.
read_token_index: The index of the token to read the activations from. Default -1, meaning final token.
show_progress: If True, show a progress bar. Default False.
aggregator: A function that takes the positive and negative activations for a
layer and returns a single vector. Default is mean_aggregator.
"""
pos_acts, neg_acts = extract_activations(
model,
tokenizer,
training_samples,
layers=layers,
layer_type=layer_type,
layer_config=layer_config,
move_to_cpu=move_to_cpu,
read_token_index=read_token_index,
show_progress=show_progress,
batch_size=batch_size,
tqdm_desc=tqdm_desc,
)
layer_activations = aggregate_activations(pos_acts, neg_acts, aggregator)
return SteeringVector(layer_activations, layer_type)
def _formalize_batch(
batch: Sequence[SteeringVectorTrainingSample | tuple[str, str]],
) -> list[SteeringVectorTrainingSample]:
return [_formalize_sample(sample) for sample in batch]
def _formalize_sample(
sample: SteeringVectorTrainingSample | tuple[str, str],
) -> SteeringVectorTrainingSample:
if isinstance(sample, tuple):
return SteeringVectorTrainingSample(sample[0], sample[1])
else:
return sample
def _extract_activations(
model: nn.Module,
tokenizer: PreTrainedTokenizerBase,
prompts: Sequence[str],
layer_type: LayerType,
layer_config: ModelLayerConfig,
layers: list[int] | None,
read_token_indices: Sequence[int],
) -> dict[int, Tensor]:
input = tokenizer(prompts, return_tensors="pt", padding=True)
adjusted_read_indices = adjust_read_indices_for_padding(
torch.tensor(read_token_indices), input["attention_mask"]
)
batch_indices = torch.arange(len(prompts))
results = {}
with record_activations(
model, layer_type, layer_config, layer_nums=layers
) as record:
model(**input.to(model.device))
for layer_num, activation in record.items():
results[layer_num] = activation[-1][
batch_indices.to(activation[-1].device),
adjusted_read_indices.to(activation[-1].device),
].detach()
return results
def _get_token_index(
custom_idx: int | None, default_idx: int | Callable[[str], int], prompt: str
) -> int:
if custom_idx is None:
if isinstance(default_idx, int):
return default_idx
else:
return default_idx(prompt)
else:
return custom_idx