Build
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/{_mamba_ssm_bft6nicqkg6ni.abi3.so → _mamba_ssm_85e64a5.abi3.so} +2 -2
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py +3 -3
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/{_mamba_ssm_nmrmresto7zfi.abi3.so → _mamba_ssm_85e64a5.abi3.so} +2 -2
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/_ops.py +3 -3
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/{_mamba_ssm_fhbfq4rqrrau4.abi3.so → _mamba_ssm_85e64a5.abi3.so} +2 -2
- build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/_ops.py +3 -3
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/{_mamba_ssm_konfvt7wiz4bc.abi3.so → _mamba_ssm_85e64a5.abi3.so} +2 -2
- build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/_ops.py +3 -3
- build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
- build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/_mamba_ssm_b7y35xkw542po.abi3.so +0 -3
- build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/_ops.py +3 -3
- build/torch25-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_3nr5ex3ddrv6c.abi3.so +0 -3
- build/torch25-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
- build/torch25-cxx98-cu124-x86_64-linux/mamba_ssm/_ops.py +3 -3
- build/torch26-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
- build/torch26-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_w4jqdduxei7ne.abi3.so +0 -3
- build/torch26-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py +3 -3
- build/torch26-cxx11-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
- build/torch26-cxx11-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_h4pt4pjmzduuo.abi3.so +0 -3
- build/torch26-cxx11-cu124-x86_64-linux/mamba_ssm/_ops.py +3 -3
- build/torch26-cxx11-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
- build/torch26-cxx11-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_ad2dqkuyppsay.abi3.so +0 -3
- build/torch26-cxx11-cu126-x86_64-linux/mamba_ssm/_ops.py +3 -3
- build/torch26-cxx98-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
- build/torch26-cxx98-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_g4gqbotnq7pgy.abi3.so +0 -3
- build/torch26-cxx98-cu118-x86_64-linux/mamba_ssm/_ops.py +3 -3
- build/torch26-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
- build/torch26-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_r7gpumhmqnfog.abi3.so +0 -3
- build/torch26-cxx98-cu124-x86_64-linux/mamba_ssm/_ops.py +3 -3
- build/torch26-cxx98-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
- build/torch26-cxx98-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_ojx7o3olgtezs.abi3.so +0 -3
- build/torch26-cxx98-cu126-x86_64-linux/mamba_ssm/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/__init__.py +14 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py +9 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/__init__.py +0 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/distributed_utils.py +144 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +326 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/models/__init__.py +0 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/models/config_mamba.py +18 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +338 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/__init__.py +0 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/block.py +107 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2.py +502 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2_simple.py +229 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba_simple.py +339 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mha.py +294 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mlp.py +34 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/ssd_minimal.py +111 -0
- build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/ops/__init__.py +0 -0
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/{_mamba_ssm_bft6nicqkg6ni.abi3.so → _mamba_ssm_85e64a5.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65499c06108d39d466d78b0e628645aff076a8db1ebd6ed6c09d05ccb4c80296
|
3 |
+
size 261767104
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _mamba_ssm_85e64a5
|
3 |
+
ops = torch.ops._mamba_ssm_85e64a5
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_mamba_ssm_85e64a5::{op_name}"
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/{_mamba_ssm_nmrmresto7zfi.abi3.so → _mamba_ssm_85e64a5.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:878481a2cddfbf12eea8a8f746abdea5bd80a8b8feeb6aa87b9561a2987bebea
|
3 |
+
size 250394944
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _mamba_ssm_85e64a5
|
3 |
+
ops = torch.ops._mamba_ssm_85e64a5
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_mamba_ssm_85e64a5::{op_name}"
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/{_mamba_ssm_fhbfq4rqrrau4.abi3.so → _mamba_ssm_85e64a5.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7987b149563bfaef257006c494496743db9eb769fb1727b0f7b260839193cb30
|
3 |
+
size 249159400
|
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _mamba_ssm_85e64a5
|
3 |
+
ops = torch.ops._mamba_ssm_85e64a5
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_mamba_ssm_85e64a5::{op_name}"
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/{_mamba_ssm_konfvt7wiz4bc.abi3.so → _mamba_ssm_85e64a5.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6a97c777005c18be46b647314eb6f1659f894ef1aa06a1ea9827f1288156545
|
3 |
+
size 261763792
|
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _mamba_ssm_85e64a5
|
3 |
+
ops = torch.ops._mamba_ssm_85e64a5
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_mamba_ssm_85e64a5::{op_name}"
|
build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d7826f86b0f364f57d88c307511a12d2981ae0d16b4f30913e315794c794f563
|
3 |
+
size 250387568
|
build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/_mamba_ssm_b7y35xkw542po.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:c01510f5ed28163a69098ed3ac7d18b42d58d38ed5cd15cc2088c0b5056be5d5
|
3 |
-
size 247798920
|
|
|
|
|
|
|
|
build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _mamba_ssm_85e64a5
|
3 |
+
ops = torch.ops._mamba_ssm_85e64a5
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_mamba_ssm_85e64a5::{op_name}"
|
build/torch25-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_3nr5ex3ddrv6c.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:856bdc363912163550832cb546797c1c668494d7a1244016e827efe9e945af4a
|
3 |
-
size 246546992
|
|
|
|
|
|
|
|
build/torch25-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ef95dd140a1ea41f59e4af3a6fa95b1ab14fbfb2ef529bd3005cf7471d2c5f7
|
3 |
+
size 249156128
|
build/torch25-cxx98-cu124-x86_64-linux/mamba_ssm/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _mamba_ssm_85e64a5
|
3 |
+
ops = torch.ops._mamba_ssm_85e64a5
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_mamba_ssm_85e64a5::{op_name}"
|
build/torch26-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8eb776fd2d50d2c6a83643848baee6a09ac1b0f0ae085d90aa918fd86c13f09d
|
3 |
+
size 261771440
|
build/torch26-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_w4jqdduxei7ne.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:d87ac7d060e64fbe8651bd6a07234a49464c326a705f58b9d21f611548cbe7f6
|
3 |
-
size 258977984
|
|
|
|
|
|
|
|
build/torch26-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _mamba_ssm_85e64a5
|
3 |
+
ops = torch.ops._mamba_ssm_85e64a5
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_mamba_ssm_85e64a5::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bf2b6908f571764b70f882e10c46f0261c5e9a6af22774879e490886a3fa5b9d
|
3 |
+
size 249159824
|
build/torch26-cxx11-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_h4pt4pjmzduuo.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:a94840c139f6dd950a88a3d049bf301fdeffa69e5fd6d2a5227cceb904189c49
|
3 |
-
size 246550688
|
|
|
|
|
|
|
|
build/torch26-cxx11-cu124-x86_64-linux/mamba_ssm/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _mamba_ssm_85e64a5
|
3 |
+
ops = torch.ops._mamba_ssm_85e64a5
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_mamba_ssm_85e64a5::{op_name}"
|
build/torch26-cxx11-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:467b1ba8ff9848471a5b9d77910ef20f59a41dc532eacf11cdbc8c6d20ba5f40
|
3 |
+
size 250021072
|
build/torch26-cxx11-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_ad2dqkuyppsay.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:e7625fca55530d1e2f3801c430cf50c0f38f744f631ff2ff3cfd48ed88a63eb9
|
3 |
-
size 247579880
|
|
|
|
|
|
|
|
build/torch26-cxx11-cu126-x86_64-linux/mamba_ssm/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _mamba_ssm_85e64a5
|
3 |
+
ops = torch.ops._mamba_ssm_85e64a5
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_mamba_ssm_85e64a5::{op_name}"
|
build/torch26-cxx98-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6a3b8a82f5359fb1f1f9736d5c62d7a7449f76547715115b3f2e6cf706b06f89
|
3 |
+
size 261764024
|
build/torch26-cxx98-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_g4gqbotnq7pgy.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:966b3d0a8ea86436492ca16aead68420d2270094f4c38425b0551f556a410960
|
3 |
-
size 258970576
|
|
|
|
|
|
|
|
build/torch26-cxx98-cu118-x86_64-linux/mamba_ssm/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _mamba_ssm_85e64a5
|
3 |
+
ops = torch.ops._mamba_ssm_85e64a5
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_mamba_ssm_85e64a5::{op_name}"
|
build/torch26-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81a1cdd99de783191a0551d958e13c0ff3625864a87b3e3850db3603ba8cebb1
|
3 |
+
size 249156424
|
build/torch26-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_r7gpumhmqnfog.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:ff591b542c6a1e41c71df196d9a47966f832a77945b26684077b881f9b99e016
|
3 |
-
size 246547288
|
|
|
|
|
|
|
|
build/torch26-cxx98-cu124-x86_64-linux/mamba_ssm/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _mamba_ssm_85e64a5
|
3 |
+
ops = torch.ops._mamba_ssm_85e64a5
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_mamba_ssm_85e64a5::{op_name}"
|
build/torch26-cxx98-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:956f952fefd39eaed6eab1fde0f2dbb4cf6d2004cf986f40392c33ef745f084a
|
3 |
+
size 250017672
|
build/torch26-cxx98-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_ojx7o3olgtezs.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:b611bba2ae9af44b053ec900355dbd27ccb64d0fa85c732a261cb0b0304df24d
|
3 |
-
size 247576472
|
|
|
|
|
|
|
|
build/torch26-cxx98-cu126-x86_64-linux/mamba_ssm/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _mamba_ssm_85e64a5
|
3 |
+
ops = torch.ops._mamba_ssm_85e64a5
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_mamba_ssm_85e64a5::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "2.2.4"
|
2 |
+
|
3 |
+
from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
4 |
+
from .modules.mamba_simple import Mamba
|
5 |
+
from .modules.mamba2 import Mamba2
|
6 |
+
from .models.mixer_seq_simple import MambaLMHeadModel
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
"selective_scan_fn",
|
10 |
+
"mamba_inner_fn",
|
11 |
+
"Mamba",
|
12 |
+
"Mamba2",
|
13 |
+
"MambaLMHeadModel",
|
14 |
+
]
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bd30bbcf05aa050cdd6472ec40e08254762525e45f3b4209d11dfef629a201fa
|
3 |
+
size 261771568
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _mamba_ssm_85e64a5
|
3 |
+
ops = torch.ops._mamba_ssm_85e64a5
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_mamba_ssm_85e64a5::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/__init__.py
ADDED
File without changes
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/distributed_utils.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
from torch.distributed import ProcessGroup
|
6 |
+
|
7 |
+
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
8 |
+
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
9 |
+
# version of PyTorch. The following 4 lines are for backward compatibility with
|
10 |
+
# older PyTorch.
|
11 |
+
if "all_gather_into_tensor" not in dir(torch.distributed):
|
12 |
+
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
13 |
+
if "reduce_scatter_tensor" not in dir(torch.distributed):
|
14 |
+
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
15 |
+
|
16 |
+
|
17 |
+
# Raw operation, does not support autograd, but does support async
|
18 |
+
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
19 |
+
world_size = torch.distributed.get_world_size(process_group)
|
20 |
+
output = torch.empty(
|
21 |
+
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
22 |
+
)
|
23 |
+
handle = torch.distributed.all_gather_into_tensor(
|
24 |
+
output, input_.contiguous(), group=process_group, async_op=async_op
|
25 |
+
)
|
26 |
+
return output, handle
|
27 |
+
|
28 |
+
|
29 |
+
# Raw operation, does not support autograd, but does support async
|
30 |
+
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
31 |
+
world_size = torch.distributed.get_world_size(process_group)
|
32 |
+
assert input_.shape[0] % world_size == 0
|
33 |
+
output = torch.empty(
|
34 |
+
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
35 |
+
)
|
36 |
+
handle = torch.distributed.reduce_scatter_tensor(
|
37 |
+
output, input_.contiguous(), group=process_group, async_op=async_op
|
38 |
+
)
|
39 |
+
return output, handle
|
40 |
+
|
41 |
+
|
42 |
+
# Raw operation, does not support autograd, but does support async
|
43 |
+
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
44 |
+
input_ = input_.contiguous()
|
45 |
+
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
|
46 |
+
return input_, handle
|
47 |
+
|
48 |
+
|
49 |
+
class AllGatherFunc(torch.autograd.Function):
|
50 |
+
"""Gather the input from sequence parallel region and concatenate."""
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
54 |
+
ctx.process_group = process_group
|
55 |
+
output, _ = all_gather_raw(input_, process_group)
|
56 |
+
return output
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def backward(ctx, grad_output: Tensor):
|
60 |
+
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
|
61 |
+
return grad_input, None
|
62 |
+
|
63 |
+
|
64 |
+
# Supports autograd, but does not support async
|
65 |
+
all_gather = AllGatherFunc.apply
|
66 |
+
|
67 |
+
|
68 |
+
class ReduceScatterFunc(torch.autograd.Function):
|
69 |
+
"""Reduce scatter the input from the sequence parallel region and concatenate."""
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
73 |
+
ctx.process_group = process_group
|
74 |
+
output, _ = reduce_scatter_raw(input_, process_group)
|
75 |
+
return output
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def backward(ctx, grad_output: Tensor):
|
79 |
+
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
|
80 |
+
return grad_input, None
|
81 |
+
|
82 |
+
|
83 |
+
# Supports autograd, but does not support async
|
84 |
+
reduce_scatter = ReduceScatterFunc.apply
|
85 |
+
|
86 |
+
|
87 |
+
class AllReduceFunc(torch.autograd.Function):
|
88 |
+
"""Gather the input from sequence parallel region and concatenate."""
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
92 |
+
ctx.process_group = process_group
|
93 |
+
output, _ = all_reduce_raw(input_, process_group)
|
94 |
+
return output
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def backward(ctx, grad_output: Tensor):
|
98 |
+
return grad_output, None
|
99 |
+
|
100 |
+
|
101 |
+
# Supports autograd, but does not support async
|
102 |
+
all_reduce = AllReduceFunc.apply
|
103 |
+
|
104 |
+
|
105 |
+
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
|
106 |
+
# We want to iterate over parameters with _shared_params=True in the same order,
|
107 |
+
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
108 |
+
pamams_shared = {
|
109 |
+
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
|
110 |
+
}
|
111 |
+
for _, p in sorted(pamams_shared.items()):
|
112 |
+
with torch.no_grad():
|
113 |
+
# Broadcast needs src to be global rank, not group rank
|
114 |
+
torch.distributed.broadcast(
|
115 |
+
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
|
120 |
+
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
|
121 |
+
# We want to iterate over parameters with _sequence_parallel=True in the same order,
|
122 |
+
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
123 |
+
params_seqparallel = {
|
124 |
+
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
|
125 |
+
}
|
126 |
+
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
|
127 |
+
if grads:
|
128 |
+
with torch.no_grad():
|
129 |
+
coalesced = torch._utils._flatten_dense_tensors(grads)
|
130 |
+
torch.distributed.all_reduce(coalesced, group=process_group)
|
131 |
+
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
|
132 |
+
buf.copy_(synced)
|
133 |
+
|
134 |
+
|
135 |
+
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
|
136 |
+
"""Get the dim for the local rank derived from splitting dim on world_size processes.
|
137 |
+
|
138 |
+
The split may not be even across the world_size processes.
|
139 |
+
"""
|
140 |
+
multiple = dim // multiple_of
|
141 |
+
div = multiple // world_size
|
142 |
+
mod = multiple % world_size
|
143 |
+
local_multiple = div + int(local_rank < mod)
|
144 |
+
return local_multiple * multiple_of
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import Tensor
|
9 |
+
from torch.distributed import ProcessGroup
|
10 |
+
from ..utils.torch import custom_bwd, custom_fwd
|
11 |
+
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from ..distributed.distributed_utils import (
|
15 |
+
all_gather_raw,
|
16 |
+
all_reduce,
|
17 |
+
all_reduce_raw,
|
18 |
+
reduce_scatter,
|
19 |
+
reduce_scatter_raw,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class ParallelLinearFunc(torch.autograd.Function):
|
24 |
+
@staticmethod
|
25 |
+
@custom_fwd
|
26 |
+
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
|
27 |
+
"""
|
28 |
+
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
29 |
+
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
30 |
+
"""
|
31 |
+
ctx.compute_weight_gradient = weight.requires_grad
|
32 |
+
ctx.process_group = process_group
|
33 |
+
ctx.sequence_parallel = sequence_parallel
|
34 |
+
|
35 |
+
if torch.is_autocast_enabled():
|
36 |
+
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
37 |
+
x = x.contiguous()
|
38 |
+
if process_group is not None and sequence_parallel:
|
39 |
+
# We want to kick off the all_gather early, before weight dtype conversion
|
40 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
41 |
+
else:
|
42 |
+
total_x = x
|
43 |
+
|
44 |
+
if torch.is_autocast_enabled():
|
45 |
+
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
46 |
+
bias = (
|
47 |
+
bias.to(dtype=torch.get_autocast_gpu_dtype())
|
48 |
+
if bias is not None
|
49 |
+
else None
|
50 |
+
)
|
51 |
+
weight = weight.contiguous()
|
52 |
+
if process_group is not None and sequence_parallel:
|
53 |
+
handle_x.wait()
|
54 |
+
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
55 |
+
batch_dim = batch_shape.numel()
|
56 |
+
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
57 |
+
output = F.linear(total_x, weight, bias)
|
58 |
+
if ctx.compute_weight_gradient:
|
59 |
+
ctx.save_for_backward(x, weight)
|
60 |
+
else:
|
61 |
+
ctx.save_for_backward(weight)
|
62 |
+
return output
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
@custom_bwd
|
66 |
+
def backward(ctx, grad_output):
|
67 |
+
grad_output = grad_output.contiguous()
|
68 |
+
process_group = ctx.process_group
|
69 |
+
sequence_parallel = ctx.sequence_parallel
|
70 |
+
if ctx.compute_weight_gradient:
|
71 |
+
x, weight = ctx.saved_tensors
|
72 |
+
if process_group is not None and sequence_parallel:
|
73 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
74 |
+
else:
|
75 |
+
total_x = x
|
76 |
+
else:
|
77 |
+
(weight,) = ctx.saved_tensors
|
78 |
+
total_x = None
|
79 |
+
batch_shape = grad_output.shape[:-1]
|
80 |
+
batch_dim = batch_shape.numel()
|
81 |
+
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
82 |
+
if ctx.needs_input_grad[0]:
|
83 |
+
grad_input = F.linear(grad_output, weight.t())
|
84 |
+
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
85 |
+
if process_group is not None:
|
86 |
+
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
87 |
+
grad_input, handle_grad_input = reduce_fn(
|
88 |
+
grad_input, process_group, async_op=True
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
grad_input = None
|
92 |
+
if ctx.needs_input_grad[1]:
|
93 |
+
assert ctx.compute_weight_gradient
|
94 |
+
if process_group is not None and sequence_parallel:
|
95 |
+
handle_x.wait()
|
96 |
+
grad_weight = torch.einsum(
|
97 |
+
"bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
grad_weight = None
|
101 |
+
grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
|
102 |
+
if process_group is not None and ctx.needs_input_grad[0]:
|
103 |
+
handle_grad_input.wait()
|
104 |
+
return grad_input, grad_weight, grad_bias, None, None
|
105 |
+
|
106 |
+
|
107 |
+
def parallel_linear_func(
|
108 |
+
x: Tensor,
|
109 |
+
weight: Tensor,
|
110 |
+
bias: Optional[Tensor] = None,
|
111 |
+
process_group: Optional[ProcessGroup] = None,
|
112 |
+
sequence_parallel: bool = True,
|
113 |
+
):
|
114 |
+
return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
|
115 |
+
|
116 |
+
|
117 |
+
class ColumnParallelLinear(nn.Linear):
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
in_features: int,
|
121 |
+
out_features: int,
|
122 |
+
process_group: ProcessGroup,
|
123 |
+
bias: bool = True,
|
124 |
+
sequence_parallel=True,
|
125 |
+
multiple_of=1,
|
126 |
+
device=None,
|
127 |
+
dtype=None,
|
128 |
+
) -> None:
|
129 |
+
world_size = torch.distributed.get_world_size(process_group)
|
130 |
+
if out_features % multiple_of:
|
131 |
+
raise ValueError(
|
132 |
+
f"out_features ({out_features}) must be a multiple of {multiple_of}"
|
133 |
+
)
|
134 |
+
multiple = out_features // multiple_of
|
135 |
+
# We want to split @multiple across world_size, but it could be an uneven split
|
136 |
+
div = multiple // world_size
|
137 |
+
mod = multiple % world_size
|
138 |
+
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
139 |
+
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
140 |
+
super().__init__(
|
141 |
+
in_features,
|
142 |
+
local_multiple * multiple_of,
|
143 |
+
bias=bias,
|
144 |
+
device=device,
|
145 |
+
dtype=dtype,
|
146 |
+
)
|
147 |
+
self.process_group = process_group
|
148 |
+
self.sequence_parallel = sequence_parallel
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
152 |
+
# we do an all_gather of x before doing the matmul.
|
153 |
+
# If not, then the input is already gathered.
|
154 |
+
return parallel_linear_func(
|
155 |
+
x,
|
156 |
+
self.weight,
|
157 |
+
self.bias,
|
158 |
+
process_group=self.process_group,
|
159 |
+
sequence_parallel=self.sequence_parallel,
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
class RowParallelLinear(nn.Linear):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
in_features: int,
|
167 |
+
out_features: int,
|
168 |
+
process_group: ProcessGroup,
|
169 |
+
bias: bool = True,
|
170 |
+
sequence_parallel=True,
|
171 |
+
multiple_of=1,
|
172 |
+
device=None,
|
173 |
+
dtype=None,
|
174 |
+
) -> None:
|
175 |
+
world_size = torch.distributed.get_world_size(process_group)
|
176 |
+
rank = torch.distributed.get_rank(process_group)
|
177 |
+
if in_features % multiple_of:
|
178 |
+
raise ValueError(
|
179 |
+
f"in_features ({in_features}) must be a multiple of {multiple_of}"
|
180 |
+
)
|
181 |
+
multiple = in_features // multiple_of
|
182 |
+
# We want to split @multiple across world_size, but it could be an uneven split
|
183 |
+
div = multiple // world_size
|
184 |
+
mod = multiple % world_size
|
185 |
+
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
186 |
+
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
187 |
+
# Only rank 0 will have bias
|
188 |
+
super().__init__(
|
189 |
+
local_multiple * multiple_of,
|
190 |
+
out_features,
|
191 |
+
bias=bias and rank == 0,
|
192 |
+
device=device,
|
193 |
+
dtype=dtype,
|
194 |
+
)
|
195 |
+
self.process_group = process_group
|
196 |
+
self.sequence_parallel = sequence_parallel
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
"""
|
200 |
+
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
201 |
+
a reduce_scatter of the result.
|
202 |
+
"""
|
203 |
+
out = parallel_linear_func(x, self.weight, self.bias)
|
204 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
205 |
+
return reduce_fn(out, self.process_group)
|
206 |
+
|
207 |
+
|
208 |
+
class VocabParallelEmbedding(nn.Embedding):
|
209 |
+
def __init__(
|
210 |
+
self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
|
211 |
+
):
|
212 |
+
self.process_group = process_group
|
213 |
+
if process_group is not None:
|
214 |
+
world_size = torch.distributed.get_world_size(process_group)
|
215 |
+
if num_embeddings % world_size != 0:
|
216 |
+
raise ValueError(
|
217 |
+
f"num_embeddings ({num_embeddings}) must be divisible by "
|
218 |
+
f"world_size ({world_size})"
|
219 |
+
)
|
220 |
+
if world_size > 1 and padding_idx is not None:
|
221 |
+
raise RuntimeError("ParallelEmbedding does not support padding_idx")
|
222 |
+
else:
|
223 |
+
world_size = 1
|
224 |
+
super().__init__(
|
225 |
+
num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
|
226 |
+
)
|
227 |
+
|
228 |
+
def forward(self, input: Tensor) -> Tensor:
|
229 |
+
if self.process_group is None:
|
230 |
+
return super().forward(input)
|
231 |
+
else:
|
232 |
+
rank = torch.distributed.get_rank(self.process_group)
|
233 |
+
vocab_size = self.num_embeddings
|
234 |
+
vocab_start_index, vocab_end_index = (
|
235 |
+
rank * vocab_size,
|
236 |
+
(rank + 1) * vocab_size,
|
237 |
+
)
|
238 |
+
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
239 |
+
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
|
240 |
+
input = input - vocab_start_index
|
241 |
+
input[input_ids_mask] = 0
|
242 |
+
embeddings = super().forward(input)
|
243 |
+
embeddings[input_ids_mask] = 0.0
|
244 |
+
return embeddings
|
245 |
+
|
246 |
+
|
247 |
+
class ColumnParallelEmbedding(nn.Embedding):
|
248 |
+
def __init__(
|
249 |
+
self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
|
250 |
+
):
|
251 |
+
self.process_group = process_group
|
252 |
+
if process_group is not None:
|
253 |
+
world_size = torch.distributed.get_world_size(process_group)
|
254 |
+
if embedding_dim % world_size != 0:
|
255 |
+
raise ValueError(
|
256 |
+
f"embedding_dim ({embedding_dim}) must be divisible by "
|
257 |
+
f"world_size ({world_size})"
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
world_size = 1
|
261 |
+
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
|
262 |
+
|
263 |
+
|
264 |
+
class ParallelEmbeddings(nn.Module):
|
265 |
+
def __init__(
|
266 |
+
self,
|
267 |
+
embed_dim,
|
268 |
+
vocab_size,
|
269 |
+
max_position_embeddings,
|
270 |
+
process_group,
|
271 |
+
padding_idx=None,
|
272 |
+
sequence_parallel=True,
|
273 |
+
device=None,
|
274 |
+
dtype=None,
|
275 |
+
):
|
276 |
+
"""
|
277 |
+
If max_position_embeddings <= 0, there's no position embeddings
|
278 |
+
"""
|
279 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
280 |
+
super().__init__()
|
281 |
+
self.process_group = process_group
|
282 |
+
self.sequence_parallel = sequence_parallel
|
283 |
+
self.word_embeddings = VocabParallelEmbedding(
|
284 |
+
vocab_size,
|
285 |
+
embed_dim,
|
286 |
+
padding_idx=padding_idx,
|
287 |
+
process_group=process_group,
|
288 |
+
**factory_kwargs,
|
289 |
+
)
|
290 |
+
self.max_position_embeddings = max_position_embeddings
|
291 |
+
if self.max_position_embeddings > 0:
|
292 |
+
self.position_embeddings = ColumnParallelEmbedding(
|
293 |
+
max_position_embeddings,
|
294 |
+
embed_dim,
|
295 |
+
process_group=process_group,
|
296 |
+
**factory_kwargs,
|
297 |
+
)
|
298 |
+
|
299 |
+
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
300 |
+
"""
|
301 |
+
input_ids: (batch, seqlen)
|
302 |
+
position_ids: (batch, seqlen)
|
303 |
+
"""
|
304 |
+
batch_size, seqlen = input_ids.shape
|
305 |
+
world_size = torch.distributed.get_world_size(self.process_group)
|
306 |
+
embeddings = self.word_embeddings(input_ids)
|
307 |
+
if self.max_position_embeddings > 0:
|
308 |
+
if position_ids is None:
|
309 |
+
position_ids = torch.arange(
|
310 |
+
seqlen, dtype=torch.long, device=input_ids.device
|
311 |
+
)
|
312 |
+
position_embeddings = self.position_embeddings(position_ids)
|
313 |
+
if world_size <= 1:
|
314 |
+
embeddings = embeddings + position_embeddings
|
315 |
+
else:
|
316 |
+
partition_dim = self.position_embeddings.embedding_dim
|
317 |
+
rank = torch.distributed.get_rank(self.process_group)
|
318 |
+
embeddings[
|
319 |
+
..., rank * partition_dim : (rank + 1) * partition_dim
|
320 |
+
] += position_embeddings
|
321 |
+
if combine_batch_seqlen_dim:
|
322 |
+
embeddings = rearrange(embeddings, "b s d -> (b s) d")
|
323 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
324 |
+
return (
|
325 |
+
embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
|
326 |
+
)
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/models/__init__.py
ADDED
File without changes
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/models/config_mamba.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class MambaConfig:
|
6 |
+
|
7 |
+
d_model: int = 2560
|
8 |
+
d_intermediate: int = 0
|
9 |
+
n_layer: int = 64
|
10 |
+
vocab_size: int = 50277
|
11 |
+
ssm_cfg: dict = field(default_factory=dict)
|
12 |
+
attn_layer_idx: list = field(default_factory=list)
|
13 |
+
attn_cfg: dict = field(default_factory=dict)
|
14 |
+
rms_norm: bool = True
|
15 |
+
residual_in_fp32: bool = True
|
16 |
+
fused_add_norm: bool = True
|
17 |
+
pad_vocab_size_multiple: int = 8
|
18 |
+
tie_embeddings: bool = True
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from functools import partial
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import copy
|
8 |
+
|
9 |
+
from collections import namedtuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from .config_mamba import MambaConfig
|
15 |
+
from ..modules.mamba_simple import Mamba
|
16 |
+
from ..modules.mamba2 import Mamba2
|
17 |
+
from ..modules.mha import MHA
|
18 |
+
from ..modules.mlp import GatedMLP
|
19 |
+
from ..modules.block import Block
|
20 |
+
from ..utils.generation import GenerationMixin
|
21 |
+
from ..utils.hf import load_config_hf, load_state_dict_hf
|
22 |
+
|
23 |
+
try:
|
24 |
+
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
25 |
+
except ImportError:
|
26 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
27 |
+
|
28 |
+
|
29 |
+
def create_block(
|
30 |
+
d_model,
|
31 |
+
d_intermediate,
|
32 |
+
ssm_cfg=None,
|
33 |
+
attn_layer_idx=None,
|
34 |
+
attn_cfg=None,
|
35 |
+
norm_epsilon=1e-5,
|
36 |
+
rms_norm=False,
|
37 |
+
residual_in_fp32=False,
|
38 |
+
fused_add_norm=False,
|
39 |
+
layer_idx=None,
|
40 |
+
device=None,
|
41 |
+
dtype=None,
|
42 |
+
):
|
43 |
+
if ssm_cfg is None:
|
44 |
+
ssm_cfg = {}
|
45 |
+
if attn_layer_idx is None:
|
46 |
+
attn_layer_idx = []
|
47 |
+
if attn_cfg is None:
|
48 |
+
attn_cfg = {}
|
49 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
50 |
+
if layer_idx not in attn_layer_idx:
|
51 |
+
# Create a copy of the config to modify
|
52 |
+
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
|
53 |
+
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
|
54 |
+
if ssm_layer not in ["Mamba1", "Mamba2"]:
|
55 |
+
raise ValueError(
|
56 |
+
f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
|
57 |
+
)
|
58 |
+
mixer_cls = partial(
|
59 |
+
Mamba2 if ssm_layer == "Mamba2" else Mamba,
|
60 |
+
layer_idx=layer_idx,
|
61 |
+
**ssm_cfg,
|
62 |
+
**factory_kwargs,
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
|
66 |
+
norm_cls = partial(
|
67 |
+
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
68 |
+
)
|
69 |
+
if d_intermediate == 0:
|
70 |
+
mlp_cls = nn.Identity
|
71 |
+
else:
|
72 |
+
mlp_cls = partial(
|
73 |
+
GatedMLP,
|
74 |
+
hidden_features=d_intermediate,
|
75 |
+
out_features=d_model,
|
76 |
+
**factory_kwargs,
|
77 |
+
)
|
78 |
+
block = Block(
|
79 |
+
d_model,
|
80 |
+
mixer_cls,
|
81 |
+
mlp_cls,
|
82 |
+
norm_cls=norm_cls,
|
83 |
+
fused_add_norm=fused_add_norm,
|
84 |
+
residual_in_fp32=residual_in_fp32,
|
85 |
+
)
|
86 |
+
block.layer_idx = layer_idx
|
87 |
+
return block
|
88 |
+
|
89 |
+
|
90 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
91 |
+
def _init_weights(
|
92 |
+
module,
|
93 |
+
n_layer,
|
94 |
+
initializer_range=0.02, # Now only used for embedding layer.
|
95 |
+
rescale_prenorm_residual=True,
|
96 |
+
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
97 |
+
):
|
98 |
+
if isinstance(module, nn.Linear):
|
99 |
+
if module.bias is not None:
|
100 |
+
if not getattr(module.bias, "_no_reinit", False):
|
101 |
+
nn.init.zeros_(module.bias)
|
102 |
+
elif isinstance(module, nn.Embedding):
|
103 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
104 |
+
|
105 |
+
if rescale_prenorm_residual:
|
106 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
107 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
108 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
109 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
110 |
+
#
|
111 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
112 |
+
for name, p in module.named_parameters():
|
113 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
114 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
115 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
116 |
+
# We need to reinit p since this code could be called multiple times
|
117 |
+
# Having just p *= scale would repeatedly scale it down
|
118 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
119 |
+
with torch.no_grad():
|
120 |
+
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
121 |
+
|
122 |
+
|
123 |
+
class MixerModel(nn.Module):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
d_model: int,
|
127 |
+
n_layer: int,
|
128 |
+
d_intermediate: int,
|
129 |
+
vocab_size: int,
|
130 |
+
ssm_cfg=None,
|
131 |
+
attn_layer_idx=None,
|
132 |
+
attn_cfg=None,
|
133 |
+
norm_epsilon: float = 1e-5,
|
134 |
+
rms_norm: bool = False,
|
135 |
+
initializer_cfg=None,
|
136 |
+
fused_add_norm=False,
|
137 |
+
residual_in_fp32=False,
|
138 |
+
device=None,
|
139 |
+
dtype=None,
|
140 |
+
) -> None:
|
141 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
142 |
+
super().__init__()
|
143 |
+
self.residual_in_fp32 = residual_in_fp32
|
144 |
+
|
145 |
+
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
146 |
+
|
147 |
+
# We change the order of residual and layer norm:
|
148 |
+
# Instead of LN -> Attn / MLP -> Add, we do:
|
149 |
+
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
150 |
+
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
151 |
+
# This is for performance reason: we can fuse add + layer_norm.
|
152 |
+
self.fused_add_norm = fused_add_norm
|
153 |
+
if self.fused_add_norm:
|
154 |
+
if layer_norm_fn is None or rms_norm_fn is None:
|
155 |
+
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
156 |
+
|
157 |
+
self.layers = nn.ModuleList(
|
158 |
+
[
|
159 |
+
create_block(
|
160 |
+
d_model,
|
161 |
+
d_intermediate=d_intermediate,
|
162 |
+
ssm_cfg=ssm_cfg,
|
163 |
+
attn_layer_idx=attn_layer_idx,
|
164 |
+
attn_cfg=attn_cfg,
|
165 |
+
norm_epsilon=norm_epsilon,
|
166 |
+
rms_norm=rms_norm,
|
167 |
+
residual_in_fp32=residual_in_fp32,
|
168 |
+
fused_add_norm=fused_add_norm,
|
169 |
+
layer_idx=i,
|
170 |
+
**factory_kwargs,
|
171 |
+
)
|
172 |
+
for i in range(n_layer)
|
173 |
+
]
|
174 |
+
)
|
175 |
+
|
176 |
+
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
177 |
+
d_model, eps=norm_epsilon, **factory_kwargs
|
178 |
+
)
|
179 |
+
|
180 |
+
self.apply(
|
181 |
+
partial(
|
182 |
+
_init_weights,
|
183 |
+
n_layer=n_layer,
|
184 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
185 |
+
n_residuals_per_layer=(
|
186 |
+
1 if d_intermediate == 0 else 2
|
187 |
+
), # 2 if we have MLP
|
188 |
+
)
|
189 |
+
)
|
190 |
+
|
191 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
192 |
+
return {
|
193 |
+
i: layer.allocate_inference_cache(
|
194 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
195 |
+
)
|
196 |
+
for i, layer in enumerate(self.layers)
|
197 |
+
}
|
198 |
+
|
199 |
+
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
|
200 |
+
hidden_states = self.embedding(input_ids)
|
201 |
+
residual = None
|
202 |
+
for layer in self.layers:
|
203 |
+
hidden_states, residual = layer(
|
204 |
+
hidden_states,
|
205 |
+
residual,
|
206 |
+
inference_params=inference_params,
|
207 |
+
**mixer_kwargs,
|
208 |
+
)
|
209 |
+
if not self.fused_add_norm:
|
210 |
+
residual = (
|
211 |
+
(hidden_states + residual) if residual is not None else hidden_states
|
212 |
+
)
|
213 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
214 |
+
else:
|
215 |
+
# Set prenorm=False here since we don't need the residual
|
216 |
+
hidden_states = layer_norm_fn(
|
217 |
+
hidden_states,
|
218 |
+
self.norm_f.weight,
|
219 |
+
self.norm_f.bias,
|
220 |
+
eps=self.norm_f.eps,
|
221 |
+
residual=residual,
|
222 |
+
prenorm=False,
|
223 |
+
residual_in_fp32=self.residual_in_fp32,
|
224 |
+
is_rms_norm=isinstance(self.norm_f, RMSNorm),
|
225 |
+
)
|
226 |
+
return hidden_states
|
227 |
+
|
228 |
+
|
229 |
+
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
230 |
+
|
231 |
+
def __init__(
|
232 |
+
self,
|
233 |
+
config: MambaConfig,
|
234 |
+
initializer_cfg=None,
|
235 |
+
device=None,
|
236 |
+
dtype=None,
|
237 |
+
) -> None:
|
238 |
+
self.config = config
|
239 |
+
d_model = config.d_model
|
240 |
+
n_layer = config.n_layer
|
241 |
+
d_intermediate = config.d_intermediate
|
242 |
+
vocab_size = config.vocab_size
|
243 |
+
ssm_cfg = config.ssm_cfg
|
244 |
+
attn_layer_idx = config.attn_layer_idx
|
245 |
+
attn_cfg = config.attn_cfg
|
246 |
+
rms_norm = config.rms_norm
|
247 |
+
residual_in_fp32 = config.residual_in_fp32
|
248 |
+
fused_add_norm = config.fused_add_norm
|
249 |
+
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
250 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
251 |
+
|
252 |
+
super().__init__()
|
253 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
254 |
+
vocab_size += pad_vocab_size_multiple - (
|
255 |
+
vocab_size % pad_vocab_size_multiple
|
256 |
+
)
|
257 |
+
self.backbone = MixerModel(
|
258 |
+
d_model=d_model,
|
259 |
+
n_layer=n_layer,
|
260 |
+
d_intermediate=d_intermediate,
|
261 |
+
vocab_size=vocab_size,
|
262 |
+
ssm_cfg=ssm_cfg,
|
263 |
+
attn_layer_idx=attn_layer_idx,
|
264 |
+
attn_cfg=attn_cfg,
|
265 |
+
rms_norm=rms_norm,
|
266 |
+
initializer_cfg=initializer_cfg,
|
267 |
+
fused_add_norm=fused_add_norm,
|
268 |
+
residual_in_fp32=residual_in_fp32,
|
269 |
+
**factory_kwargs,
|
270 |
+
)
|
271 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
272 |
+
|
273 |
+
# Initialize weights and apply final processing
|
274 |
+
self.apply(
|
275 |
+
partial(
|
276 |
+
_init_weights,
|
277 |
+
n_layer=n_layer,
|
278 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
279 |
+
)
|
280 |
+
)
|
281 |
+
self.tie_weights()
|
282 |
+
|
283 |
+
def tie_weights(self):
|
284 |
+
if self.config.tie_embeddings:
|
285 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
286 |
+
|
287 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
288 |
+
return self.backbone.allocate_inference_cache(
|
289 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
290 |
+
)
|
291 |
+
|
292 |
+
def forward(
|
293 |
+
self,
|
294 |
+
input_ids,
|
295 |
+
position_ids=None,
|
296 |
+
inference_params=None,
|
297 |
+
num_last_tokens=0,
|
298 |
+
**mixer_kwargs,
|
299 |
+
):
|
300 |
+
"""
|
301 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
302 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
303 |
+
"""
|
304 |
+
hidden_states = self.backbone(
|
305 |
+
input_ids, inference_params=inference_params, **mixer_kwargs
|
306 |
+
)
|
307 |
+
if num_last_tokens > 0:
|
308 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
309 |
+
lm_logits = self.lm_head(hidden_states)
|
310 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
311 |
+
return CausalLMOutput(logits=lm_logits)
|
312 |
+
|
313 |
+
@classmethod
|
314 |
+
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
315 |
+
config_data = load_config_hf(pretrained_model_name)
|
316 |
+
config = MambaConfig(**config_data)
|
317 |
+
model = cls(config, device=device, dtype=dtype, **kwargs)
|
318 |
+
model.load_state_dict(
|
319 |
+
load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
|
320 |
+
)
|
321 |
+
return model
|
322 |
+
|
323 |
+
def save_pretrained(self, save_directory):
|
324 |
+
"""
|
325 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
326 |
+
Save the model and its configuration file to a directory.
|
327 |
+
"""
|
328 |
+
# Ensure save_directory exists
|
329 |
+
os.makedirs(save_directory, exist_ok=True)
|
330 |
+
|
331 |
+
# Save the model's state_dict
|
332 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
333 |
+
torch.save(self.state_dict(), model_path)
|
334 |
+
|
335 |
+
# Save the configuration of the model
|
336 |
+
config_path = os.path.join(save_directory, "config.json")
|
337 |
+
with open(config_path, "w") as f:
|
338 |
+
json.dump(self.config.__dict__, f, indent=4)
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/__init__.py
ADDED
File without changes
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/block.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
+
|
7 |
+
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn
|
8 |
+
|
9 |
+
|
10 |
+
class Block(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
dim,
|
14 |
+
mixer_cls,
|
15 |
+
mlp_cls,
|
16 |
+
norm_cls=nn.LayerNorm,
|
17 |
+
fused_add_norm=False,
|
18 |
+
residual_in_fp32=False,
|
19 |
+
):
|
20 |
+
"""
|
21 |
+
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
22 |
+
|
23 |
+
This Block has a slightly different structure compared to a regular
|
24 |
+
prenorm Transformer block.
|
25 |
+
The standard block is: LN -> MHA/MLP -> Add.
|
26 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
27 |
+
Here we have: Add -> LN -> Mixer, returning both
|
28 |
+
the hidden_states (output of the mixer) and the residual.
|
29 |
+
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
30 |
+
The residual needs to be provided (except for the very first block).
|
31 |
+
"""
|
32 |
+
super().__init__()
|
33 |
+
self.residual_in_fp32 = residual_in_fp32
|
34 |
+
self.fused_add_norm = fused_add_norm
|
35 |
+
self.norm = norm_cls(dim)
|
36 |
+
self.mixer = mixer_cls(dim)
|
37 |
+
if mlp_cls is not nn.Identity:
|
38 |
+
self.norm2 = norm_cls(dim)
|
39 |
+
self.mlp = mlp_cls(dim)
|
40 |
+
else:
|
41 |
+
self.mlp = None
|
42 |
+
if self.fused_add_norm:
|
43 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
44 |
+
assert isinstance(
|
45 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
46 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
47 |
+
|
48 |
+
def forward(
|
49 |
+
self,
|
50 |
+
hidden_states: Tensor,
|
51 |
+
residual: Optional[Tensor] = None,
|
52 |
+
inference_params=None,
|
53 |
+
**mixer_kwargs
|
54 |
+
):
|
55 |
+
r"""Pass the input through the encoder layer.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
hidden_states: the sequence to the encoder layer (required).
|
59 |
+
residual: hidden_states = Mixer(LN(residual))
|
60 |
+
"""
|
61 |
+
if not self.fused_add_norm:
|
62 |
+
residual = (
|
63 |
+
(hidden_states + residual) if residual is not None else hidden_states
|
64 |
+
)
|
65 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
66 |
+
if self.residual_in_fp32:
|
67 |
+
residual = residual.to(torch.float32)
|
68 |
+
else:
|
69 |
+
hidden_states, residual = layer_norm_fn(
|
70 |
+
hidden_states,
|
71 |
+
self.norm.weight,
|
72 |
+
self.norm.bias,
|
73 |
+
residual=residual,
|
74 |
+
prenorm=True,
|
75 |
+
residual_in_fp32=self.residual_in_fp32,
|
76 |
+
eps=self.norm.eps,
|
77 |
+
is_rms_norm=isinstance(self.norm, RMSNorm),
|
78 |
+
)
|
79 |
+
hidden_states = self.mixer(
|
80 |
+
hidden_states, inference_params=inference_params, **mixer_kwargs
|
81 |
+
)
|
82 |
+
|
83 |
+
if self.mlp is not None:
|
84 |
+
if not self.fused_add_norm:
|
85 |
+
residual = hidden_states + residual
|
86 |
+
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
87 |
+
if self.residual_in_fp32:
|
88 |
+
residual = residual.to(torch.float32)
|
89 |
+
else:
|
90 |
+
hidden_states, residual = layer_norm_fn(
|
91 |
+
hidden_states,
|
92 |
+
self.norm2.weight,
|
93 |
+
self.norm2.bias,
|
94 |
+
residual=residual,
|
95 |
+
prenorm=True,
|
96 |
+
residual_in_fp32=self.residual_in_fp32,
|
97 |
+
eps=self.norm2.eps,
|
98 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
99 |
+
)
|
100 |
+
hidden_states = self.mlp(hidden_states)
|
101 |
+
|
102 |
+
return hidden_states, residual
|
103 |
+
|
104 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
105 |
+
return self.mixer.allocate_inference_cache(
|
106 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
107 |
+
)
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
try:
|
12 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
13 |
+
except ImportError:
|
14 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
15 |
+
|
16 |
+
try:
|
17 |
+
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
|
18 |
+
except ImportError:
|
19 |
+
causal_conv1d_varlen_states = None
|
20 |
+
|
21 |
+
try:
|
22 |
+
from ..ops.triton.selective_state_update import selective_state_update
|
23 |
+
except ImportError:
|
24 |
+
selective_state_update = None
|
25 |
+
|
26 |
+
from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated
|
27 |
+
|
28 |
+
from ..distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
|
29 |
+
from ..distributed.distributed_utils import all_reduce, reduce_scatter
|
30 |
+
|
31 |
+
from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
|
32 |
+
from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
33 |
+
|
34 |
+
from huggingface_hub import PyTorchModelHubMixin
|
35 |
+
|
36 |
+
|
37 |
+
class Mamba2(nn.Module, PyTorchModelHubMixin):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
d_model,
|
41 |
+
d_state=128,
|
42 |
+
d_conv=4,
|
43 |
+
conv_init=None,
|
44 |
+
expand=2,
|
45 |
+
headdim=64,
|
46 |
+
d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
|
47 |
+
ngroups=1,
|
48 |
+
A_init_range=(1, 16),
|
49 |
+
D_has_hdim=False,
|
50 |
+
rmsnorm=True,
|
51 |
+
norm_before_gate=False,
|
52 |
+
dt_min=0.001,
|
53 |
+
dt_max=0.1,
|
54 |
+
dt_init_floor=1e-4,
|
55 |
+
dt_limit=(0.0, float("inf")),
|
56 |
+
bias=False,
|
57 |
+
conv_bias=True,
|
58 |
+
# Fused kernel and sharding options
|
59 |
+
chunk_size=256,
|
60 |
+
use_mem_eff_path=True,
|
61 |
+
layer_idx=None, # Absorb kwarg for general module
|
62 |
+
process_group=None,
|
63 |
+
sequence_parallel=True,
|
64 |
+
device=None,
|
65 |
+
dtype=None,
|
66 |
+
):
|
67 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
68 |
+
super().__init__()
|
69 |
+
self.d_model = d_model
|
70 |
+
self.d_state = d_state
|
71 |
+
self.d_conv = d_conv
|
72 |
+
self.conv_init = conv_init
|
73 |
+
self.expand = expand
|
74 |
+
self.process_group = process_group
|
75 |
+
self.sequence_parallel = sequence_parallel
|
76 |
+
self.world_size = 1 if process_group is None else process_group.size()
|
77 |
+
self.local_rank = 0 if process_group is None else process_group.rank()
|
78 |
+
self.d_inner = (self.expand * self.d_model) // self.world_size
|
79 |
+
assert self.d_inner * self.world_size == self.expand * self.d_model
|
80 |
+
self.headdim = headdim
|
81 |
+
self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
|
82 |
+
assert ngroups % self.world_size == 0
|
83 |
+
self.ngroups = ngroups // self.world_size
|
84 |
+
assert self.d_ssm % self.headdim == 0
|
85 |
+
self.nheads = self.d_ssm // self.headdim
|
86 |
+
self.D_has_hdim = D_has_hdim
|
87 |
+
self.rmsnorm = rmsnorm
|
88 |
+
self.norm_before_gate = norm_before_gate
|
89 |
+
self.dt_limit = dt_limit
|
90 |
+
self.activation = "silu"
|
91 |
+
self.chunk_size = chunk_size
|
92 |
+
self.use_mem_eff_path = use_mem_eff_path
|
93 |
+
self.layer_idx = layer_idx
|
94 |
+
|
95 |
+
# Order: [z, x, B, C, dt]
|
96 |
+
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
97 |
+
if self.process_group is None:
|
98 |
+
self.in_proj = nn.Linear(
|
99 |
+
self.d_model, d_in_proj, bias=bias, **factory_kwargs
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
self.in_proj = ColumnParallelLinear(
|
103 |
+
self.d_model,
|
104 |
+
d_in_proj * self.world_size,
|
105 |
+
bias=bias,
|
106 |
+
process_group=self.process_group,
|
107 |
+
sequence_parallel=self.sequence_parallel,
|
108 |
+
**factory_kwargs,
|
109 |
+
)
|
110 |
+
|
111 |
+
conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
|
112 |
+
self.conv1d = nn.Conv1d(
|
113 |
+
in_channels=conv_dim,
|
114 |
+
out_channels=conv_dim,
|
115 |
+
bias=conv_bias,
|
116 |
+
kernel_size=d_conv,
|
117 |
+
groups=conv_dim,
|
118 |
+
padding=d_conv - 1,
|
119 |
+
**factory_kwargs,
|
120 |
+
)
|
121 |
+
if self.conv_init is not None:
|
122 |
+
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
|
123 |
+
|
124 |
+
self.act = nn.SiLU()
|
125 |
+
|
126 |
+
# Initialize log dt bias
|
127 |
+
dt = torch.exp(
|
128 |
+
torch.rand(self.nheads, **factory_kwargs)
|
129 |
+
* (math.log(dt_max) - math.log(dt_min))
|
130 |
+
+ math.log(dt_min)
|
131 |
+
)
|
132 |
+
dt = torch.clamp(dt, min=dt_init_floor)
|
133 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
134 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
135 |
+
self.dt_bias = nn.Parameter(inv_dt)
|
136 |
+
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
137 |
+
# name.endswith("bias") in param_grouping.py
|
138 |
+
self.dt_bias._no_weight_decay = True
|
139 |
+
|
140 |
+
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
141 |
+
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
|
142 |
+
*A_init_range
|
143 |
+
)
|
144 |
+
A_log = torch.log(A).to(dtype=dtype)
|
145 |
+
self.A_log = nn.Parameter(A_log)
|
146 |
+
self.A_log._no_weight_decay = True
|
147 |
+
|
148 |
+
# D "skip" parameter
|
149 |
+
self.D = nn.Parameter(
|
150 |
+
torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)
|
151 |
+
)
|
152 |
+
self.D._no_weight_decay = True
|
153 |
+
|
154 |
+
if self.rmsnorm:
|
155 |
+
assert RMSNormGated is not None
|
156 |
+
self.norm = RMSNormGated(
|
157 |
+
self.d_ssm,
|
158 |
+
eps=1e-5,
|
159 |
+
norm_before_gate=self.norm_before_gate,
|
160 |
+
group_size=self.d_ssm // ngroups,
|
161 |
+
**factory_kwargs,
|
162 |
+
)
|
163 |
+
|
164 |
+
if self.process_group is None:
|
165 |
+
self.out_proj = nn.Linear(
|
166 |
+
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
self.out_proj = RowParallelLinear(
|
170 |
+
self.d_inner * self.world_size,
|
171 |
+
self.d_model,
|
172 |
+
bias=bias,
|
173 |
+
process_group=self.process_group,
|
174 |
+
sequence_parallel=self.sequence_parallel,
|
175 |
+
**factory_kwargs,
|
176 |
+
)
|
177 |
+
|
178 |
+
def forward(
|
179 |
+
self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None
|
180 |
+
):
|
181 |
+
"""
|
182 |
+
u: (batch, seqlen, hidden_dim) if seqlen=None.
|
183 |
+
If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
|
184 |
+
split u during sequence parallel, we split the batch * seqlen dimension
|
185 |
+
(in case batch is small).
|
186 |
+
Returns: same shape as u
|
187 |
+
"""
|
188 |
+
seqlen_og = seqlen
|
189 |
+
if seqlen is None:
|
190 |
+
batch, seqlen, dim = u.shape
|
191 |
+
else:
|
192 |
+
batch_seqlen, dim = u.shape
|
193 |
+
batch = batch_seqlen // seqlen
|
194 |
+
|
195 |
+
conv_state, ssm_state = None, None
|
196 |
+
if inference_params is not None:
|
197 |
+
inference_batch = (
|
198 |
+
cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
|
199 |
+
)
|
200 |
+
conv_state, ssm_state = self._get_states_from_cache(
|
201 |
+
inference_params, inference_batch
|
202 |
+
)
|
203 |
+
if inference_params.seqlen_offset > 0:
|
204 |
+
# The states are updated inplace
|
205 |
+
out, _, _ = self.step(u, conv_state, ssm_state)
|
206 |
+
return out
|
207 |
+
|
208 |
+
zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
|
209 |
+
if seqlen_og is not None:
|
210 |
+
zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
|
211 |
+
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
212 |
+
A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
|
213 |
+
dt_limit_kwargs = (
|
214 |
+
{} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
|
215 |
+
)
|
216 |
+
if self.use_mem_eff_path and inference_params is None:
|
217 |
+
out = mamba_split_conv1d_scan_combined(
|
218 |
+
zxbcdt,
|
219 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
220 |
+
self.conv1d.bias,
|
221 |
+
self.dt_bias,
|
222 |
+
A,
|
223 |
+
D=(
|
224 |
+
rearrange(self.D, "(h p) -> h p", p=self.headdim)
|
225 |
+
if self.D_has_hdim
|
226 |
+
else self.D
|
227 |
+
),
|
228 |
+
chunk_size=self.chunk_size,
|
229 |
+
seq_idx=seq_idx,
|
230 |
+
activation=self.activation,
|
231 |
+
rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
|
232 |
+
rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
|
233 |
+
outproj_weight=self.out_proj.weight,
|
234 |
+
outproj_bias=self.out_proj.bias,
|
235 |
+
headdim=None if self.D_has_hdim else self.headdim,
|
236 |
+
ngroups=self.ngroups,
|
237 |
+
norm_before_gate=self.norm_before_gate,
|
238 |
+
**dt_limit_kwargs,
|
239 |
+
)
|
240 |
+
if seqlen_og is not None:
|
241 |
+
out = rearrange(out, "b l d -> (b l) d")
|
242 |
+
if self.process_group is not None:
|
243 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
244 |
+
out = reduce_fn(out, self.process_group)
|
245 |
+
else:
|
246 |
+
d_mlp = (
|
247 |
+
zxbcdt.shape[-1]
|
248 |
+
- 2 * self.d_ssm
|
249 |
+
- 2 * self.ngroups * self.d_state
|
250 |
+
- self.nheads
|
251 |
+
) // 2
|
252 |
+
z0, x0, z, xBC, dt = torch.split(
|
253 |
+
zxbcdt,
|
254 |
+
[
|
255 |
+
d_mlp,
|
256 |
+
d_mlp,
|
257 |
+
self.d_ssm,
|
258 |
+
self.d_ssm + 2 * self.ngroups * self.d_state,
|
259 |
+
self.nheads,
|
260 |
+
],
|
261 |
+
dim=-1,
|
262 |
+
)
|
263 |
+
if conv_state is not None:
|
264 |
+
if cu_seqlens is None:
|
265 |
+
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
266 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
267 |
+
xBC_t = rearrange(xBC, "b l d -> b d l")
|
268 |
+
conv_state.copy_(
|
269 |
+
F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))
|
270 |
+
) # Update state (B D W)
|
271 |
+
else:
|
272 |
+
assert (
|
273 |
+
causal_conv1d_varlen_states is not None
|
274 |
+
), "varlen inference requires causal_conv1d package"
|
275 |
+
assert (
|
276 |
+
batch == 1
|
277 |
+
), "varlen inference only supports batch dimension 1"
|
278 |
+
conv_varlen_states = causal_conv1d_varlen_states(
|
279 |
+
xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
|
280 |
+
)
|
281 |
+
conv_state.copy_(conv_varlen_states)
|
282 |
+
assert self.activation in ["silu", "swish"]
|
283 |
+
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
284 |
+
assert (
|
285 |
+
seq_idx is None
|
286 |
+
), "varlen conv1d requires the causal_conv1d package"
|
287 |
+
xBC = self.act(
|
288 |
+
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[
|
289 |
+
:, : -(self.d_conv - 1)
|
290 |
+
]
|
291 |
+
) # (B, L, self.d_ssm + 2 * ngroups * d_state)
|
292 |
+
else:
|
293 |
+
xBC = causal_conv1d_fn(
|
294 |
+
xBC.transpose(1, 2),
|
295 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
296 |
+
bias=self.conv1d.bias,
|
297 |
+
activation=self.activation,
|
298 |
+
seq_idx=seq_idx,
|
299 |
+
).transpose(1, 2)
|
300 |
+
x, B, C = torch.split(
|
301 |
+
xBC,
|
302 |
+
[self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
|
303 |
+
dim=-1,
|
304 |
+
)
|
305 |
+
y = mamba_chunk_scan_combined(
|
306 |
+
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
307 |
+
dt,
|
308 |
+
A,
|
309 |
+
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
310 |
+
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
311 |
+
chunk_size=self.chunk_size,
|
312 |
+
D=(
|
313 |
+
rearrange(self.D, "(h p) -> h p", p=self.headdim)
|
314 |
+
if self.D_has_hdim
|
315 |
+
else self.D
|
316 |
+
),
|
317 |
+
z=(
|
318 |
+
rearrange(z, "b l (h p) -> b l h p", p=self.headdim)
|
319 |
+
if not self.rmsnorm
|
320 |
+
else None
|
321 |
+
),
|
322 |
+
dt_bias=self.dt_bias,
|
323 |
+
dt_softplus=True,
|
324 |
+
seq_idx=seq_idx,
|
325 |
+
cu_seqlens=cu_seqlens,
|
326 |
+
**dt_limit_kwargs,
|
327 |
+
return_final_states=ssm_state is not None,
|
328 |
+
return_varlen_states=cu_seqlens is not None
|
329 |
+
and inference_params is not None,
|
330 |
+
)
|
331 |
+
if ssm_state is not None:
|
332 |
+
y, last_state, *rest = y
|
333 |
+
if cu_seqlens is None:
|
334 |
+
ssm_state.copy_(last_state)
|
335 |
+
else:
|
336 |
+
varlen_states = rest[0]
|
337 |
+
ssm_state.copy_(varlen_states)
|
338 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
339 |
+
if self.rmsnorm:
|
340 |
+
y = self.norm(y, z)
|
341 |
+
if d_mlp > 0:
|
342 |
+
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
343 |
+
if seqlen_og is not None:
|
344 |
+
y = rearrange(y, "b l d -> (b l) d")
|
345 |
+
out = self.out_proj(y)
|
346 |
+
return out
|
347 |
+
|
348 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
349 |
+
dtype = hidden_states.dtype
|
350 |
+
assert (
|
351 |
+
hidden_states.shape[1] == 1
|
352 |
+
), "Only support decoding with 1 token at a time for now"
|
353 |
+
zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
354 |
+
d_mlp = (
|
355 |
+
zxbcdt.shape[-1]
|
356 |
+
- 2 * self.d_ssm
|
357 |
+
- 2 * self.ngroups * self.d_state
|
358 |
+
- self.nheads
|
359 |
+
) // 2
|
360 |
+
z0, x0, z, xBC, dt = torch.split(
|
361 |
+
zxbcdt,
|
362 |
+
[
|
363 |
+
d_mlp,
|
364 |
+
d_mlp,
|
365 |
+
self.d_ssm,
|
366 |
+
self.d_ssm + 2 * self.ngroups * self.d_state,
|
367 |
+
self.nheads,
|
368 |
+
],
|
369 |
+
dim=-1,
|
370 |
+
)
|
371 |
+
|
372 |
+
# Conv step
|
373 |
+
if causal_conv1d_update is None:
|
374 |
+
conv_state.copy_(
|
375 |
+
torch.roll(conv_state, shifts=-1, dims=-1)
|
376 |
+
) # Update state (B D W)
|
377 |
+
conv_state[:, :, -1] = xBC
|
378 |
+
xBC = torch.sum(
|
379 |
+
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
380 |
+
) # (B D)
|
381 |
+
if self.conv1d.bias is not None:
|
382 |
+
xBC = xBC + self.conv1d.bias
|
383 |
+
xBC = self.act(xBC).to(dtype=dtype)
|
384 |
+
else:
|
385 |
+
xBC = causal_conv1d_update(
|
386 |
+
xBC,
|
387 |
+
conv_state,
|
388 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
389 |
+
self.conv1d.bias,
|
390 |
+
self.activation,
|
391 |
+
)
|
392 |
+
|
393 |
+
x, B, C = torch.split(
|
394 |
+
xBC,
|
395 |
+
[self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
|
396 |
+
dim=-1,
|
397 |
+
)
|
398 |
+
A = -torch.exp(self.A_log.float()) # (nheads,)
|
399 |
+
|
400 |
+
# SSM step
|
401 |
+
if selective_state_update is None:
|
402 |
+
assert (
|
403 |
+
self.ngroups == 1
|
404 |
+
), "Only support ngroups=1 for this inference code path"
|
405 |
+
# Discretize A and B
|
406 |
+
dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
|
407 |
+
dA = torch.exp(dt * A) # (batch, nheads)
|
408 |
+
x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
|
409 |
+
dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
|
410 |
+
ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
|
411 |
+
y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
|
412 |
+
y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
|
413 |
+
y = rearrange(y, "b h p -> b (h p)")
|
414 |
+
if not self.rmsnorm:
|
415 |
+
y = y * self.act(z) # (B D)
|
416 |
+
else:
|
417 |
+
A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(
|
418 |
+
dtype=torch.float32
|
419 |
+
)
|
420 |
+
dt = repeat(dt, "b h -> b h p", p=self.headdim)
|
421 |
+
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
|
422 |
+
D = repeat(self.D, "h -> h p", p=self.headdim)
|
423 |
+
B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
|
424 |
+
C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
|
425 |
+
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
|
426 |
+
if not self.rmsnorm:
|
427 |
+
z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
|
428 |
+
y = selective_state_update(
|
429 |
+
ssm_state,
|
430 |
+
x_reshaped,
|
431 |
+
dt,
|
432 |
+
A,
|
433 |
+
B,
|
434 |
+
C,
|
435 |
+
D,
|
436 |
+
z=z if not self.rmsnorm else None,
|
437 |
+
dt_bias=dt_bias,
|
438 |
+
dt_softplus=True,
|
439 |
+
)
|
440 |
+
y = rearrange(y, "b h p -> b (h p)")
|
441 |
+
if self.rmsnorm:
|
442 |
+
y = self.norm(y, z)
|
443 |
+
if d_mlp > 0:
|
444 |
+
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
445 |
+
out = self.out_proj(y)
|
446 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
447 |
+
|
448 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
449 |
+
device = self.out_proj.weight.device
|
450 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
451 |
+
conv_state = torch.zeros(
|
452 |
+
batch_size,
|
453 |
+
self.d_conv,
|
454 |
+
self.conv1d.weight.shape[0],
|
455 |
+
device=device,
|
456 |
+
dtype=conv_dtype,
|
457 |
+
).transpose(1, 2)
|
458 |
+
ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
|
459 |
+
ssm_state = torch.zeros(
|
460 |
+
batch_size,
|
461 |
+
self.nheads,
|
462 |
+
self.headdim,
|
463 |
+
self.d_state,
|
464 |
+
device=device,
|
465 |
+
dtype=ssm_dtype,
|
466 |
+
)
|
467 |
+
return conv_state, ssm_state
|
468 |
+
|
469 |
+
def _get_states_from_cache(
|
470 |
+
self, inference_params, batch_size, initialize_states=False
|
471 |
+
):
|
472 |
+
assert self.layer_idx is not None
|
473 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
474 |
+
batch_shape = (batch_size,)
|
475 |
+
conv_state = torch.zeros(
|
476 |
+
batch_size,
|
477 |
+
self.d_conv,
|
478 |
+
self.conv1d.weight.shape[0],
|
479 |
+
device=self.conv1d.weight.device,
|
480 |
+
dtype=self.conv1d.weight.dtype,
|
481 |
+
).transpose(1, 2)
|
482 |
+
ssm_state = torch.zeros(
|
483 |
+
batch_size,
|
484 |
+
self.nheads,
|
485 |
+
self.headdim,
|
486 |
+
self.d_state,
|
487 |
+
device=self.in_proj.weight.device,
|
488 |
+
dtype=self.in_proj.weight.dtype,
|
489 |
+
)
|
490 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (
|
491 |
+
conv_state,
|
492 |
+
ssm_state,
|
493 |
+
)
|
494 |
+
else:
|
495 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[
|
496 |
+
self.layer_idx
|
497 |
+
]
|
498 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
499 |
+
if initialize_states:
|
500 |
+
conv_state.zero_()
|
501 |
+
ssm_state.zero_()
|
502 |
+
return conv_state, ssm_state
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2_simple.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
|
10 |
+
try:
|
11 |
+
from causal_conv1d import causal_conv1d_fn
|
12 |
+
except ImportError:
|
13 |
+
causal_conv1d_fn = None
|
14 |
+
|
15 |
+
try:
|
16 |
+
from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
|
17 |
+
except ImportError:
|
18 |
+
RMSNormGated, LayerNorm = None, None
|
19 |
+
|
20 |
+
from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
|
21 |
+
from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
22 |
+
|
23 |
+
|
24 |
+
class Mamba2Simple(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
d_model,
|
28 |
+
d_state=64,
|
29 |
+
d_conv=4,
|
30 |
+
conv_init=None,
|
31 |
+
expand=2,
|
32 |
+
headdim=128,
|
33 |
+
ngroups=1,
|
34 |
+
A_init_range=(1, 16),
|
35 |
+
dt_min=0.001,
|
36 |
+
dt_max=0.1,
|
37 |
+
dt_init_floor=1e-4,
|
38 |
+
dt_limit=(0.0, float("inf")),
|
39 |
+
learnable_init_states=False,
|
40 |
+
activation="swish",
|
41 |
+
bias=False,
|
42 |
+
conv_bias=True,
|
43 |
+
# Fused kernel and sharding options
|
44 |
+
chunk_size=256,
|
45 |
+
use_mem_eff_path=True,
|
46 |
+
layer_idx=None, # Absorb kwarg for general module
|
47 |
+
device=None,
|
48 |
+
dtype=None,
|
49 |
+
):
|
50 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
51 |
+
super().__init__()
|
52 |
+
self.d_model = d_model
|
53 |
+
self.d_state = d_state
|
54 |
+
self.d_conv = d_conv
|
55 |
+
self.conv_init = conv_init
|
56 |
+
self.expand = expand
|
57 |
+
self.d_inner = self.expand * self.d_model
|
58 |
+
self.headdim = headdim
|
59 |
+
self.ngroups = ngroups
|
60 |
+
assert self.d_inner % self.headdim == 0
|
61 |
+
self.nheads = self.d_inner // self.headdim
|
62 |
+
self.dt_limit = dt_limit
|
63 |
+
self.learnable_init_states = learnable_init_states
|
64 |
+
self.activation = activation
|
65 |
+
self.chunk_size = chunk_size
|
66 |
+
self.use_mem_eff_path = use_mem_eff_path
|
67 |
+
self.layer_idx = layer_idx
|
68 |
+
|
69 |
+
# Order: [z, x, B, C, dt]
|
70 |
+
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
71 |
+
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
|
72 |
+
|
73 |
+
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
|
74 |
+
self.conv1d = nn.Conv1d(
|
75 |
+
in_channels=conv_dim,
|
76 |
+
out_channels=conv_dim,
|
77 |
+
bias=conv_bias,
|
78 |
+
kernel_size=d_conv,
|
79 |
+
groups=conv_dim,
|
80 |
+
padding=d_conv - 1,
|
81 |
+
**factory_kwargs,
|
82 |
+
)
|
83 |
+
if self.conv_init is not None:
|
84 |
+
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
|
85 |
+
# self.conv1d.weight._no_weight_decay = True
|
86 |
+
|
87 |
+
if self.learnable_init_states:
|
88 |
+
self.init_states = nn.Parameter(
|
89 |
+
torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)
|
90 |
+
)
|
91 |
+
self.init_states._no_weight_decay = True
|
92 |
+
|
93 |
+
self.act = nn.SiLU()
|
94 |
+
|
95 |
+
# Initialize log dt bias
|
96 |
+
dt = torch.exp(
|
97 |
+
torch.rand(self.nheads, **factory_kwargs)
|
98 |
+
* (math.log(dt_max) - math.log(dt_min))
|
99 |
+
+ math.log(dt_min)
|
100 |
+
)
|
101 |
+
dt = torch.clamp(dt, min=dt_init_floor)
|
102 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
103 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
104 |
+
self.dt_bias = nn.Parameter(inv_dt)
|
105 |
+
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
106 |
+
# name.endswith("bias") in param_grouping.py
|
107 |
+
self.dt_bias._no_weight_decay = True
|
108 |
+
|
109 |
+
# A parameter
|
110 |
+
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
111 |
+
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
|
112 |
+
*A_init_range
|
113 |
+
)
|
114 |
+
A_log = torch.log(A).to(dtype=dtype)
|
115 |
+
self.A_log = nn.Parameter(A_log)
|
116 |
+
# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
|
117 |
+
self.A_log._no_weight_decay = True
|
118 |
+
|
119 |
+
# D "skip" parameter
|
120 |
+
self.D = nn.Parameter(torch.ones(self.nheads, device=device))
|
121 |
+
self.D._no_weight_decay = True
|
122 |
+
|
123 |
+
# Extra normalization layer right before output projection
|
124 |
+
assert RMSNormGated is not None
|
125 |
+
self.norm = RMSNormGated(
|
126 |
+
self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs
|
127 |
+
)
|
128 |
+
|
129 |
+
self.out_proj = nn.Linear(
|
130 |
+
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
131 |
+
)
|
132 |
+
|
133 |
+
def forward(self, u, seq_idx=None):
|
134 |
+
"""
|
135 |
+
u: (B, L, D)
|
136 |
+
Returns: same shape as u
|
137 |
+
"""
|
138 |
+
batch, seqlen, dim = u.shape
|
139 |
+
|
140 |
+
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
|
141 |
+
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
|
142 |
+
initial_states = (
|
143 |
+
repeat(self.init_states, "... -> b ...", b=batch)
|
144 |
+
if self.learnable_init_states
|
145 |
+
else None
|
146 |
+
)
|
147 |
+
dt_limit_kwargs = (
|
148 |
+
{} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
|
149 |
+
)
|
150 |
+
|
151 |
+
if self.use_mem_eff_path:
|
152 |
+
# Fully fused path
|
153 |
+
out = mamba_split_conv1d_scan_combined(
|
154 |
+
zxbcdt,
|
155 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
156 |
+
self.conv1d.bias,
|
157 |
+
self.dt_bias,
|
158 |
+
A,
|
159 |
+
D=self.D,
|
160 |
+
chunk_size=self.chunk_size,
|
161 |
+
seq_idx=seq_idx,
|
162 |
+
activation=self.activation,
|
163 |
+
rmsnorm_weight=self.norm.weight,
|
164 |
+
rmsnorm_eps=self.norm.eps,
|
165 |
+
outproj_weight=self.out_proj.weight,
|
166 |
+
outproj_bias=self.out_proj.bias,
|
167 |
+
headdim=self.headdim,
|
168 |
+
ngroups=self.ngroups,
|
169 |
+
norm_before_gate=False,
|
170 |
+
initial_states=initial_states,
|
171 |
+
**dt_limit_kwargs,
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
z, xBC, dt = torch.split(
|
175 |
+
zxbcdt,
|
176 |
+
[
|
177 |
+
self.d_inner,
|
178 |
+
self.d_inner + 2 * self.ngroups * self.d_state,
|
179 |
+
self.nheads,
|
180 |
+
],
|
181 |
+
dim=-1,
|
182 |
+
)
|
183 |
+
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
|
184 |
+
assert self.activation in ["silu", "swish"]
|
185 |
+
|
186 |
+
# 1D Convolution
|
187 |
+
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
188 |
+
xBC = self.act(
|
189 |
+
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
|
190 |
+
) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
191 |
+
xBC = xBC[:, :seqlen, :]
|
192 |
+
else:
|
193 |
+
xBC = causal_conv1d_fn(
|
194 |
+
x=xBC.transpose(1, 2),
|
195 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
196 |
+
bias=self.conv1d.bias,
|
197 |
+
activation=self.activation,
|
198 |
+
).transpose(1, 2)
|
199 |
+
|
200 |
+
# Split into 3 main branches: X, B, C
|
201 |
+
# These correspond to V, K, Q respectively in the SSM/attention duality
|
202 |
+
x, B, C = torch.split(
|
203 |
+
xBC,
|
204 |
+
[
|
205 |
+
self.d_inner,
|
206 |
+
self.ngroups * self.d_state,
|
207 |
+
self.ngroups * self.d_state,
|
208 |
+
],
|
209 |
+
dim=-1,
|
210 |
+
)
|
211 |
+
y = mamba_chunk_scan_combined(
|
212 |
+
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
213 |
+
dt,
|
214 |
+
A,
|
215 |
+
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
216 |
+
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
217 |
+
chunk_size=self.chunk_size,
|
218 |
+
D=self.D,
|
219 |
+
z=None,
|
220 |
+
seq_idx=seq_idx,
|
221 |
+
initial_states=initial_states,
|
222 |
+
**dt_limit_kwargs,
|
223 |
+
)
|
224 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
225 |
+
|
226 |
+
# Multiply "gate" branch and apply extra normalization layer
|
227 |
+
y = self.norm(y, z)
|
228 |
+
out = self.out_proj(y)
|
229 |
+
return out
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba_simple.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
|
13 |
+
from ..ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
14 |
+
|
15 |
+
try:
|
16 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
17 |
+
except ImportError:
|
18 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
19 |
+
|
20 |
+
try:
|
21 |
+
from ..ops.triton.selective_state_update import selective_state_update
|
22 |
+
except ImportError:
|
23 |
+
selective_state_update = None
|
24 |
+
|
25 |
+
try:
|
26 |
+
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
27 |
+
except ImportError:
|
28 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
29 |
+
|
30 |
+
|
31 |
+
class Mamba(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
d_model,
|
35 |
+
d_state=16,
|
36 |
+
d_conv=4,
|
37 |
+
expand=2,
|
38 |
+
dt_rank="auto",
|
39 |
+
dt_min=0.001,
|
40 |
+
dt_max=0.1,
|
41 |
+
dt_init="random",
|
42 |
+
dt_scale=1.0,
|
43 |
+
dt_init_floor=1e-4,
|
44 |
+
conv_bias=True,
|
45 |
+
bias=False,
|
46 |
+
use_fast_path=True, # Fused kernel options
|
47 |
+
layer_idx=None,
|
48 |
+
device=None,
|
49 |
+
dtype=None,
|
50 |
+
):
|
51 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
52 |
+
super().__init__()
|
53 |
+
self.d_model = d_model
|
54 |
+
self.d_state = d_state
|
55 |
+
self.d_conv = d_conv
|
56 |
+
self.expand = expand
|
57 |
+
self.d_inner = int(self.expand * self.d_model)
|
58 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
59 |
+
self.use_fast_path = use_fast_path
|
60 |
+
self.layer_idx = layer_idx
|
61 |
+
|
62 |
+
self.in_proj = nn.Linear(
|
63 |
+
self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
|
64 |
+
)
|
65 |
+
|
66 |
+
self.conv1d = nn.Conv1d(
|
67 |
+
in_channels=self.d_inner,
|
68 |
+
out_channels=self.d_inner,
|
69 |
+
bias=conv_bias,
|
70 |
+
kernel_size=d_conv,
|
71 |
+
groups=self.d_inner,
|
72 |
+
padding=d_conv - 1,
|
73 |
+
**factory_kwargs,
|
74 |
+
)
|
75 |
+
|
76 |
+
self.activation = "silu"
|
77 |
+
self.act = nn.SiLU()
|
78 |
+
|
79 |
+
self.x_proj = nn.Linear(
|
80 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
81 |
+
)
|
82 |
+
self.dt_proj = nn.Linear(
|
83 |
+
self.dt_rank, self.d_inner, bias=True, **factory_kwargs
|
84 |
+
)
|
85 |
+
|
86 |
+
# Initialize special dt projection to preserve variance at initialization
|
87 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
88 |
+
if dt_init == "constant":
|
89 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
90 |
+
elif dt_init == "random":
|
91 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
92 |
+
else:
|
93 |
+
raise NotImplementedError
|
94 |
+
|
95 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
96 |
+
dt = torch.exp(
|
97 |
+
torch.rand(self.d_inner, **factory_kwargs)
|
98 |
+
* (math.log(dt_max) - math.log(dt_min))
|
99 |
+
+ math.log(dt_min)
|
100 |
+
).clamp(min=dt_init_floor)
|
101 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
102 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
103 |
+
with torch.no_grad():
|
104 |
+
self.dt_proj.bias.copy_(inv_dt)
|
105 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
106 |
+
self.dt_proj.bias._no_reinit = True
|
107 |
+
|
108 |
+
# S4D real initialization
|
109 |
+
A = repeat(
|
110 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
111 |
+
"n -> d n",
|
112 |
+
d=self.d_inner,
|
113 |
+
).contiguous()
|
114 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
115 |
+
self.A_log = nn.Parameter(A_log)
|
116 |
+
self.A_log._no_weight_decay = True
|
117 |
+
|
118 |
+
# D "skip" parameter
|
119 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
120 |
+
self.D._no_weight_decay = True
|
121 |
+
|
122 |
+
self.out_proj = nn.Linear(
|
123 |
+
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
124 |
+
)
|
125 |
+
|
126 |
+
def forward(self, hidden_states, inference_params=None):
|
127 |
+
"""
|
128 |
+
hidden_states: (B, L, D)
|
129 |
+
Returns: same shape as hidden_states
|
130 |
+
"""
|
131 |
+
batch, seqlen, dim = hidden_states.shape
|
132 |
+
|
133 |
+
conv_state, ssm_state = None, None
|
134 |
+
if inference_params is not None:
|
135 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
136 |
+
if inference_params.seqlen_offset > 0:
|
137 |
+
# The states are updated inplace
|
138 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
139 |
+
return out
|
140 |
+
|
141 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
142 |
+
xz = rearrange(
|
143 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
144 |
+
"d (b l) -> b d l",
|
145 |
+
l=seqlen,
|
146 |
+
)
|
147 |
+
if self.in_proj.bias is not None:
|
148 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
149 |
+
|
150 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
151 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
152 |
+
if (
|
153 |
+
self.use_fast_path
|
154 |
+
and causal_conv1d_fn is not None
|
155 |
+
and inference_params is None
|
156 |
+
): # Doesn't support outputting the states
|
157 |
+
out = mamba_inner_fn(
|
158 |
+
xz,
|
159 |
+
self.conv1d.weight,
|
160 |
+
self.conv1d.bias,
|
161 |
+
self.x_proj.weight,
|
162 |
+
self.dt_proj.weight,
|
163 |
+
self.out_proj.weight,
|
164 |
+
self.out_proj.bias,
|
165 |
+
A,
|
166 |
+
None, # input-dependent B
|
167 |
+
None, # input-dependent C
|
168 |
+
self.D.float(),
|
169 |
+
delta_bias=self.dt_proj.bias.float(),
|
170 |
+
delta_softplus=True,
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
x, z = xz.chunk(2, dim=1)
|
174 |
+
# Compute short convolution
|
175 |
+
if conv_state is not None:
|
176 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
177 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
178 |
+
conv_state.copy_(
|
179 |
+
F.pad(x, (self.d_conv - x.shape[-1], 0))
|
180 |
+
) # Update state (B D W)
|
181 |
+
if causal_conv1d_fn is None:
|
182 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
183 |
+
else:
|
184 |
+
assert self.activation in ["silu", "swish"]
|
185 |
+
x = causal_conv1d_fn(
|
186 |
+
x=x,
|
187 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
188 |
+
bias=self.conv1d.bias,
|
189 |
+
activation=self.activation,
|
190 |
+
)
|
191 |
+
|
192 |
+
# We're careful here about the layout, to avoid extra transposes.
|
193 |
+
# We want dt to have d as the slowest moving dimension
|
194 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
195 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
196 |
+
dt, B, C = torch.split(
|
197 |
+
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
|
198 |
+
)
|
199 |
+
dt = self.dt_proj.weight @ dt.t()
|
200 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
201 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
202 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
203 |
+
assert self.activation in ["silu", "swish"]
|
204 |
+
y = selective_scan_fn(
|
205 |
+
x,
|
206 |
+
dt,
|
207 |
+
A,
|
208 |
+
B,
|
209 |
+
C,
|
210 |
+
self.D.float(),
|
211 |
+
z=z,
|
212 |
+
delta_bias=self.dt_proj.bias.float(),
|
213 |
+
delta_softplus=True,
|
214 |
+
return_last_state=ssm_state is not None,
|
215 |
+
)
|
216 |
+
if ssm_state is not None:
|
217 |
+
y, last_state = y
|
218 |
+
ssm_state.copy_(last_state)
|
219 |
+
y = rearrange(y, "b d l -> b l d")
|
220 |
+
out = self.out_proj(y)
|
221 |
+
return out
|
222 |
+
|
223 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
224 |
+
dtype = hidden_states.dtype
|
225 |
+
assert (
|
226 |
+
hidden_states.shape[1] == 1
|
227 |
+
), "Only support decoding with 1 token at a time for now"
|
228 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
229 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
230 |
+
|
231 |
+
# Conv step
|
232 |
+
if causal_conv1d_update is None:
|
233 |
+
conv_state.copy_(
|
234 |
+
torch.roll(conv_state, shifts=-1, dims=-1)
|
235 |
+
) # Update state (B D W)
|
236 |
+
conv_state[:, :, -1] = x
|
237 |
+
x = torch.sum(
|
238 |
+
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
239 |
+
) # (B D)
|
240 |
+
if self.conv1d.bias is not None:
|
241 |
+
x = x + self.conv1d.bias
|
242 |
+
x = self.act(x).to(dtype=dtype)
|
243 |
+
else:
|
244 |
+
x = causal_conv1d_update(
|
245 |
+
x,
|
246 |
+
conv_state,
|
247 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
248 |
+
self.conv1d.bias,
|
249 |
+
self.activation,
|
250 |
+
)
|
251 |
+
|
252 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
253 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
254 |
+
# Don't add dt_bias here
|
255 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
256 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
257 |
+
|
258 |
+
# SSM step
|
259 |
+
if selective_state_update is None:
|
260 |
+
# Discretize A and B
|
261 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
262 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
263 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
264 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
265 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
266 |
+
y = y + self.D.to(dtype) * x
|
267 |
+
y = y * self.act(z) # (B D)
|
268 |
+
else:
|
269 |
+
y = selective_state_update(
|
270 |
+
ssm_state,
|
271 |
+
x,
|
272 |
+
dt,
|
273 |
+
A,
|
274 |
+
B,
|
275 |
+
C,
|
276 |
+
self.D,
|
277 |
+
z=z,
|
278 |
+
dt_bias=self.dt_proj.bias,
|
279 |
+
dt_softplus=True,
|
280 |
+
)
|
281 |
+
|
282 |
+
out = self.out_proj(y)
|
283 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
284 |
+
|
285 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
286 |
+
device = self.out_proj.weight.device
|
287 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
288 |
+
conv_state = torch.zeros(
|
289 |
+
batch_size,
|
290 |
+
self.d_model * self.expand,
|
291 |
+
self.d_conv,
|
292 |
+
device=device,
|
293 |
+
dtype=conv_dtype,
|
294 |
+
)
|
295 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
296 |
+
# ssm_dtype = torch.float32
|
297 |
+
ssm_state = torch.zeros(
|
298 |
+
batch_size,
|
299 |
+
self.d_model * self.expand,
|
300 |
+
self.d_state,
|
301 |
+
device=device,
|
302 |
+
dtype=ssm_dtype,
|
303 |
+
)
|
304 |
+
return conv_state, ssm_state
|
305 |
+
|
306 |
+
def _get_states_from_cache(
|
307 |
+
self, inference_params, batch_size, initialize_states=False
|
308 |
+
):
|
309 |
+
assert self.layer_idx is not None
|
310 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
311 |
+
batch_shape = (batch_size,)
|
312 |
+
conv_state = torch.zeros(
|
313 |
+
batch_size,
|
314 |
+
self.d_model * self.expand,
|
315 |
+
self.d_conv,
|
316 |
+
device=self.conv1d.weight.device,
|
317 |
+
dtype=self.conv1d.weight.dtype,
|
318 |
+
)
|
319 |
+
ssm_state = torch.zeros(
|
320 |
+
batch_size,
|
321 |
+
self.d_model * self.expand,
|
322 |
+
self.d_state,
|
323 |
+
device=self.dt_proj.weight.device,
|
324 |
+
dtype=self.dt_proj.weight.dtype,
|
325 |
+
# dtype=torch.float32,
|
326 |
+
)
|
327 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (
|
328 |
+
conv_state,
|
329 |
+
ssm_state,
|
330 |
+
)
|
331 |
+
else:
|
332 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[
|
333 |
+
self.layer_idx
|
334 |
+
]
|
335 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
336 |
+
if initialize_states:
|
337 |
+
conv_state.zero_()
|
338 |
+
ssm_state.zero_()
|
339 |
+
return conv_state, ssm_state
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mha.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
try:
|
11 |
+
from flash_attn import flash_attn_with_kvcache
|
12 |
+
except ImportError:
|
13 |
+
flash_attn_with_kvcache = None
|
14 |
+
|
15 |
+
try:
|
16 |
+
from flash_attn.layers.rotary import RotaryEmbedding
|
17 |
+
except ImportError:
|
18 |
+
RotaryEmbedding = None
|
19 |
+
|
20 |
+
try:
|
21 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
22 |
+
except ImportError:
|
23 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
24 |
+
|
25 |
+
|
26 |
+
def _update_kv_cache(kv, inference_params, layer_idx):
|
27 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
28 |
+
# Pre-allocate memory for key-values for inference.
|
29 |
+
num_heads, head_dim = kv.shape[-2:]
|
30 |
+
assert layer_idx in inference_params.key_value_memory_dict
|
31 |
+
kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
|
32 |
+
# Adjust key and value for inference
|
33 |
+
batch_start = inference_params.batch_size_offset
|
34 |
+
batch_end = batch_start + kv.shape[0]
|
35 |
+
sequence_start = inference_params.seqlen_offset
|
36 |
+
sequence_end = sequence_start + kv.shape[1]
|
37 |
+
assert batch_end <= kv_cache.shape[0]
|
38 |
+
assert sequence_end <= kv_cache.shape[1]
|
39 |
+
assert kv_cache is not None
|
40 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
41 |
+
return kv_cache[batch_start:batch_end, :sequence_end, ...]
|
42 |
+
|
43 |
+
|
44 |
+
class MHA(nn.Module):
|
45 |
+
"""Multi-head self-attention and cross-attention"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
embed_dim,
|
50 |
+
num_heads,
|
51 |
+
num_heads_kv=None,
|
52 |
+
head_dim=None, # If None, use embed_dim // num_heads
|
53 |
+
mlp_dim=0,
|
54 |
+
qkv_proj_bias=True,
|
55 |
+
out_proj_bias=True,
|
56 |
+
softmax_scale=None,
|
57 |
+
causal=False,
|
58 |
+
layer_idx=None,
|
59 |
+
d_conv=0,
|
60 |
+
rotary_emb_dim=0,
|
61 |
+
rotary_emb_base=10000.0,
|
62 |
+
rotary_emb_interleaved=False,
|
63 |
+
device=None,
|
64 |
+
dtype=None,
|
65 |
+
) -> None:
|
66 |
+
"""
|
67 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
68 |
+
return_residual: whether to return the input x along with the output. This is for
|
69 |
+
performance reason: for post-norm architecture, returning the input allows us
|
70 |
+
to fuse the backward of nn.Linear with the residual connection.
|
71 |
+
"""
|
72 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
73 |
+
super().__init__()
|
74 |
+
self.embed_dim = embed_dim
|
75 |
+
self.layer_idx = layer_idx
|
76 |
+
self.d_conv = d_conv
|
77 |
+
self.rotary_emb_dim = rotary_emb_dim
|
78 |
+
self.softmax_scale = softmax_scale
|
79 |
+
self.causal = causal
|
80 |
+
|
81 |
+
self.num_heads = num_heads
|
82 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
83 |
+
assert (
|
84 |
+
self.num_heads % self.num_heads_kv == 0
|
85 |
+
), "num_heads must be divisible by num_heads_kv"
|
86 |
+
if head_dim is None:
|
87 |
+
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
88 |
+
self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
|
89 |
+
self.mlp_dim = math.ceil(mlp_dim / 256) * 256
|
90 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
91 |
+
out_dim = self.head_dim * self.num_heads
|
92 |
+
|
93 |
+
if self.rotary_emb_dim > 0:
|
94 |
+
assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
|
95 |
+
self.rotary_emb = RotaryEmbedding(
|
96 |
+
self.rotary_emb_dim,
|
97 |
+
base=rotary_emb_base,
|
98 |
+
interleaved=rotary_emb_interleaved,
|
99 |
+
device=device,
|
100 |
+
)
|
101 |
+
|
102 |
+
self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
|
103 |
+
if self.d_conv > 0:
|
104 |
+
self.conv1d = nn.Conv1d(
|
105 |
+
qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
|
106 |
+
**factory_kwargs
|
107 |
+
)
|
108 |
+
self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
109 |
+
|
110 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
111 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
112 |
+
device = self.out_proj.weight.device
|
113 |
+
if self.d_conv > 0:
|
114 |
+
conv_state = torch.zeros(
|
115 |
+
batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
conv_state = None
|
119 |
+
kv_cache = torch.empty(
|
120 |
+
batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
|
121 |
+
)
|
122 |
+
return kv_cache, conv_state
|
123 |
+
|
124 |
+
def _update_kv_cache(self, kv, inference_params):
|
125 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
126 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
127 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
128 |
+
|
129 |
+
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
130 |
+
"""
|
131 |
+
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
132 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
133 |
+
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
134 |
+
"""
|
135 |
+
assert inference_params is not None and inference_params.seqlen_offset > 0
|
136 |
+
if self.rotary_emb_dim > 0:
|
137 |
+
self.rotary_emb._update_cos_sin_cache(
|
138 |
+
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
139 |
+
)
|
140 |
+
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
141 |
+
else:
|
142 |
+
rotary_cos, rotary_sin = None, None
|
143 |
+
batch = q.shape[0]
|
144 |
+
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
|
145 |
+
kv_cache = kv_cache[:batch]
|
146 |
+
cache_seqlens = (
|
147 |
+
inference_params.lengths_per_sample[:batch]
|
148 |
+
if inference_params.lengths_per_sample is not None
|
149 |
+
else inference_params.seqlen_offset
|
150 |
+
)
|
151 |
+
assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
|
152 |
+
context = flash_attn_with_kvcache(
|
153 |
+
q,
|
154 |
+
kv_cache[:, :, 0],
|
155 |
+
kv_cache[:, :, 1],
|
156 |
+
kv[:, :, 0],
|
157 |
+
kv[:, :, 1],
|
158 |
+
rotary_cos=rotary_cos,
|
159 |
+
rotary_sin=rotary_sin,
|
160 |
+
cache_seqlens=cache_seqlens,
|
161 |
+
softmax_scale=self.softmax_scale,
|
162 |
+
causal=self.causal,
|
163 |
+
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
164 |
+
)
|
165 |
+
return context
|
166 |
+
|
167 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
168 |
+
"""Write kv to inference_params, then do attention"""
|
169 |
+
if (
|
170 |
+
inference_params.seqlen_offset == 0
|
171 |
+
or flash_attn_with_kvcache is None
|
172 |
+
):
|
173 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
174 |
+
kv = self._update_kv_cache(kv, inference_params)
|
175 |
+
k, v = kv.unbind(dim=-3)
|
176 |
+
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
177 |
+
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
178 |
+
return F.scaled_dot_product_attention(
|
179 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
|
180 |
+
).transpose(1, 2)
|
181 |
+
else:
|
182 |
+
batch = q.shape[0]
|
183 |
+
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
|
184 |
+
kv_cache = kv_cache[:batch]
|
185 |
+
cache_seqlens = (
|
186 |
+
inference_params.lengths_per_sample[:batch]
|
187 |
+
if inference_params.lengths_per_sample is not None
|
188 |
+
else inference_params.seqlen_offset
|
189 |
+
)
|
190 |
+
return flash_attn_with_kvcache(
|
191 |
+
q,
|
192 |
+
kv_cache[:, :, 0],
|
193 |
+
kv_cache[:, :, 1],
|
194 |
+
kv[:, :, 0],
|
195 |
+
kv[:, :, 1],
|
196 |
+
cache_seqlens=cache_seqlens,
|
197 |
+
softmax_scale=self.softmax_scale,
|
198 |
+
causal=self.causal,
|
199 |
+
)
|
200 |
+
|
201 |
+
def forward(self, x, inference_params=None):
|
202 |
+
"""
|
203 |
+
Arguments:
|
204 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
205 |
+
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
206 |
+
is the is the sum of the sequence lengths in the batch.
|
207 |
+
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
208 |
+
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
209 |
+
"""
|
210 |
+
if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
|
211 |
+
inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
|
212 |
+
x.shape[0], inference_params.max_seqlen, dtype=x.dtype
|
213 |
+
)
|
214 |
+
seqlen_offset = (
|
215 |
+
0
|
216 |
+
if inference_params is None
|
217 |
+
else (
|
218 |
+
inference_params.lengths_per_sample
|
219 |
+
if inference_params.lengths_per_sample is not None
|
220 |
+
else inference_params.seqlen_offset
|
221 |
+
)
|
222 |
+
)
|
223 |
+
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
224 |
+
qkv = self.in_proj(x)
|
225 |
+
if self.mlp_dim > 0:
|
226 |
+
qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
|
227 |
+
x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
|
228 |
+
x_mlp = x_mlp_up * F.silu(x_mlp_gate)
|
229 |
+
if self.d_conv > 0:
|
230 |
+
# The inference code for conv1d is pretty messy, should clean it up
|
231 |
+
if (inference_params is None or inference_params.seqlen_offset == 0):
|
232 |
+
if causal_conv1d_fn is None:
|
233 |
+
qkv = rearrange(
|
234 |
+
self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
|
235 |
+
).contiguous()
|
236 |
+
else:
|
237 |
+
qkv = causal_conv1d_fn(
|
238 |
+
qkv.transpose(1, 2),
|
239 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
240 |
+
self.conv1d.bias
|
241 |
+
).transpose(1, 2)
|
242 |
+
if inference_params is not None:
|
243 |
+
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
|
244 |
+
# If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
245 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
246 |
+
qkv_t = rearrange(qkv, "b l d -> b d l")
|
247 |
+
conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
|
248 |
+
else:
|
249 |
+
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
|
250 |
+
assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
251 |
+
qkv = qkv.squeeze(1)
|
252 |
+
# Conv step
|
253 |
+
if causal_conv1d_update is None:
|
254 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
255 |
+
conv_state[:, :, -1] = qkv
|
256 |
+
qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
257 |
+
if self.conv1d.bias is not None:
|
258 |
+
qkv = qkv + self.conv1d.bias
|
259 |
+
else:
|
260 |
+
qkv = causal_conv1d_update(
|
261 |
+
qkv,
|
262 |
+
conv_state,
|
263 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
264 |
+
self.conv1d.bias
|
265 |
+
)
|
266 |
+
qkv = qkv.unsqueeze(1)
|
267 |
+
q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
|
268 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
269 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
270 |
+
if (
|
271 |
+
inference_params is None
|
272 |
+
or inference_params.seqlen_offset == 0
|
273 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
274 |
+
):
|
275 |
+
if self.rotary_emb_dim > 0:
|
276 |
+
q, kv = self.rotary_emb(
|
277 |
+
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
278 |
+
)
|
279 |
+
if inference_params is None:
|
280 |
+
k, v = kv.unbind(dim=-3)
|
281 |
+
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
282 |
+
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
283 |
+
context = F.scaled_dot_product_attention(
|
284 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
|
285 |
+
).transpose(1, 2)
|
286 |
+
else:
|
287 |
+
context = self._update_kvcache_attention(q, kv, inference_params)
|
288 |
+
else:
|
289 |
+
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
290 |
+
context = rearrange(context, "... h d -> ... (h d)")
|
291 |
+
if self.mlp_dim > 0:
|
292 |
+
context = torch.cat([context, x_mlp], dim=-1)
|
293 |
+
out = self.out_proj(context)
|
294 |
+
return out
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mlp.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class GatedMLP(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
in_features,
|
10 |
+
hidden_features=None,
|
11 |
+
out_features=None,
|
12 |
+
activation=F.silu,
|
13 |
+
bias=False,
|
14 |
+
multiple_of=128,
|
15 |
+
device=None,
|
16 |
+
dtype=None,
|
17 |
+
):
|
18 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
19 |
+
super().__init__()
|
20 |
+
out_features = out_features if out_features is not None else in_features
|
21 |
+
hidden_features = (
|
22 |
+
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
23 |
+
)
|
24 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
25 |
+
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
|
26 |
+
self.activation = activation
|
27 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
y = self.fc1(x)
|
31 |
+
y, gate = y.chunk(2, dim=-1)
|
32 |
+
y = y * self.activation(gate)
|
33 |
+
y = self.fc2(y)
|
34 |
+
return y
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/ssd_minimal.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Albert Gu and Tri Dao.
|
2 |
+
"""Minimal implementation of SSD.
|
3 |
+
|
4 |
+
This is the same as Listing 1 from the paper.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
|
12 |
+
|
13 |
+
|
14 |
+
def segsum_unstable(x):
|
15 |
+
"""Naive segment sum calculation."""
|
16 |
+
T = x.size(-1)
|
17 |
+
x_cumsum = torch.cumsum(x, dim=-1)
|
18 |
+
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
|
19 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
20 |
+
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
21 |
+
return x_segsum
|
22 |
+
|
23 |
+
|
24 |
+
def segsum(x):
|
25 |
+
"""More stable segment sum calculation."""
|
26 |
+
T = x.size(-1)
|
27 |
+
x = repeat(x, "... d -> ... d e", e=T)
|
28 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
|
29 |
+
x = x.masked_fill(~mask, 0)
|
30 |
+
x_segsum = torch.cumsum(x, dim=-2)
|
31 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
32 |
+
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
33 |
+
return x_segsum
|
34 |
+
|
35 |
+
|
36 |
+
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
|
37 |
+
"""
|
38 |
+
Arguments:
|
39 |
+
X: (batch, length, n_heads, d_head)
|
40 |
+
A: (batch, length, n_heads)
|
41 |
+
B: (batch, length, n_heads, d_state)
|
42 |
+
C: (batch, length, n_heads, d_state)
|
43 |
+
Return:
|
44 |
+
Y: (batch, length, n_heads, d_head)
|
45 |
+
"""
|
46 |
+
assert X.dtype == A.dtype == B.dtype == C.dtype
|
47 |
+
assert X.shape[1] % block_len == 0
|
48 |
+
|
49 |
+
# Rearrange into blocks/chunks
|
50 |
+
X, A, B, C = [
|
51 |
+
rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
|
52 |
+
]
|
53 |
+
|
54 |
+
A = rearrange(A, "b c l h -> b h c l")
|
55 |
+
A_cumsum = torch.cumsum(A, dim=-1)
|
56 |
+
|
57 |
+
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
58 |
+
L = torch.exp(segsum(A))
|
59 |
+
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
|
60 |
+
|
61 |
+
# 2. Compute the state for each intra-chunk
|
62 |
+
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
63 |
+
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
|
64 |
+
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
|
65 |
+
|
66 |
+
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
67 |
+
# (middle term of factorization of off-diag blocks; A terms)
|
68 |
+
if initial_states is None:
|
69 |
+
initial_states = torch.zeros_like(states[:, :1])
|
70 |
+
states = torch.cat([initial_states, states], dim=1)
|
71 |
+
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
72 |
+
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
73 |
+
states, final_state = new_states[:, :-1], new_states[:, -1]
|
74 |
+
|
75 |
+
# 4. Compute state -> output conversion per chunk
|
76 |
+
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
77 |
+
state_decay_out = torch.exp(A_cumsum)
|
78 |
+
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
|
79 |
+
|
80 |
+
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
81 |
+
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
|
82 |
+
return Y, final_state
|
83 |
+
|
84 |
+
|
85 |
+
# Simple test
|
86 |
+
def test_correctness():
|
87 |
+
torch.manual_seed(42)
|
88 |
+
|
89 |
+
## Dimensions
|
90 |
+
# Denoted (B, T, Q, D, P) in the paper
|
91 |
+
batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
|
92 |
+
nheads = dim // headdim # (H) in the paper
|
93 |
+
ngroups = 1 # (G) in the paper
|
94 |
+
dstate = 64 # (N) in the paper
|
95 |
+
dtype = torch.float32
|
96 |
+
device = "cuda"
|
97 |
+
|
98 |
+
x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
|
99 |
+
dt = F.softplus(
|
100 |
+
torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4
|
101 |
+
).requires_grad_()
|
102 |
+
A = (
|
103 |
+
-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))
|
104 |
+
).requires_grad_()
|
105 |
+
B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
|
106 |
+
C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
|
107 |
+
D = torch.randn(nheads, dtype=dtype, device=device)
|
108 |
+
|
109 |
+
# Comparing fused version and minimal version
|
110 |
+
y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None)
|
111 |
+
y_min, _ = ssd_minimal_discrete(x * dt.unsqueeze(-1), A * dt, B, C, chunk_size)
|
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/ops/__init__.py
ADDED
File without changes
|