mobicham commited on
Commit
e075870
·
verified ·
1 Parent(s): b8693ed

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +158 -0
README.md ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: llama3
3
+ train: false
4
+ inference: false
5
+ pipeline_tag: text-generation
6
+ ---
7
+ This is an NVPF4 calibrated weight-only quantized <a href="https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct"> Meta-Llama-3.1-8B-Instruct</a> model, as presented in our <a href="https://mobiusml.github.io/fp4_blogpost/"> blogpost</a>.
8
+
9
+ ## Usage
10
+
11
+ ### Installation
12
+ ```
13
+ pip install safetensors==0.6.0.dev0
14
+ ```
15
+
16
+ ```Python
17
+ import os, torch
18
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
19
+ from accelerate import init_empty_weights
20
+ from huggingface_hub import snapshot_download
21
+ from os.path import join as pjoin
22
+ from safetensors import safe_open
23
+
24
+ @torch.compile(fullgraph=True)
25
+ def matmul_fp4(x, W_q, scales, group_size, fp4_values):
26
+ def unpack_over_cols(W_q_packed, W_nbits, num_output_cols, dtype):
27
+ n_rows, n_cols = W_q_packed.shape
28
+ device = W_q_packed.device
29
+ shifts = torch.arange(num_output_cols // n_cols, device=device, dtype=W_q_packed.dtype) * W_nbits
30
+ W_q_unpacked = ((W_q_packed.unsqueeze(-1) >> shifts) & ((1 << W_nbits) - 1)).to(dtype)
31
+ W_q_unpacked = W_q_unpacked.view(n_rows, num_output_cols)
32
+ return W_q_unpacked
33
+
34
+ N, K = W_q.shape[0], W_q.shape[1] * 2
35
+ W_q = fp4_values[unpack_over_cols(W_q, W_nbits=4, num_output_cols=K, dtype=torch.int32)]
36
+ W_r = (W_q.float().view([-1, group_size]) * scales.float()).reshape([N, K]).to(x.dtype).T
37
+ return torch.matmul(x, W_r)
38
+
39
+ class AutoModelForCausalLMFP4:
40
+
41
+ @classmethod
42
+ def from_pretrained(
43
+ cls,
44
+ save_dir_or_hub,
45
+ torch_dtype=torch.bfloat16,
46
+ cache_dir=None,
47
+ device_map="cuda:0",
48
+ *args,
49
+ **kwargs
50
+ ):
51
+
52
+ #Download snapshot
53
+ if os.path.exists(save_dir_or_hub):
54
+ save_dir = save_dir_or_hub
55
+ else:
56
+ save_dir = snapshot_download(repo_id=save_dir_or_hub, cache_dir=cache_dir)
57
+
58
+ #Create model from config
59
+ config = AutoConfig.from_pretrained(pjoin(save_dir, "config.json"))
60
+ config.torch_dtype = str(torch_dtype).split('.')[-1]
61
+ with init_empty_weights():
62
+ model = AutoModelForCausalLM.from_config(config)
63
+
64
+ #Load and patch
65
+ state_dict = {}
66
+ with safe_open(pjoin(save_dir, "model.safetensors"), framework="pt", device="cpu") as f:
67
+ for key in f.keys():
68
+ tensor = f.get_tensor(key)
69
+ dtype = torch_dtype if tensor.is_floating_point() else tensor.dtype
70
+ state_dict[key] = tensor.to(device=device_map, dtype=dtype, non_blocking=True)
71
+
72
+ cls.patch_model_for_fp4_inference(model=model, torch_dtype=torch_dtype, device=device_map, state_dict=state_dict)
73
+
74
+ return model
75
+
76
+ @classmethod
77
+ def patch_model_for_fp4_inference(cls, model, torch_dtype, device, state_dict):
78
+
79
+ model.fp4_values = torch.tensor(
80
+ [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6],
81
+ dtype=torch_dtype,
82
+ device=device,
83
+ )
84
+
85
+ def patch_linearlayers(model, fct):
86
+ for name, layer in model.named_children():
87
+ if isinstance(layer, torch.nn.Linear):
88
+ setattr(model, name, fct(layer, name))
89
+ else:
90
+ patch_linearlayers(layer, fct)
91
+
92
+ def patch_enable_fp4(layer, arg):
93
+ #Load params
94
+ if('lm_head' in layer.name):
95
+ return layer
96
+
97
+ if(hasattr(layer, 'weight')):
98
+ del layer.weight
99
+ for key in ['W_q', 'scales', 'shift', 'post_scale', 'meta']:
100
+ param_tag, param = layer.name + '.' + key, None
101
+ if(param_tag in state_dict):
102
+ param = state_dict[param_tag].tolist() if key in ["meta"] else state_dict[param_tag]
103
+ setattr(layer, key, param)
104
+
105
+ #Set forward pass
106
+ def forward(self, x):
107
+ if(hasattr(self, 'weight')):
108
+ out = torch.matmul(x, self.weight.data.T)
109
+ else:
110
+ out = matmul_fp4(x, self.W_q, self.scales, self.meta[-1], model.fp4_values)
111
+ if(self.post_scale is not None):
112
+ out *= self.post_scale
113
+ if(self.shift is not None):
114
+ out += self.shift
115
+ if(self.bias is not None):
116
+ out += self.bias
117
+ return out
118
+
119
+ layer.forward = lambda x: forward(layer, x)
120
+
121
+ return layer
122
+
123
+ try: #FP4 params will fail here
124
+ model.load_state_dict(state_dict, assign=True)
125
+ except:
126
+ pass
127
+
128
+ for name, module in model.named_modules():
129
+ module.name = name
130
+ patch_linearlayers(model, patch_enable_fp4)
131
+ model = model.to(device)
132
+ ```
133
+
134
+ ### Usage
135
+ ```Python
136
+ model_id = "mobiuslabsgmbh/Llama-3.1-8B-Instruct_mxfp4_weights_calib_demo"
137
+ model = AutoModelForCausalLMFP4.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map='cuda')
138
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
139
+
140
+ # Check the trained params
141
+ # print( model.model.layers[-1].self_attn.v_proj.shift)
142
+ # tensor([ 0.0034, -0.0036, 0.0054, ..., 0.0036, -0.0076, -0.0068],
143
+ # device='cuda:0', dtype=torch.bfloat16)
144
+
145
+ # print( model.model.layers[-1].self_attn.v_proj.post_scale)
146
+ # tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0', dtype=torch.bfloat16)
147
+
148
+ outputs = model.generate(
149
+ tokenizer.apply_chat_template(
150
+ [{"role": "user", "content": "Solve the following equation: x^2 + 1 = -1"}],
151
+ tokenize=True,
152
+ add_generation_prompt=True,
153
+ return_tensors="pt",
154
+ ).to(model.device),
155
+ max_new_tokens=256,
156
+ )
157
+ print(tokenizer.decode(outputs[0]))
158
+ ```