import os, warnings from operator import attrgetter from typing import List, Dict import torch import torch.nn.functional as F from torchtyping import TensorType from transformers import TextIteratorStreamer from transformers import AutoTokenizer, BatchEncoding import nnsight from nnsight import LanguageModel from nnsight.intervention import Envoy warnings.filterwarnings("ignore") os.environ["TOKENIZERS_PARALLELISM"] = "false" # nnsight with multi-threading: https://github.com/ndif-team/nnsight/issues/280 nnsight.CONFIG.APP.GLOBAL_TRACING = False config = { "model_name": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "steering_vec": "activations/candidate_vectors.pt", "offset": "activations/offsets.pt", } def detect_module_attrs(model: LanguageModel) -> str: if "model" in model._modules and "layers" in model.model._modules: return "model.layers" elif "transformers" in model._modules and "h" in model.transformers._modules: return "transformers.h" else: raise Exception("Failed to detect module attributes.") class ModelBase: def __init__( self, model_name: str, steering_vecs: TensorType, offsets: TensorType, tokenizer: AutoTokenizer = None, block_module_attr=None ): if tokenizer is None: self.tokenizer = self._load_tokenizer(model_name) else: self.tokenizer = tokenizer self.model = self._load_model(model_name, self.tokenizer) self.device = self.model.device self.hidden_size = self.model.config.hidden_size if block_module_attr is None: self.block_modules = self.get_module(detect_module_attrs(self.model)) else: self.block_modules = self.get_module(block_module_attr) self.steering_vecs = F.normalize(steering_vecs, dim=-1) self.steering_vecs, self.offsets = self.set_dtype(self.steering_vecs, offsets) def _load_model(self, model_name: str, tokenizer: AutoTokenizer) -> LanguageModel: return LanguageModel(model_name, tokenizer=tokenizer, dispatch=True, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16) def _load_tokenizer(self, model_name) -> AutoTokenizer: tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.padding_side = "left" if not tokenizer.pad_token: tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token = tokenizer.eos_token tokenizer.chat_template = tokenizer.chat_template.replace("<|Assistant|>\\n", "<|Assistant|>") return tokenizer def tokenize(self, prompt: str) -> BatchEncoding: return self.tokenizer(prompt, padding=True, truncation=False, return_tensors="pt") def get_module(self, attr: str) -> Envoy: return attrgetter(attr)(self.model) def set_dtype(self, *vars): if len(vars) == 1: return vars[0].to(self.model.dtype) else: return (var.to(self.model.dtype) for var in vars) def apply_chat_template(self, instruction: str) -> List[str]: messages = [{"role": "user", "content": instruction}] return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) def run_generation(self, inputs, streamer: TextIteratorStreamer, generation_config: Dict): inputs = inputs.to(self.device) _ = self.model._model.generate(**inputs, do_sample=True, streamer=streamer, **generation_config) def steer_generation( self, inputs, streamer: TextIteratorStreamer, k: float, layer: int, coeff: float, generation_config: Dict ): layer_block = self.block_modules[layer] unit_vec = self.steering_vecs[layer] offset = self.offsets[layer] with self.model.generate(inputs, do_sample=True, streamer=streamer, **generation_config): with self.block_modules.all(): acts = layer_block.output[0].clone() proj = (acts - offset) @ unit_vec.unsqueeze(-1) * unit_vec layer_block.output[0][:] = acts - proj + coeff * k * unit_vec def load_model() -> ModelBase: steering_vecs = torch.load(config['steering_vec'], weights_only=True) offsets = torch.load(config['offset'], weights_only=True) model = ModelBase(config['model_name'], steering_vecs=steering_vecs, offsets=offsets) return model