Enable mambav2 compat
Browse files- modeling_rcps.py +5 -2
modeling_rcps.py
CHANGED
|
@@ -10,9 +10,12 @@ from torch import nn
|
|
| 10 |
from torch.nn import functional as F
|
| 11 |
|
| 12 |
try:
|
| 13 |
-
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 14 |
except ImportError:
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class RCPSEmbedding(nn.Module):
|
|
|
|
| 10 |
from torch.nn import functional as F
|
| 11 |
|
| 12 |
try:
|
| 13 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
|
| 14 |
except ImportError:
|
| 15 |
+
try:
|
| 16 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
|
| 17 |
+
except ImportError:
|
| 18 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 19 |
|
| 20 |
|
| 21 |
class RCPSEmbedding(nn.Module):
|