Update modeling_jamba.py - LoRA support in Mamba (#6)
Browse files- Update modeling_jamba.py - LoRA support in Mamba (409c904957803838229e49676ec3958c2205783d)
- modeling_jamba.py +12 -4
modeling_jamba.py
CHANGED
|
@@ -943,14 +943,22 @@ class JambaMambaMixer(nn.Module):
|
|
| 943 |
# in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
|
| 944 |
# linear layers, and requires to call the forward pass directly.
|
| 945 |
# The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
|
| 946 |
-
|
| 947 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 948 |
discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
|
| 949 |
-
self.dt_proj
|
|
|
|
|
|
|
|
|
|
| 950 |
|
| 951 |
A = -torch.exp(self.A_log.float())
|
| 952 |
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
| 953 |
-
time_proj_bias =
|
| 954 |
if cache_params is not None and cache_params.seqlen_offset > 0:
|
| 955 |
scan_outputs = selective_state_update(
|
| 956 |
cache_params.ssm_states[self.layer_idx],
|
|
|
|
| 943 |
# in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
|
| 944 |
# linear layers, and requires to call the forward pass directly.
|
| 945 |
# The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
|
| 946 |
+
if hasattr(self.dt_proj, "base_layer"):
|
| 947 |
+
# In case of LoRA, we need to access the base layer to get the weight
|
| 948 |
+
time_proj_bias = self.dt_proj.base_layer.bias
|
| 949 |
+
self.dt_proj.base_layer.bias = None
|
| 950 |
+
else:
|
| 951 |
+
time_proj_bias = self.dt_proj.bias
|
| 952 |
+
self.dt_proj.bias = None
|
| 953 |
discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
|
| 954 |
+
if hasattr(self.dt_proj, "base_layer"):
|
| 955 |
+
self.dt_proj.base_layer.bias = time_proj_bias
|
| 956 |
+
else:
|
| 957 |
+
self.dt_proj.bias = time_proj_bias
|
| 958 |
|
| 959 |
A = -torch.exp(self.A_log.float())
|
| 960 |
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
| 961 |
+
time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
|
| 962 |
if cache_params is not None and cache_params.seqlen_offset > 0:
|
| 963 |
scan_outputs = selective_state_update(
|
| 964 |
cache_params.ssm_states[self.layer_idx],
|