danieldk HF Staff commited on
Commit
2e3f8b7
·
1 Parent(s): 85e64a5
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/{_mamba_ssm_bft6nicqkg6ni.abi3.so → _mamba_ssm_85e64a5.abi3.so} +2 -2
  2. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py +3 -3
  3. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/{_mamba_ssm_nmrmresto7zfi.abi3.so → _mamba_ssm_85e64a5.abi3.so} +2 -2
  4. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/_ops.py +3 -3
  5. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/{_mamba_ssm_fhbfq4rqrrau4.abi3.so → _mamba_ssm_85e64a5.abi3.so} +2 -2
  6. build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/_ops.py +3 -3
  7. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/{_mamba_ssm_konfvt7wiz4bc.abi3.so → _mamba_ssm_85e64a5.abi3.so} +2 -2
  8. build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/_ops.py +3 -3
  9. build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
  10. build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/_mamba_ssm_b7y35xkw542po.abi3.so +0 -3
  11. build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/_ops.py +3 -3
  12. build/torch25-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_3nr5ex3ddrv6c.abi3.so +0 -3
  13. build/torch25-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
  14. build/torch25-cxx98-cu124-x86_64-linux/mamba_ssm/_ops.py +3 -3
  15. build/torch26-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
  16. build/torch26-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_w4jqdduxei7ne.abi3.so +0 -3
  17. build/torch26-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py +3 -3
  18. build/torch26-cxx11-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
  19. build/torch26-cxx11-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_h4pt4pjmzduuo.abi3.so +0 -3
  20. build/torch26-cxx11-cu124-x86_64-linux/mamba_ssm/_ops.py +3 -3
  21. build/torch26-cxx11-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
  22. build/torch26-cxx11-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_ad2dqkuyppsay.abi3.so +0 -3
  23. build/torch26-cxx11-cu126-x86_64-linux/mamba_ssm/_ops.py +3 -3
  24. build/torch26-cxx98-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
  25. build/torch26-cxx98-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_g4gqbotnq7pgy.abi3.so +0 -3
  26. build/torch26-cxx98-cu118-x86_64-linux/mamba_ssm/_ops.py +3 -3
  27. build/torch26-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
  28. build/torch26-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_r7gpumhmqnfog.abi3.so +0 -3
  29. build/torch26-cxx98-cu124-x86_64-linux/mamba_ssm/_ops.py +3 -3
  30. build/torch26-cxx98-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
  31. build/torch26-cxx98-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_ojx7o3olgtezs.abi3.so +0 -3
  32. build/torch26-cxx98-cu126-x86_64-linux/mamba_ssm/_ops.py +3 -3
  33. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/__init__.py +14 -0
  34. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so +3 -0
  35. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py +9 -0
  36. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/__init__.py +0 -0
  37. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/distributed_utils.py +144 -0
  38. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +326 -0
  39. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/models/__init__.py +0 -0
  40. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/models/config_mamba.py +18 -0
  41. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +338 -0
  42. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/__init__.py +0 -0
  43. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/block.py +107 -0
  44. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2.py +502 -0
  45. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2_simple.py +229 -0
  46. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba_simple.py +339 -0
  47. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mha.py +294 -0
  48. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mlp.py +34 -0
  49. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/ssd_minimal.py +111 -0
  50. build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/ops/__init__.py +0 -0
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/{_mamba_ssm_bft6nicqkg6ni.abi3.so → _mamba_ssm_85e64a5.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c0e8bc801359703c8d092b7c8c9906bd59c083d94e6778b621ba709d79fff5a0
3
- size 258973648
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65499c06108d39d466d78b0e628645aff076a8db1ebd6ed6c09d05ccb4c80296
3
+ size 261767104
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _mamba_ssm_bft6nicqkg6ni
3
- ops = torch.ops._mamba_ssm_bft6nicqkg6ni
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_mamba_ssm_bft6nicqkg6ni::{op_name}"
 
1
  import torch
2
+ from . import _mamba_ssm_85e64a5
3
+ ops = torch.ops._mamba_ssm_85e64a5
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_mamba_ssm_85e64a5::{op_name}"
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/{_mamba_ssm_nmrmresto7zfi.abi3.so → _mamba_ssm_85e64a5.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:854cfdd1c899869de1c88a6a56de1494a3d4a0edd1a04412167599485bc1093e
3
- size 247806288
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:878481a2cddfbf12eea8a8f746abdea5bd80a8b8feeb6aa87b9561a2987bebea
3
+ size 250394944
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _mamba_ssm_nmrmresto7zfi
3
- ops = torch.ops._mamba_ssm_nmrmresto7zfi
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_mamba_ssm_nmrmresto7zfi::{op_name}"
 
1
  import torch
2
+ from . import _mamba_ssm_85e64a5
3
+ ops = torch.ops._mamba_ssm_85e64a5
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_mamba_ssm_85e64a5::{op_name}"
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/{_mamba_ssm_fhbfq4rqrrau4.abi3.so → _mamba_ssm_85e64a5.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:14d2f2d58ab71802b3bb7a21d3ee808fb501d207b0c437fd95637ef9f0a348f2
3
- size 246550264
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7987b149563bfaef257006c494496743db9eb769fb1727b0f7b260839193cb30
3
+ size 249159400
build/torch25-cxx11-cu124-x86_64-linux/mamba_ssm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _mamba_ssm_fhbfq4rqrrau4
3
- ops = torch.ops._mamba_ssm_fhbfq4rqrrau4
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_mamba_ssm_fhbfq4rqrrau4::{op_name}"
 
1
  import torch
2
+ from . import _mamba_ssm_85e64a5
3
+ ops = torch.ops._mamba_ssm_85e64a5
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_mamba_ssm_85e64a5::{op_name}"
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/{_mamba_ssm_konfvt7wiz4bc.abi3.so → _mamba_ssm_85e64a5.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6180eb17a4eddc08eb3f313ec2835464f0ee2e7418ccc09f4fe2183c4165933a
3
- size 258974432
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6a97c777005c18be46b647314eb6f1659f894ef1aa06a1ea9827f1288156545
3
+ size 261763792
build/torch25-cxx98-cu118-x86_64-linux/mamba_ssm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _mamba_ssm_konfvt7wiz4bc
3
- ops = torch.ops._mamba_ssm_konfvt7wiz4bc
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_mamba_ssm_konfvt7wiz4bc::{op_name}"
 
1
  import torch
2
+ from . import _mamba_ssm_85e64a5
3
+ ops = torch.ops._mamba_ssm_85e64a5
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_mamba_ssm_85e64a5::{op_name}"
build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7826f86b0f364f57d88c307511a12d2981ae0d16b4f30913e315794c794f563
3
+ size 250387568
build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/_mamba_ssm_b7y35xkw542po.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c01510f5ed28163a69098ed3ac7d18b42d58d38ed5cd15cc2088c0b5056be5d5
3
- size 247798920
 
 
 
 
build/torch25-cxx98-cu121-x86_64-linux/mamba_ssm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _mamba_ssm_b7y35xkw542po
3
- ops = torch.ops._mamba_ssm_b7y35xkw542po
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_mamba_ssm_b7y35xkw542po::{op_name}"
 
1
  import torch
2
+ from . import _mamba_ssm_85e64a5
3
+ ops = torch.ops._mamba_ssm_85e64a5
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_mamba_ssm_85e64a5::{op_name}"
build/torch25-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_3nr5ex3ddrv6c.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:856bdc363912163550832cb546797c1c668494d7a1244016e827efe9e945af4a
3
- size 246546992
 
 
 
 
build/torch25-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ef95dd140a1ea41f59e4af3a6fa95b1ab14fbfb2ef529bd3005cf7471d2c5f7
3
+ size 249156128
build/torch25-cxx98-cu124-x86_64-linux/mamba_ssm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _mamba_ssm_3nr5ex3ddrv6c
3
- ops = torch.ops._mamba_ssm_3nr5ex3ddrv6c
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_mamba_ssm_3nr5ex3ddrv6c::{op_name}"
 
1
  import torch
2
+ from . import _mamba_ssm_85e64a5
3
+ ops = torch.ops._mamba_ssm_85e64a5
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_mamba_ssm_85e64a5::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8eb776fd2d50d2c6a83643848baee6a09ac1b0f0ae085d90aa918fd86c13f09d
3
+ size 261771440
build/torch26-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_w4jqdduxei7ne.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d87ac7d060e64fbe8651bd6a07234a49464c326a705f58b9d21f611548cbe7f6
3
- size 258977984
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _mamba_ssm_w4jqdduxei7ne
3
- ops = torch.ops._mamba_ssm_w4jqdduxei7ne
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_mamba_ssm_w4jqdduxei7ne::{op_name}"
 
1
  import torch
2
+ from . import _mamba_ssm_85e64a5
3
+ ops = torch.ops._mamba_ssm_85e64a5
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_mamba_ssm_85e64a5::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf2b6908f571764b70f882e10c46f0261c5e9a6af22774879e490886a3fa5b9d
3
+ size 249159824
build/torch26-cxx11-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_h4pt4pjmzduuo.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a94840c139f6dd950a88a3d049bf301fdeffa69e5fd6d2a5227cceb904189c49
3
- size 246550688
 
 
 
 
build/torch26-cxx11-cu124-x86_64-linux/mamba_ssm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _mamba_ssm_h4pt4pjmzduuo
3
- ops = torch.ops._mamba_ssm_h4pt4pjmzduuo
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_mamba_ssm_h4pt4pjmzduuo::{op_name}"
 
1
  import torch
2
+ from . import _mamba_ssm_85e64a5
3
+ ops = torch.ops._mamba_ssm_85e64a5
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_mamba_ssm_85e64a5::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:467b1ba8ff9848471a5b9d77910ef20f59a41dc532eacf11cdbc8c6d20ba5f40
3
+ size 250021072
build/torch26-cxx11-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_ad2dqkuyppsay.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e7625fca55530d1e2f3801c430cf50c0f38f744f631ff2ff3cfd48ed88a63eb9
3
- size 247579880
 
 
 
 
build/torch26-cxx11-cu126-x86_64-linux/mamba_ssm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _mamba_ssm_ad2dqkuyppsay
3
- ops = torch.ops._mamba_ssm_ad2dqkuyppsay
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_mamba_ssm_ad2dqkuyppsay::{op_name}"
 
1
  import torch
2
+ from . import _mamba_ssm_85e64a5
3
+ ops = torch.ops._mamba_ssm_85e64a5
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_mamba_ssm_85e64a5::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a3b8a82f5359fb1f1f9736d5c62d7a7449f76547715115b3f2e6cf706b06f89
3
+ size 261764024
build/torch26-cxx98-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_g4gqbotnq7pgy.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:966b3d0a8ea86436492ca16aead68420d2270094f4c38425b0551f556a410960
3
- size 258970576
 
 
 
 
build/torch26-cxx98-cu118-x86_64-linux/mamba_ssm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _mamba_ssm_g4gqbotnq7pgy
3
- ops = torch.ops._mamba_ssm_g4gqbotnq7pgy
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_mamba_ssm_g4gqbotnq7pgy::{op_name}"
 
1
  import torch
2
+ from . import _mamba_ssm_85e64a5
3
+ ops = torch.ops._mamba_ssm_85e64a5
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_mamba_ssm_85e64a5::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81a1cdd99de783191a0551d958e13c0ff3625864a87b3e3850db3603ba8cebb1
3
+ size 249156424
build/torch26-cxx98-cu124-x86_64-linux/mamba_ssm/_mamba_ssm_r7gpumhmqnfog.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ff591b542c6a1e41c71df196d9a47966f832a77945b26684077b881f9b99e016
3
- size 246547288
 
 
 
 
build/torch26-cxx98-cu124-x86_64-linux/mamba_ssm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _mamba_ssm_r7gpumhmqnfog
3
- ops = torch.ops._mamba_ssm_r7gpumhmqnfog
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_mamba_ssm_r7gpumhmqnfog::{op_name}"
 
1
  import torch
2
+ from . import _mamba_ssm_85e64a5
3
+ ops = torch.ops._mamba_ssm_85e64a5
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_mamba_ssm_85e64a5::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:956f952fefd39eaed6eab1fde0f2dbb4cf6d2004cf986f40392c33ef745f084a
3
+ size 250017672
build/torch26-cxx98-cu126-x86_64-linux/mamba_ssm/_mamba_ssm_ojx7o3olgtezs.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b611bba2ae9af44b053ec900355dbd27ccb64d0fa85c732a261cb0b0304df24d
3
- size 247576472
 
 
 
 
build/torch26-cxx98-cu126-x86_64-linux/mamba_ssm/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _mamba_ssm_ojx7o3olgtezs
3
- ops = torch.ops._mamba_ssm_ojx7o3olgtezs
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_mamba_ssm_ojx7o3olgtezs::{op_name}"
 
1
  import torch
2
+ from . import _mamba_ssm_85e64a5
3
+ ops = torch.ops._mamba_ssm_85e64a5
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_mamba_ssm_85e64a5::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "2.2.4"
2
+
3
+ from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
+ from .modules.mamba_simple import Mamba
5
+ from .modules.mamba2 import Mamba2
6
+ from .models.mixer_seq_simple import MambaLMHeadModel
7
+
8
+ __all__ = [
9
+ "selective_scan_fn",
10
+ "mamba_inner_fn",
11
+ "Mamba",
12
+ "Mamba2",
13
+ "MambaLMHeadModel",
14
+ ]
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_85e64a5.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd30bbcf05aa050cdd6472ec40e08254762525e45f3b4209d11dfef629a201fa
3
+ size 261771568
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _mamba_ssm_85e64a5
3
+ ops = torch.ops._mamba_ssm_85e64a5
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_mamba_ssm_85e64a5::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/__init__.py ADDED
File without changes
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/distributed_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.distributed import ProcessGroup
6
+
7
+ # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
8
+ # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
9
+ # version of PyTorch. The following 4 lines are for backward compatibility with
10
+ # older PyTorch.
11
+ if "all_gather_into_tensor" not in dir(torch.distributed):
12
+ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
13
+ if "reduce_scatter_tensor" not in dir(torch.distributed):
14
+ torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
15
+
16
+
17
+ # Raw operation, does not support autograd, but does support async
18
+ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
19
+ world_size = torch.distributed.get_world_size(process_group)
20
+ output = torch.empty(
21
+ world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
22
+ )
23
+ handle = torch.distributed.all_gather_into_tensor(
24
+ output, input_.contiguous(), group=process_group, async_op=async_op
25
+ )
26
+ return output, handle
27
+
28
+
29
+ # Raw operation, does not support autograd, but does support async
30
+ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
31
+ world_size = torch.distributed.get_world_size(process_group)
32
+ assert input_.shape[0] % world_size == 0
33
+ output = torch.empty(
34
+ input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
35
+ )
36
+ handle = torch.distributed.reduce_scatter_tensor(
37
+ output, input_.contiguous(), group=process_group, async_op=async_op
38
+ )
39
+ return output, handle
40
+
41
+
42
+ # Raw operation, does not support autograd, but does support async
43
+ def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
44
+ input_ = input_.contiguous()
45
+ handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
46
+ return input_, handle
47
+
48
+
49
+ class AllGatherFunc(torch.autograd.Function):
50
+ """Gather the input from sequence parallel region and concatenate."""
51
+
52
+ @staticmethod
53
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
54
+ ctx.process_group = process_group
55
+ output, _ = all_gather_raw(input_, process_group)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output: Tensor):
60
+ grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
61
+ return grad_input, None
62
+
63
+
64
+ # Supports autograd, but does not support async
65
+ all_gather = AllGatherFunc.apply
66
+
67
+
68
+ class ReduceScatterFunc(torch.autograd.Function):
69
+ """Reduce scatter the input from the sequence parallel region and concatenate."""
70
+
71
+ @staticmethod
72
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
73
+ ctx.process_group = process_group
74
+ output, _ = reduce_scatter_raw(input_, process_group)
75
+ return output
76
+
77
+ @staticmethod
78
+ def backward(ctx, grad_output: Tensor):
79
+ grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
80
+ return grad_input, None
81
+
82
+
83
+ # Supports autograd, but does not support async
84
+ reduce_scatter = ReduceScatterFunc.apply
85
+
86
+
87
+ class AllReduceFunc(torch.autograd.Function):
88
+ """Gather the input from sequence parallel region and concatenate."""
89
+
90
+ @staticmethod
91
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
92
+ ctx.process_group = process_group
93
+ output, _ = all_reduce_raw(input_, process_group)
94
+ return output
95
+
96
+ @staticmethod
97
+ def backward(ctx, grad_output: Tensor):
98
+ return grad_output, None
99
+
100
+
101
+ # Supports autograd, but does not support async
102
+ all_reduce = AllReduceFunc.apply
103
+
104
+
105
+ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
106
+ # We want to iterate over parameters with _shared_params=True in the same order,
107
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
108
+ pamams_shared = {
109
+ name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
110
+ }
111
+ for _, p in sorted(pamams_shared.items()):
112
+ with torch.no_grad():
113
+ # Broadcast needs src to be global rank, not group rank
114
+ torch.distributed.broadcast(
115
+ p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
116
+ )
117
+
118
+
119
+ # Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
120
+ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
121
+ # We want to iterate over parameters with _sequence_parallel=True in the same order,
122
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
123
+ params_seqparallel = {
124
+ name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
125
+ }
126
+ grads = [p.grad for _, p in sorted(params_seqparallel.items())]
127
+ if grads:
128
+ with torch.no_grad():
129
+ coalesced = torch._utils._flatten_dense_tensors(grads)
130
+ torch.distributed.all_reduce(coalesced, group=process_group)
131
+ for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
132
+ buf.copy_(synced)
133
+
134
+
135
+ def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
136
+ """Get the dim for the local rank derived from splitting dim on world_size processes.
137
+
138
+ The split may not be even across the world_size processes.
139
+ """
140
+ multiple = dim // multiple_of
141
+ div = multiple // world_size
142
+ mod = multiple % world_size
143
+ local_multiple = div + int(local_rank < mod)
144
+ return local_multiple * multiple_of
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from torch.distributed import ProcessGroup
10
+ from ..utils.torch import custom_bwd, custom_fwd
11
+
12
+ from einops import rearrange
13
+
14
+ from ..distributed.distributed_utils import (
15
+ all_gather_raw,
16
+ all_reduce,
17
+ all_reduce_raw,
18
+ reduce_scatter,
19
+ reduce_scatter_raw,
20
+ )
21
+
22
+
23
+ class ParallelLinearFunc(torch.autograd.Function):
24
+ @staticmethod
25
+ @custom_fwd
26
+ def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
27
+ """
28
+ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
29
+ with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
30
+ """
31
+ ctx.compute_weight_gradient = weight.requires_grad
32
+ ctx.process_group = process_group
33
+ ctx.sequence_parallel = sequence_parallel
34
+
35
+ if torch.is_autocast_enabled():
36
+ x = x.to(dtype=torch.get_autocast_gpu_dtype())
37
+ x = x.contiguous()
38
+ if process_group is not None and sequence_parallel:
39
+ # We want to kick off the all_gather early, before weight dtype conversion
40
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
41
+ else:
42
+ total_x = x
43
+
44
+ if torch.is_autocast_enabled():
45
+ weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
46
+ bias = (
47
+ bias.to(dtype=torch.get_autocast_gpu_dtype())
48
+ if bias is not None
49
+ else None
50
+ )
51
+ weight = weight.contiguous()
52
+ if process_group is not None and sequence_parallel:
53
+ handle_x.wait()
54
+ batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
55
+ batch_dim = batch_shape.numel()
56
+ # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
57
+ output = F.linear(total_x, weight, bias)
58
+ if ctx.compute_weight_gradient:
59
+ ctx.save_for_backward(x, weight)
60
+ else:
61
+ ctx.save_for_backward(weight)
62
+ return output
63
+
64
+ @staticmethod
65
+ @custom_bwd
66
+ def backward(ctx, grad_output):
67
+ grad_output = grad_output.contiguous()
68
+ process_group = ctx.process_group
69
+ sequence_parallel = ctx.sequence_parallel
70
+ if ctx.compute_weight_gradient:
71
+ x, weight = ctx.saved_tensors
72
+ if process_group is not None and sequence_parallel:
73
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
74
+ else:
75
+ total_x = x
76
+ else:
77
+ (weight,) = ctx.saved_tensors
78
+ total_x = None
79
+ batch_shape = grad_output.shape[:-1]
80
+ batch_dim = batch_shape.numel()
81
+ grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
82
+ if ctx.needs_input_grad[0]:
83
+ grad_input = F.linear(grad_output, weight.t())
84
+ grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
85
+ if process_group is not None:
86
+ reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
87
+ grad_input, handle_grad_input = reduce_fn(
88
+ grad_input, process_group, async_op=True
89
+ )
90
+ else:
91
+ grad_input = None
92
+ if ctx.needs_input_grad[1]:
93
+ assert ctx.compute_weight_gradient
94
+ if process_group is not None and sequence_parallel:
95
+ handle_x.wait()
96
+ grad_weight = torch.einsum(
97
+ "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
98
+ )
99
+ else:
100
+ grad_weight = None
101
+ grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
102
+ if process_group is not None and ctx.needs_input_grad[0]:
103
+ handle_grad_input.wait()
104
+ return grad_input, grad_weight, grad_bias, None, None
105
+
106
+
107
+ def parallel_linear_func(
108
+ x: Tensor,
109
+ weight: Tensor,
110
+ bias: Optional[Tensor] = None,
111
+ process_group: Optional[ProcessGroup] = None,
112
+ sequence_parallel: bool = True,
113
+ ):
114
+ return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
115
+
116
+
117
+ class ColumnParallelLinear(nn.Linear):
118
+ def __init__(
119
+ self,
120
+ in_features: int,
121
+ out_features: int,
122
+ process_group: ProcessGroup,
123
+ bias: bool = True,
124
+ sequence_parallel=True,
125
+ multiple_of=1,
126
+ device=None,
127
+ dtype=None,
128
+ ) -> None:
129
+ world_size = torch.distributed.get_world_size(process_group)
130
+ if out_features % multiple_of:
131
+ raise ValueError(
132
+ f"out_features ({out_features}) must be a multiple of {multiple_of}"
133
+ )
134
+ multiple = out_features // multiple_of
135
+ # We want to split @multiple across world_size, but it could be an uneven split
136
+ div = multiple // world_size
137
+ mod = multiple % world_size
138
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
139
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
140
+ super().__init__(
141
+ in_features,
142
+ local_multiple * multiple_of,
143
+ bias=bias,
144
+ device=device,
145
+ dtype=dtype,
146
+ )
147
+ self.process_group = process_group
148
+ self.sequence_parallel = sequence_parallel
149
+
150
+ def forward(self, x):
151
+ # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
152
+ # we do an all_gather of x before doing the matmul.
153
+ # If not, then the input is already gathered.
154
+ return parallel_linear_func(
155
+ x,
156
+ self.weight,
157
+ self.bias,
158
+ process_group=self.process_group,
159
+ sequence_parallel=self.sequence_parallel,
160
+ )
161
+
162
+
163
+ class RowParallelLinear(nn.Linear):
164
+ def __init__(
165
+ self,
166
+ in_features: int,
167
+ out_features: int,
168
+ process_group: ProcessGroup,
169
+ bias: bool = True,
170
+ sequence_parallel=True,
171
+ multiple_of=1,
172
+ device=None,
173
+ dtype=None,
174
+ ) -> None:
175
+ world_size = torch.distributed.get_world_size(process_group)
176
+ rank = torch.distributed.get_rank(process_group)
177
+ if in_features % multiple_of:
178
+ raise ValueError(
179
+ f"in_features ({in_features}) must be a multiple of {multiple_of}"
180
+ )
181
+ multiple = in_features // multiple_of
182
+ # We want to split @multiple across world_size, but it could be an uneven split
183
+ div = multiple // world_size
184
+ mod = multiple % world_size
185
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
186
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
187
+ # Only rank 0 will have bias
188
+ super().__init__(
189
+ local_multiple * multiple_of,
190
+ out_features,
191
+ bias=bias and rank == 0,
192
+ device=device,
193
+ dtype=dtype,
194
+ )
195
+ self.process_group = process_group
196
+ self.sequence_parallel = sequence_parallel
197
+
198
+ def forward(self, x):
199
+ """
200
+ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
201
+ a reduce_scatter of the result.
202
+ """
203
+ out = parallel_linear_func(x, self.weight, self.bias)
204
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
205
+ return reduce_fn(out, self.process_group)
206
+
207
+
208
+ class VocabParallelEmbedding(nn.Embedding):
209
+ def __init__(
210
+ self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
211
+ ):
212
+ self.process_group = process_group
213
+ if process_group is not None:
214
+ world_size = torch.distributed.get_world_size(process_group)
215
+ if num_embeddings % world_size != 0:
216
+ raise ValueError(
217
+ f"num_embeddings ({num_embeddings}) must be divisible by "
218
+ f"world_size ({world_size})"
219
+ )
220
+ if world_size > 1 and padding_idx is not None:
221
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
222
+ else:
223
+ world_size = 1
224
+ super().__init__(
225
+ num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
226
+ )
227
+
228
+ def forward(self, input: Tensor) -> Tensor:
229
+ if self.process_group is None:
230
+ return super().forward(input)
231
+ else:
232
+ rank = torch.distributed.get_rank(self.process_group)
233
+ vocab_size = self.num_embeddings
234
+ vocab_start_index, vocab_end_index = (
235
+ rank * vocab_size,
236
+ (rank + 1) * vocab_size,
237
+ )
238
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
239
+ input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
240
+ input = input - vocab_start_index
241
+ input[input_ids_mask] = 0
242
+ embeddings = super().forward(input)
243
+ embeddings[input_ids_mask] = 0.0
244
+ return embeddings
245
+
246
+
247
+ class ColumnParallelEmbedding(nn.Embedding):
248
+ def __init__(
249
+ self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
250
+ ):
251
+ self.process_group = process_group
252
+ if process_group is not None:
253
+ world_size = torch.distributed.get_world_size(process_group)
254
+ if embedding_dim % world_size != 0:
255
+ raise ValueError(
256
+ f"embedding_dim ({embedding_dim}) must be divisible by "
257
+ f"world_size ({world_size})"
258
+ )
259
+ else:
260
+ world_size = 1
261
+ super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
262
+
263
+
264
+ class ParallelEmbeddings(nn.Module):
265
+ def __init__(
266
+ self,
267
+ embed_dim,
268
+ vocab_size,
269
+ max_position_embeddings,
270
+ process_group,
271
+ padding_idx=None,
272
+ sequence_parallel=True,
273
+ device=None,
274
+ dtype=None,
275
+ ):
276
+ """
277
+ If max_position_embeddings <= 0, there's no position embeddings
278
+ """
279
+ factory_kwargs = {"device": device, "dtype": dtype}
280
+ super().__init__()
281
+ self.process_group = process_group
282
+ self.sequence_parallel = sequence_parallel
283
+ self.word_embeddings = VocabParallelEmbedding(
284
+ vocab_size,
285
+ embed_dim,
286
+ padding_idx=padding_idx,
287
+ process_group=process_group,
288
+ **factory_kwargs,
289
+ )
290
+ self.max_position_embeddings = max_position_embeddings
291
+ if self.max_position_embeddings > 0:
292
+ self.position_embeddings = ColumnParallelEmbedding(
293
+ max_position_embeddings,
294
+ embed_dim,
295
+ process_group=process_group,
296
+ **factory_kwargs,
297
+ )
298
+
299
+ def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
300
+ """
301
+ input_ids: (batch, seqlen)
302
+ position_ids: (batch, seqlen)
303
+ """
304
+ batch_size, seqlen = input_ids.shape
305
+ world_size = torch.distributed.get_world_size(self.process_group)
306
+ embeddings = self.word_embeddings(input_ids)
307
+ if self.max_position_embeddings > 0:
308
+ if position_ids is None:
309
+ position_ids = torch.arange(
310
+ seqlen, dtype=torch.long, device=input_ids.device
311
+ )
312
+ position_embeddings = self.position_embeddings(position_ids)
313
+ if world_size <= 1:
314
+ embeddings = embeddings + position_embeddings
315
+ else:
316
+ partition_dim = self.position_embeddings.embedding_dim
317
+ rank = torch.distributed.get_rank(self.process_group)
318
+ embeddings[
319
+ ..., rank * partition_dim : (rank + 1) * partition_dim
320
+ ] += position_embeddings
321
+ if combine_batch_seqlen_dim:
322
+ embeddings = rearrange(embeddings, "b s d -> (b s) d")
323
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
324
+ return (
325
+ embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
326
+ )
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/models/__init__.py ADDED
File without changes
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/models/config_mamba.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class MambaConfig:
6
+
7
+ d_model: int = 2560
8
+ d_intermediate: int = 0
9
+ n_layer: int = 64
10
+ vocab_size: int = 50277
11
+ ssm_cfg: dict = field(default_factory=dict)
12
+ attn_layer_idx: list = field(default_factory=list)
13
+ attn_cfg: dict = field(default_factory=dict)
14
+ rms_norm: bool = True
15
+ residual_in_fp32: bool = True
16
+ fused_add_norm: bool = True
17
+ pad_vocab_size_multiple: int = 8
18
+ tie_embeddings: bool = True
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+ import json
6
+ import os
7
+ import copy
8
+
9
+ from collections import namedtuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from .config_mamba import MambaConfig
15
+ from ..modules.mamba_simple import Mamba
16
+ from ..modules.mamba2 import Mamba2
17
+ from ..modules.mha import MHA
18
+ from ..modules.mlp import GatedMLP
19
+ from ..modules.block import Block
20
+ from ..utils.generation import GenerationMixin
21
+ from ..utils.hf import load_config_hf, load_state_dict_hf
22
+
23
+ try:
24
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
25
+ except ImportError:
26
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
27
+
28
+
29
+ def create_block(
30
+ d_model,
31
+ d_intermediate,
32
+ ssm_cfg=None,
33
+ attn_layer_idx=None,
34
+ attn_cfg=None,
35
+ norm_epsilon=1e-5,
36
+ rms_norm=False,
37
+ residual_in_fp32=False,
38
+ fused_add_norm=False,
39
+ layer_idx=None,
40
+ device=None,
41
+ dtype=None,
42
+ ):
43
+ if ssm_cfg is None:
44
+ ssm_cfg = {}
45
+ if attn_layer_idx is None:
46
+ attn_layer_idx = []
47
+ if attn_cfg is None:
48
+ attn_cfg = {}
49
+ factory_kwargs = {"device": device, "dtype": dtype}
50
+ if layer_idx not in attn_layer_idx:
51
+ # Create a copy of the config to modify
52
+ ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
53
+ ssm_layer = ssm_cfg.pop("layer", "Mamba1")
54
+ if ssm_layer not in ["Mamba1", "Mamba2"]:
55
+ raise ValueError(
56
+ f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
57
+ )
58
+ mixer_cls = partial(
59
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
60
+ layer_idx=layer_idx,
61
+ **ssm_cfg,
62
+ **factory_kwargs,
63
+ )
64
+ else:
65
+ mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
66
+ norm_cls = partial(
67
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
68
+ )
69
+ if d_intermediate == 0:
70
+ mlp_cls = nn.Identity
71
+ else:
72
+ mlp_cls = partial(
73
+ GatedMLP,
74
+ hidden_features=d_intermediate,
75
+ out_features=d_model,
76
+ **factory_kwargs,
77
+ )
78
+ block = Block(
79
+ d_model,
80
+ mixer_cls,
81
+ mlp_cls,
82
+ norm_cls=norm_cls,
83
+ fused_add_norm=fused_add_norm,
84
+ residual_in_fp32=residual_in_fp32,
85
+ )
86
+ block.layer_idx = layer_idx
87
+ return block
88
+
89
+
90
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
91
+ def _init_weights(
92
+ module,
93
+ n_layer,
94
+ initializer_range=0.02, # Now only used for embedding layer.
95
+ rescale_prenorm_residual=True,
96
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
97
+ ):
98
+ if isinstance(module, nn.Linear):
99
+ if module.bias is not None:
100
+ if not getattr(module.bias, "_no_reinit", False):
101
+ nn.init.zeros_(module.bias)
102
+ elif isinstance(module, nn.Embedding):
103
+ nn.init.normal_(module.weight, std=initializer_range)
104
+
105
+ if rescale_prenorm_residual:
106
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
107
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
108
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
109
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
110
+ #
111
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
112
+ for name, p in module.named_parameters():
113
+ if name in ["out_proj.weight", "fc2.weight"]:
114
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
115
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
116
+ # We need to reinit p since this code could be called multiple times
117
+ # Having just p *= scale would repeatedly scale it down
118
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
119
+ with torch.no_grad():
120
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
121
+
122
+
123
+ class MixerModel(nn.Module):
124
+ def __init__(
125
+ self,
126
+ d_model: int,
127
+ n_layer: int,
128
+ d_intermediate: int,
129
+ vocab_size: int,
130
+ ssm_cfg=None,
131
+ attn_layer_idx=None,
132
+ attn_cfg=None,
133
+ norm_epsilon: float = 1e-5,
134
+ rms_norm: bool = False,
135
+ initializer_cfg=None,
136
+ fused_add_norm=False,
137
+ residual_in_fp32=False,
138
+ device=None,
139
+ dtype=None,
140
+ ) -> None:
141
+ factory_kwargs = {"device": device, "dtype": dtype}
142
+ super().__init__()
143
+ self.residual_in_fp32 = residual_in_fp32
144
+
145
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
146
+
147
+ # We change the order of residual and layer norm:
148
+ # Instead of LN -> Attn / MLP -> Add, we do:
149
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
150
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
151
+ # This is for performance reason: we can fuse add + layer_norm.
152
+ self.fused_add_norm = fused_add_norm
153
+ if self.fused_add_norm:
154
+ if layer_norm_fn is None or rms_norm_fn is None:
155
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
156
+
157
+ self.layers = nn.ModuleList(
158
+ [
159
+ create_block(
160
+ d_model,
161
+ d_intermediate=d_intermediate,
162
+ ssm_cfg=ssm_cfg,
163
+ attn_layer_idx=attn_layer_idx,
164
+ attn_cfg=attn_cfg,
165
+ norm_epsilon=norm_epsilon,
166
+ rms_norm=rms_norm,
167
+ residual_in_fp32=residual_in_fp32,
168
+ fused_add_norm=fused_add_norm,
169
+ layer_idx=i,
170
+ **factory_kwargs,
171
+ )
172
+ for i in range(n_layer)
173
+ ]
174
+ )
175
+
176
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
177
+ d_model, eps=norm_epsilon, **factory_kwargs
178
+ )
179
+
180
+ self.apply(
181
+ partial(
182
+ _init_weights,
183
+ n_layer=n_layer,
184
+ **(initializer_cfg if initializer_cfg is not None else {}),
185
+ n_residuals_per_layer=(
186
+ 1 if d_intermediate == 0 else 2
187
+ ), # 2 if we have MLP
188
+ )
189
+ )
190
+
191
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
192
+ return {
193
+ i: layer.allocate_inference_cache(
194
+ batch_size, max_seqlen, dtype=dtype, **kwargs
195
+ )
196
+ for i, layer in enumerate(self.layers)
197
+ }
198
+
199
+ def forward(self, input_ids, inference_params=None, **mixer_kwargs):
200
+ hidden_states = self.embedding(input_ids)
201
+ residual = None
202
+ for layer in self.layers:
203
+ hidden_states, residual = layer(
204
+ hidden_states,
205
+ residual,
206
+ inference_params=inference_params,
207
+ **mixer_kwargs,
208
+ )
209
+ if not self.fused_add_norm:
210
+ residual = (
211
+ (hidden_states + residual) if residual is not None else hidden_states
212
+ )
213
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
214
+ else:
215
+ # Set prenorm=False here since we don't need the residual
216
+ hidden_states = layer_norm_fn(
217
+ hidden_states,
218
+ self.norm_f.weight,
219
+ self.norm_f.bias,
220
+ eps=self.norm_f.eps,
221
+ residual=residual,
222
+ prenorm=False,
223
+ residual_in_fp32=self.residual_in_fp32,
224
+ is_rms_norm=isinstance(self.norm_f, RMSNorm),
225
+ )
226
+ return hidden_states
227
+
228
+
229
+ class MambaLMHeadModel(nn.Module, GenerationMixin):
230
+
231
+ def __init__(
232
+ self,
233
+ config: MambaConfig,
234
+ initializer_cfg=None,
235
+ device=None,
236
+ dtype=None,
237
+ ) -> None:
238
+ self.config = config
239
+ d_model = config.d_model
240
+ n_layer = config.n_layer
241
+ d_intermediate = config.d_intermediate
242
+ vocab_size = config.vocab_size
243
+ ssm_cfg = config.ssm_cfg
244
+ attn_layer_idx = config.attn_layer_idx
245
+ attn_cfg = config.attn_cfg
246
+ rms_norm = config.rms_norm
247
+ residual_in_fp32 = config.residual_in_fp32
248
+ fused_add_norm = config.fused_add_norm
249
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
250
+ factory_kwargs = {"device": device, "dtype": dtype}
251
+
252
+ super().__init__()
253
+ if vocab_size % pad_vocab_size_multiple != 0:
254
+ vocab_size += pad_vocab_size_multiple - (
255
+ vocab_size % pad_vocab_size_multiple
256
+ )
257
+ self.backbone = MixerModel(
258
+ d_model=d_model,
259
+ n_layer=n_layer,
260
+ d_intermediate=d_intermediate,
261
+ vocab_size=vocab_size,
262
+ ssm_cfg=ssm_cfg,
263
+ attn_layer_idx=attn_layer_idx,
264
+ attn_cfg=attn_cfg,
265
+ rms_norm=rms_norm,
266
+ initializer_cfg=initializer_cfg,
267
+ fused_add_norm=fused_add_norm,
268
+ residual_in_fp32=residual_in_fp32,
269
+ **factory_kwargs,
270
+ )
271
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
272
+
273
+ # Initialize weights and apply final processing
274
+ self.apply(
275
+ partial(
276
+ _init_weights,
277
+ n_layer=n_layer,
278
+ **(initializer_cfg if initializer_cfg is not None else {}),
279
+ )
280
+ )
281
+ self.tie_weights()
282
+
283
+ def tie_weights(self):
284
+ if self.config.tie_embeddings:
285
+ self.lm_head.weight = self.backbone.embedding.weight
286
+
287
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
288
+ return self.backbone.allocate_inference_cache(
289
+ batch_size, max_seqlen, dtype=dtype, **kwargs
290
+ )
291
+
292
+ def forward(
293
+ self,
294
+ input_ids,
295
+ position_ids=None,
296
+ inference_params=None,
297
+ num_last_tokens=0,
298
+ **mixer_kwargs,
299
+ ):
300
+ """
301
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
302
+ num_last_tokens: if > 0, only return the logits for the last n tokens
303
+ """
304
+ hidden_states = self.backbone(
305
+ input_ids, inference_params=inference_params, **mixer_kwargs
306
+ )
307
+ if num_last_tokens > 0:
308
+ hidden_states = hidden_states[:, -num_last_tokens:]
309
+ lm_logits = self.lm_head(hidden_states)
310
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
311
+ return CausalLMOutput(logits=lm_logits)
312
+
313
+ @classmethod
314
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
315
+ config_data = load_config_hf(pretrained_model_name)
316
+ config = MambaConfig(**config_data)
317
+ model = cls(config, device=device, dtype=dtype, **kwargs)
318
+ model.load_state_dict(
319
+ load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
320
+ )
321
+ return model
322
+
323
+ def save_pretrained(self, save_directory):
324
+ """
325
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
326
+ Save the model and its configuration file to a directory.
327
+ """
328
+ # Ensure save_directory exists
329
+ os.makedirs(save_directory, exist_ok=True)
330
+
331
+ # Save the model's state_dict
332
+ model_path = os.path.join(save_directory, "pytorch_model.bin")
333
+ torch.save(self.state_dict(), model_path)
334
+
335
+ # Save the configuration of the model
336
+ config_path = os.path.join(save_directory, "config.json")
337
+ with open(config_path, "w") as f:
338
+ json.dump(self.config.__dict__, f, indent=4)
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/__init__.py ADDED
File without changes
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/block.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+
7
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn
8
+
9
+
10
+ class Block(nn.Module):
11
+ def __init__(
12
+ self,
13
+ dim,
14
+ mixer_cls,
15
+ mlp_cls,
16
+ norm_cls=nn.LayerNorm,
17
+ fused_add_norm=False,
18
+ residual_in_fp32=False,
19
+ ):
20
+ """
21
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
22
+
23
+ This Block has a slightly different structure compared to a regular
24
+ prenorm Transformer block.
25
+ The standard block is: LN -> MHA/MLP -> Add.
26
+ [Ref: https://arxiv.org/abs/2002.04745]
27
+ Here we have: Add -> LN -> Mixer, returning both
28
+ the hidden_states (output of the mixer) and the residual.
29
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
30
+ The residual needs to be provided (except for the very first block).
31
+ """
32
+ super().__init__()
33
+ self.residual_in_fp32 = residual_in_fp32
34
+ self.fused_add_norm = fused_add_norm
35
+ self.norm = norm_cls(dim)
36
+ self.mixer = mixer_cls(dim)
37
+ if mlp_cls is not nn.Identity:
38
+ self.norm2 = norm_cls(dim)
39
+ self.mlp = mlp_cls(dim)
40
+ else:
41
+ self.mlp = None
42
+ if self.fused_add_norm:
43
+ assert RMSNorm is not None, "RMSNorm import fails"
44
+ assert isinstance(
45
+ self.norm, (nn.LayerNorm, RMSNorm)
46
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
47
+
48
+ def forward(
49
+ self,
50
+ hidden_states: Tensor,
51
+ residual: Optional[Tensor] = None,
52
+ inference_params=None,
53
+ **mixer_kwargs
54
+ ):
55
+ r"""Pass the input through the encoder layer.
56
+
57
+ Args:
58
+ hidden_states: the sequence to the encoder layer (required).
59
+ residual: hidden_states = Mixer(LN(residual))
60
+ """
61
+ if not self.fused_add_norm:
62
+ residual = (
63
+ (hidden_states + residual) if residual is not None else hidden_states
64
+ )
65
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
66
+ if self.residual_in_fp32:
67
+ residual = residual.to(torch.float32)
68
+ else:
69
+ hidden_states, residual = layer_norm_fn(
70
+ hidden_states,
71
+ self.norm.weight,
72
+ self.norm.bias,
73
+ residual=residual,
74
+ prenorm=True,
75
+ residual_in_fp32=self.residual_in_fp32,
76
+ eps=self.norm.eps,
77
+ is_rms_norm=isinstance(self.norm, RMSNorm),
78
+ )
79
+ hidden_states = self.mixer(
80
+ hidden_states, inference_params=inference_params, **mixer_kwargs
81
+ )
82
+
83
+ if self.mlp is not None:
84
+ if not self.fused_add_norm:
85
+ residual = hidden_states + residual
86
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
87
+ if self.residual_in_fp32:
88
+ residual = residual.to(torch.float32)
89
+ else:
90
+ hidden_states, residual = layer_norm_fn(
91
+ hidden_states,
92
+ self.norm2.weight,
93
+ self.norm2.bias,
94
+ residual=residual,
95
+ prenorm=True,
96
+ residual_in_fp32=self.residual_in_fp32,
97
+ eps=self.norm2.eps,
98
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
99
+ )
100
+ hidden_states = self.mlp(hidden_states)
101
+
102
+ return hidden_states, residual
103
+
104
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
105
+ return self.mixer.allocate_inference_cache(
106
+ batch_size, max_seqlen, dtype=dtype, **kwargs
107
+ )
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat
10
+
11
+ try:
12
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
13
+ except ImportError:
14
+ causal_conv1d_fn, causal_conv1d_update = None, None
15
+
16
+ try:
17
+ from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
18
+ except ImportError:
19
+ causal_conv1d_varlen_states = None
20
+
21
+ try:
22
+ from ..ops.triton.selective_state_update import selective_state_update
23
+ except ImportError:
24
+ selective_state_update = None
25
+
26
+ from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated
27
+
28
+ from ..distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
29
+ from ..distributed.distributed_utils import all_reduce, reduce_scatter
30
+
31
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
32
+ from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
33
+
34
+ from huggingface_hub import PyTorchModelHubMixin
35
+
36
+
37
+ class Mamba2(nn.Module, PyTorchModelHubMixin):
38
+ def __init__(
39
+ self,
40
+ d_model,
41
+ d_state=128,
42
+ d_conv=4,
43
+ conv_init=None,
44
+ expand=2,
45
+ headdim=64,
46
+ d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
47
+ ngroups=1,
48
+ A_init_range=(1, 16),
49
+ D_has_hdim=False,
50
+ rmsnorm=True,
51
+ norm_before_gate=False,
52
+ dt_min=0.001,
53
+ dt_max=0.1,
54
+ dt_init_floor=1e-4,
55
+ dt_limit=(0.0, float("inf")),
56
+ bias=False,
57
+ conv_bias=True,
58
+ # Fused kernel and sharding options
59
+ chunk_size=256,
60
+ use_mem_eff_path=True,
61
+ layer_idx=None, # Absorb kwarg for general module
62
+ process_group=None,
63
+ sequence_parallel=True,
64
+ device=None,
65
+ dtype=None,
66
+ ):
67
+ factory_kwargs = {"device": device, "dtype": dtype}
68
+ super().__init__()
69
+ self.d_model = d_model
70
+ self.d_state = d_state
71
+ self.d_conv = d_conv
72
+ self.conv_init = conv_init
73
+ self.expand = expand
74
+ self.process_group = process_group
75
+ self.sequence_parallel = sequence_parallel
76
+ self.world_size = 1 if process_group is None else process_group.size()
77
+ self.local_rank = 0 if process_group is None else process_group.rank()
78
+ self.d_inner = (self.expand * self.d_model) // self.world_size
79
+ assert self.d_inner * self.world_size == self.expand * self.d_model
80
+ self.headdim = headdim
81
+ self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
82
+ assert ngroups % self.world_size == 0
83
+ self.ngroups = ngroups // self.world_size
84
+ assert self.d_ssm % self.headdim == 0
85
+ self.nheads = self.d_ssm // self.headdim
86
+ self.D_has_hdim = D_has_hdim
87
+ self.rmsnorm = rmsnorm
88
+ self.norm_before_gate = norm_before_gate
89
+ self.dt_limit = dt_limit
90
+ self.activation = "silu"
91
+ self.chunk_size = chunk_size
92
+ self.use_mem_eff_path = use_mem_eff_path
93
+ self.layer_idx = layer_idx
94
+
95
+ # Order: [z, x, B, C, dt]
96
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
97
+ if self.process_group is None:
98
+ self.in_proj = nn.Linear(
99
+ self.d_model, d_in_proj, bias=bias, **factory_kwargs
100
+ )
101
+ else:
102
+ self.in_proj = ColumnParallelLinear(
103
+ self.d_model,
104
+ d_in_proj * self.world_size,
105
+ bias=bias,
106
+ process_group=self.process_group,
107
+ sequence_parallel=self.sequence_parallel,
108
+ **factory_kwargs,
109
+ )
110
+
111
+ conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
112
+ self.conv1d = nn.Conv1d(
113
+ in_channels=conv_dim,
114
+ out_channels=conv_dim,
115
+ bias=conv_bias,
116
+ kernel_size=d_conv,
117
+ groups=conv_dim,
118
+ padding=d_conv - 1,
119
+ **factory_kwargs,
120
+ )
121
+ if self.conv_init is not None:
122
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
123
+
124
+ self.act = nn.SiLU()
125
+
126
+ # Initialize log dt bias
127
+ dt = torch.exp(
128
+ torch.rand(self.nheads, **factory_kwargs)
129
+ * (math.log(dt_max) - math.log(dt_min))
130
+ + math.log(dt_min)
131
+ )
132
+ dt = torch.clamp(dt, min=dt_init_floor)
133
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
134
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
135
+ self.dt_bias = nn.Parameter(inv_dt)
136
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
137
+ # name.endswith("bias") in param_grouping.py
138
+ self.dt_bias._no_weight_decay = True
139
+
140
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
141
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
142
+ *A_init_range
143
+ )
144
+ A_log = torch.log(A).to(dtype=dtype)
145
+ self.A_log = nn.Parameter(A_log)
146
+ self.A_log._no_weight_decay = True
147
+
148
+ # D "skip" parameter
149
+ self.D = nn.Parameter(
150
+ torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)
151
+ )
152
+ self.D._no_weight_decay = True
153
+
154
+ if self.rmsnorm:
155
+ assert RMSNormGated is not None
156
+ self.norm = RMSNormGated(
157
+ self.d_ssm,
158
+ eps=1e-5,
159
+ norm_before_gate=self.norm_before_gate,
160
+ group_size=self.d_ssm // ngroups,
161
+ **factory_kwargs,
162
+ )
163
+
164
+ if self.process_group is None:
165
+ self.out_proj = nn.Linear(
166
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
167
+ )
168
+ else:
169
+ self.out_proj = RowParallelLinear(
170
+ self.d_inner * self.world_size,
171
+ self.d_model,
172
+ bias=bias,
173
+ process_group=self.process_group,
174
+ sequence_parallel=self.sequence_parallel,
175
+ **factory_kwargs,
176
+ )
177
+
178
+ def forward(
179
+ self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None
180
+ ):
181
+ """
182
+ u: (batch, seqlen, hidden_dim) if seqlen=None.
183
+ If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
184
+ split u during sequence parallel, we split the batch * seqlen dimension
185
+ (in case batch is small).
186
+ Returns: same shape as u
187
+ """
188
+ seqlen_og = seqlen
189
+ if seqlen is None:
190
+ batch, seqlen, dim = u.shape
191
+ else:
192
+ batch_seqlen, dim = u.shape
193
+ batch = batch_seqlen // seqlen
194
+
195
+ conv_state, ssm_state = None, None
196
+ if inference_params is not None:
197
+ inference_batch = (
198
+ cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
199
+ )
200
+ conv_state, ssm_state = self._get_states_from_cache(
201
+ inference_params, inference_batch
202
+ )
203
+ if inference_params.seqlen_offset > 0:
204
+ # The states are updated inplace
205
+ out, _, _ = self.step(u, conv_state, ssm_state)
206
+ return out
207
+
208
+ zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
209
+ if seqlen_og is not None:
210
+ zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
211
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
212
+ A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
213
+ dt_limit_kwargs = (
214
+ {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
215
+ )
216
+ if self.use_mem_eff_path and inference_params is None:
217
+ out = mamba_split_conv1d_scan_combined(
218
+ zxbcdt,
219
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
220
+ self.conv1d.bias,
221
+ self.dt_bias,
222
+ A,
223
+ D=(
224
+ rearrange(self.D, "(h p) -> h p", p=self.headdim)
225
+ if self.D_has_hdim
226
+ else self.D
227
+ ),
228
+ chunk_size=self.chunk_size,
229
+ seq_idx=seq_idx,
230
+ activation=self.activation,
231
+ rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
232
+ rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
233
+ outproj_weight=self.out_proj.weight,
234
+ outproj_bias=self.out_proj.bias,
235
+ headdim=None if self.D_has_hdim else self.headdim,
236
+ ngroups=self.ngroups,
237
+ norm_before_gate=self.norm_before_gate,
238
+ **dt_limit_kwargs,
239
+ )
240
+ if seqlen_og is not None:
241
+ out = rearrange(out, "b l d -> (b l) d")
242
+ if self.process_group is not None:
243
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
244
+ out = reduce_fn(out, self.process_group)
245
+ else:
246
+ d_mlp = (
247
+ zxbcdt.shape[-1]
248
+ - 2 * self.d_ssm
249
+ - 2 * self.ngroups * self.d_state
250
+ - self.nheads
251
+ ) // 2
252
+ z0, x0, z, xBC, dt = torch.split(
253
+ zxbcdt,
254
+ [
255
+ d_mlp,
256
+ d_mlp,
257
+ self.d_ssm,
258
+ self.d_ssm + 2 * self.ngroups * self.d_state,
259
+ self.nheads,
260
+ ],
261
+ dim=-1,
262
+ )
263
+ if conv_state is not None:
264
+ if cu_seqlens is None:
265
+ # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
266
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
267
+ xBC_t = rearrange(xBC, "b l d -> b d l")
268
+ conv_state.copy_(
269
+ F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))
270
+ ) # Update state (B D W)
271
+ else:
272
+ assert (
273
+ causal_conv1d_varlen_states is not None
274
+ ), "varlen inference requires causal_conv1d package"
275
+ assert (
276
+ batch == 1
277
+ ), "varlen inference only supports batch dimension 1"
278
+ conv_varlen_states = causal_conv1d_varlen_states(
279
+ xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
280
+ )
281
+ conv_state.copy_(conv_varlen_states)
282
+ assert self.activation in ["silu", "swish"]
283
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
284
+ assert (
285
+ seq_idx is None
286
+ ), "varlen conv1d requires the causal_conv1d package"
287
+ xBC = self.act(
288
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[
289
+ :, : -(self.d_conv - 1)
290
+ ]
291
+ ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
292
+ else:
293
+ xBC = causal_conv1d_fn(
294
+ xBC.transpose(1, 2),
295
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
296
+ bias=self.conv1d.bias,
297
+ activation=self.activation,
298
+ seq_idx=seq_idx,
299
+ ).transpose(1, 2)
300
+ x, B, C = torch.split(
301
+ xBC,
302
+ [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
303
+ dim=-1,
304
+ )
305
+ y = mamba_chunk_scan_combined(
306
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
307
+ dt,
308
+ A,
309
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
310
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
311
+ chunk_size=self.chunk_size,
312
+ D=(
313
+ rearrange(self.D, "(h p) -> h p", p=self.headdim)
314
+ if self.D_has_hdim
315
+ else self.D
316
+ ),
317
+ z=(
318
+ rearrange(z, "b l (h p) -> b l h p", p=self.headdim)
319
+ if not self.rmsnorm
320
+ else None
321
+ ),
322
+ dt_bias=self.dt_bias,
323
+ dt_softplus=True,
324
+ seq_idx=seq_idx,
325
+ cu_seqlens=cu_seqlens,
326
+ **dt_limit_kwargs,
327
+ return_final_states=ssm_state is not None,
328
+ return_varlen_states=cu_seqlens is not None
329
+ and inference_params is not None,
330
+ )
331
+ if ssm_state is not None:
332
+ y, last_state, *rest = y
333
+ if cu_seqlens is None:
334
+ ssm_state.copy_(last_state)
335
+ else:
336
+ varlen_states = rest[0]
337
+ ssm_state.copy_(varlen_states)
338
+ y = rearrange(y, "b l h p -> b l (h p)")
339
+ if self.rmsnorm:
340
+ y = self.norm(y, z)
341
+ if d_mlp > 0:
342
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
343
+ if seqlen_og is not None:
344
+ y = rearrange(y, "b l d -> (b l) d")
345
+ out = self.out_proj(y)
346
+ return out
347
+
348
+ def step(self, hidden_states, conv_state, ssm_state):
349
+ dtype = hidden_states.dtype
350
+ assert (
351
+ hidden_states.shape[1] == 1
352
+ ), "Only support decoding with 1 token at a time for now"
353
+ zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
354
+ d_mlp = (
355
+ zxbcdt.shape[-1]
356
+ - 2 * self.d_ssm
357
+ - 2 * self.ngroups * self.d_state
358
+ - self.nheads
359
+ ) // 2
360
+ z0, x0, z, xBC, dt = torch.split(
361
+ zxbcdt,
362
+ [
363
+ d_mlp,
364
+ d_mlp,
365
+ self.d_ssm,
366
+ self.d_ssm + 2 * self.ngroups * self.d_state,
367
+ self.nheads,
368
+ ],
369
+ dim=-1,
370
+ )
371
+
372
+ # Conv step
373
+ if causal_conv1d_update is None:
374
+ conv_state.copy_(
375
+ torch.roll(conv_state, shifts=-1, dims=-1)
376
+ ) # Update state (B D W)
377
+ conv_state[:, :, -1] = xBC
378
+ xBC = torch.sum(
379
+ conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
380
+ ) # (B D)
381
+ if self.conv1d.bias is not None:
382
+ xBC = xBC + self.conv1d.bias
383
+ xBC = self.act(xBC).to(dtype=dtype)
384
+ else:
385
+ xBC = causal_conv1d_update(
386
+ xBC,
387
+ conv_state,
388
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
389
+ self.conv1d.bias,
390
+ self.activation,
391
+ )
392
+
393
+ x, B, C = torch.split(
394
+ xBC,
395
+ [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
396
+ dim=-1,
397
+ )
398
+ A = -torch.exp(self.A_log.float()) # (nheads,)
399
+
400
+ # SSM step
401
+ if selective_state_update is None:
402
+ assert (
403
+ self.ngroups == 1
404
+ ), "Only support ngroups=1 for this inference code path"
405
+ # Discretize A and B
406
+ dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
407
+ dA = torch.exp(dt * A) # (batch, nheads)
408
+ x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
409
+ dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
410
+ ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
411
+ y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
412
+ y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
413
+ y = rearrange(y, "b h p -> b (h p)")
414
+ if not self.rmsnorm:
415
+ y = y * self.act(z) # (B D)
416
+ else:
417
+ A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(
418
+ dtype=torch.float32
419
+ )
420
+ dt = repeat(dt, "b h -> b h p", p=self.headdim)
421
+ dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
422
+ D = repeat(self.D, "h -> h p", p=self.headdim)
423
+ B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
424
+ C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
425
+ x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
426
+ if not self.rmsnorm:
427
+ z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
428
+ y = selective_state_update(
429
+ ssm_state,
430
+ x_reshaped,
431
+ dt,
432
+ A,
433
+ B,
434
+ C,
435
+ D,
436
+ z=z if not self.rmsnorm else None,
437
+ dt_bias=dt_bias,
438
+ dt_softplus=True,
439
+ )
440
+ y = rearrange(y, "b h p -> b (h p)")
441
+ if self.rmsnorm:
442
+ y = self.norm(y, z)
443
+ if d_mlp > 0:
444
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
445
+ out = self.out_proj(y)
446
+ return out.unsqueeze(1), conv_state, ssm_state
447
+
448
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
449
+ device = self.out_proj.weight.device
450
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
451
+ conv_state = torch.zeros(
452
+ batch_size,
453
+ self.d_conv,
454
+ self.conv1d.weight.shape[0],
455
+ device=device,
456
+ dtype=conv_dtype,
457
+ ).transpose(1, 2)
458
+ ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
459
+ ssm_state = torch.zeros(
460
+ batch_size,
461
+ self.nheads,
462
+ self.headdim,
463
+ self.d_state,
464
+ device=device,
465
+ dtype=ssm_dtype,
466
+ )
467
+ return conv_state, ssm_state
468
+
469
+ def _get_states_from_cache(
470
+ self, inference_params, batch_size, initialize_states=False
471
+ ):
472
+ assert self.layer_idx is not None
473
+ if self.layer_idx not in inference_params.key_value_memory_dict:
474
+ batch_shape = (batch_size,)
475
+ conv_state = torch.zeros(
476
+ batch_size,
477
+ self.d_conv,
478
+ self.conv1d.weight.shape[0],
479
+ device=self.conv1d.weight.device,
480
+ dtype=self.conv1d.weight.dtype,
481
+ ).transpose(1, 2)
482
+ ssm_state = torch.zeros(
483
+ batch_size,
484
+ self.nheads,
485
+ self.headdim,
486
+ self.d_state,
487
+ device=self.in_proj.weight.device,
488
+ dtype=self.in_proj.weight.dtype,
489
+ )
490
+ inference_params.key_value_memory_dict[self.layer_idx] = (
491
+ conv_state,
492
+ ssm_state,
493
+ )
494
+ else:
495
+ conv_state, ssm_state = inference_params.key_value_memory_dict[
496
+ self.layer_idx
497
+ ]
498
+ # TODO: What if batch size changes between generation, and we reuse the same states?
499
+ if initialize_states:
500
+ conv_state.zero_()
501
+ ssm_state.zero_()
502
+ return conv_state, ssm_state
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2_simple.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from einops import rearrange, repeat
9
+
10
+ try:
11
+ from causal_conv1d import causal_conv1d_fn
12
+ except ImportError:
13
+ causal_conv1d_fn = None
14
+
15
+ try:
16
+ from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
17
+ except ImportError:
18
+ RMSNormGated, LayerNorm = None, None
19
+
20
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
21
+ from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
22
+
23
+
24
+ class Mamba2Simple(nn.Module):
25
+ def __init__(
26
+ self,
27
+ d_model,
28
+ d_state=64,
29
+ d_conv=4,
30
+ conv_init=None,
31
+ expand=2,
32
+ headdim=128,
33
+ ngroups=1,
34
+ A_init_range=(1, 16),
35
+ dt_min=0.001,
36
+ dt_max=0.1,
37
+ dt_init_floor=1e-4,
38
+ dt_limit=(0.0, float("inf")),
39
+ learnable_init_states=False,
40
+ activation="swish",
41
+ bias=False,
42
+ conv_bias=True,
43
+ # Fused kernel and sharding options
44
+ chunk_size=256,
45
+ use_mem_eff_path=True,
46
+ layer_idx=None, # Absorb kwarg for general module
47
+ device=None,
48
+ dtype=None,
49
+ ):
50
+ factory_kwargs = {"device": device, "dtype": dtype}
51
+ super().__init__()
52
+ self.d_model = d_model
53
+ self.d_state = d_state
54
+ self.d_conv = d_conv
55
+ self.conv_init = conv_init
56
+ self.expand = expand
57
+ self.d_inner = self.expand * self.d_model
58
+ self.headdim = headdim
59
+ self.ngroups = ngroups
60
+ assert self.d_inner % self.headdim == 0
61
+ self.nheads = self.d_inner // self.headdim
62
+ self.dt_limit = dt_limit
63
+ self.learnable_init_states = learnable_init_states
64
+ self.activation = activation
65
+ self.chunk_size = chunk_size
66
+ self.use_mem_eff_path = use_mem_eff_path
67
+ self.layer_idx = layer_idx
68
+
69
+ # Order: [z, x, B, C, dt]
70
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
71
+ self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
72
+
73
+ conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
74
+ self.conv1d = nn.Conv1d(
75
+ in_channels=conv_dim,
76
+ out_channels=conv_dim,
77
+ bias=conv_bias,
78
+ kernel_size=d_conv,
79
+ groups=conv_dim,
80
+ padding=d_conv - 1,
81
+ **factory_kwargs,
82
+ )
83
+ if self.conv_init is not None:
84
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
85
+ # self.conv1d.weight._no_weight_decay = True
86
+
87
+ if self.learnable_init_states:
88
+ self.init_states = nn.Parameter(
89
+ torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)
90
+ )
91
+ self.init_states._no_weight_decay = True
92
+
93
+ self.act = nn.SiLU()
94
+
95
+ # Initialize log dt bias
96
+ dt = torch.exp(
97
+ torch.rand(self.nheads, **factory_kwargs)
98
+ * (math.log(dt_max) - math.log(dt_min))
99
+ + math.log(dt_min)
100
+ )
101
+ dt = torch.clamp(dt, min=dt_init_floor)
102
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
103
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
104
+ self.dt_bias = nn.Parameter(inv_dt)
105
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
106
+ # name.endswith("bias") in param_grouping.py
107
+ self.dt_bias._no_weight_decay = True
108
+
109
+ # A parameter
110
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
111
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
112
+ *A_init_range
113
+ )
114
+ A_log = torch.log(A).to(dtype=dtype)
115
+ self.A_log = nn.Parameter(A_log)
116
+ # self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
117
+ self.A_log._no_weight_decay = True
118
+
119
+ # D "skip" parameter
120
+ self.D = nn.Parameter(torch.ones(self.nheads, device=device))
121
+ self.D._no_weight_decay = True
122
+
123
+ # Extra normalization layer right before output projection
124
+ assert RMSNormGated is not None
125
+ self.norm = RMSNormGated(
126
+ self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs
127
+ )
128
+
129
+ self.out_proj = nn.Linear(
130
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
131
+ )
132
+
133
+ def forward(self, u, seq_idx=None):
134
+ """
135
+ u: (B, L, D)
136
+ Returns: same shape as u
137
+ """
138
+ batch, seqlen, dim = u.shape
139
+
140
+ zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
141
+ A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
142
+ initial_states = (
143
+ repeat(self.init_states, "... -> b ...", b=batch)
144
+ if self.learnable_init_states
145
+ else None
146
+ )
147
+ dt_limit_kwargs = (
148
+ {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
149
+ )
150
+
151
+ if self.use_mem_eff_path:
152
+ # Fully fused path
153
+ out = mamba_split_conv1d_scan_combined(
154
+ zxbcdt,
155
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
156
+ self.conv1d.bias,
157
+ self.dt_bias,
158
+ A,
159
+ D=self.D,
160
+ chunk_size=self.chunk_size,
161
+ seq_idx=seq_idx,
162
+ activation=self.activation,
163
+ rmsnorm_weight=self.norm.weight,
164
+ rmsnorm_eps=self.norm.eps,
165
+ outproj_weight=self.out_proj.weight,
166
+ outproj_bias=self.out_proj.bias,
167
+ headdim=self.headdim,
168
+ ngroups=self.ngroups,
169
+ norm_before_gate=False,
170
+ initial_states=initial_states,
171
+ **dt_limit_kwargs,
172
+ )
173
+ else:
174
+ z, xBC, dt = torch.split(
175
+ zxbcdt,
176
+ [
177
+ self.d_inner,
178
+ self.d_inner + 2 * self.ngroups * self.d_state,
179
+ self.nheads,
180
+ ],
181
+ dim=-1,
182
+ )
183
+ dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
184
+ assert self.activation in ["silu", "swish"]
185
+
186
+ # 1D Convolution
187
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
188
+ xBC = self.act(
189
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
190
+ ) # (B, L, self.d_inner + 2 * ngroups * d_state)
191
+ xBC = xBC[:, :seqlen, :]
192
+ else:
193
+ xBC = causal_conv1d_fn(
194
+ x=xBC.transpose(1, 2),
195
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
196
+ bias=self.conv1d.bias,
197
+ activation=self.activation,
198
+ ).transpose(1, 2)
199
+
200
+ # Split into 3 main branches: X, B, C
201
+ # These correspond to V, K, Q respectively in the SSM/attention duality
202
+ x, B, C = torch.split(
203
+ xBC,
204
+ [
205
+ self.d_inner,
206
+ self.ngroups * self.d_state,
207
+ self.ngroups * self.d_state,
208
+ ],
209
+ dim=-1,
210
+ )
211
+ y = mamba_chunk_scan_combined(
212
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
213
+ dt,
214
+ A,
215
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
216
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
217
+ chunk_size=self.chunk_size,
218
+ D=self.D,
219
+ z=None,
220
+ seq_idx=seq_idx,
221
+ initial_states=initial_states,
222
+ **dt_limit_kwargs,
223
+ )
224
+ y = rearrange(y, "b l h p -> b l (h p)")
225
+
226
+ # Multiply "gate" branch and apply extra normalization layer
227
+ y = self.norm(y, z)
228
+ out = self.out_proj(y)
229
+ return out
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba_simple.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from einops import rearrange, repeat
12
+
13
+ from ..ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
14
+
15
+ try:
16
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
17
+ except ImportError:
18
+ causal_conv1d_fn, causal_conv1d_update = None, None
19
+
20
+ try:
21
+ from ..ops.triton.selective_state_update import selective_state_update
22
+ except ImportError:
23
+ selective_state_update = None
24
+
25
+ try:
26
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
27
+ except ImportError:
28
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
+
30
+
31
+ class Mamba(nn.Module):
32
+ def __init__(
33
+ self,
34
+ d_model,
35
+ d_state=16,
36
+ d_conv=4,
37
+ expand=2,
38
+ dt_rank="auto",
39
+ dt_min=0.001,
40
+ dt_max=0.1,
41
+ dt_init="random",
42
+ dt_scale=1.0,
43
+ dt_init_floor=1e-4,
44
+ conv_bias=True,
45
+ bias=False,
46
+ use_fast_path=True, # Fused kernel options
47
+ layer_idx=None,
48
+ device=None,
49
+ dtype=None,
50
+ ):
51
+ factory_kwargs = {"device": device, "dtype": dtype}
52
+ super().__init__()
53
+ self.d_model = d_model
54
+ self.d_state = d_state
55
+ self.d_conv = d_conv
56
+ self.expand = expand
57
+ self.d_inner = int(self.expand * self.d_model)
58
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
59
+ self.use_fast_path = use_fast_path
60
+ self.layer_idx = layer_idx
61
+
62
+ self.in_proj = nn.Linear(
63
+ self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
64
+ )
65
+
66
+ self.conv1d = nn.Conv1d(
67
+ in_channels=self.d_inner,
68
+ out_channels=self.d_inner,
69
+ bias=conv_bias,
70
+ kernel_size=d_conv,
71
+ groups=self.d_inner,
72
+ padding=d_conv - 1,
73
+ **factory_kwargs,
74
+ )
75
+
76
+ self.activation = "silu"
77
+ self.act = nn.SiLU()
78
+
79
+ self.x_proj = nn.Linear(
80
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
81
+ )
82
+ self.dt_proj = nn.Linear(
83
+ self.dt_rank, self.d_inner, bias=True, **factory_kwargs
84
+ )
85
+
86
+ # Initialize special dt projection to preserve variance at initialization
87
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
88
+ if dt_init == "constant":
89
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
90
+ elif dt_init == "random":
91
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
92
+ else:
93
+ raise NotImplementedError
94
+
95
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
96
+ dt = torch.exp(
97
+ torch.rand(self.d_inner, **factory_kwargs)
98
+ * (math.log(dt_max) - math.log(dt_min))
99
+ + math.log(dt_min)
100
+ ).clamp(min=dt_init_floor)
101
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
102
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
103
+ with torch.no_grad():
104
+ self.dt_proj.bias.copy_(inv_dt)
105
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
106
+ self.dt_proj.bias._no_reinit = True
107
+
108
+ # S4D real initialization
109
+ A = repeat(
110
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
111
+ "n -> d n",
112
+ d=self.d_inner,
113
+ ).contiguous()
114
+ A_log = torch.log(A) # Keep A_log in fp32
115
+ self.A_log = nn.Parameter(A_log)
116
+ self.A_log._no_weight_decay = True
117
+
118
+ # D "skip" parameter
119
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
120
+ self.D._no_weight_decay = True
121
+
122
+ self.out_proj = nn.Linear(
123
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
124
+ )
125
+
126
+ def forward(self, hidden_states, inference_params=None):
127
+ """
128
+ hidden_states: (B, L, D)
129
+ Returns: same shape as hidden_states
130
+ """
131
+ batch, seqlen, dim = hidden_states.shape
132
+
133
+ conv_state, ssm_state = None, None
134
+ if inference_params is not None:
135
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
136
+ if inference_params.seqlen_offset > 0:
137
+ # The states are updated inplace
138
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
139
+ return out
140
+
141
+ # We do matmul and transpose BLH -> HBL at the same time
142
+ xz = rearrange(
143
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
144
+ "d (b l) -> b d l",
145
+ l=seqlen,
146
+ )
147
+ if self.in_proj.bias is not None:
148
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
149
+
150
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
151
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
152
+ if (
153
+ self.use_fast_path
154
+ and causal_conv1d_fn is not None
155
+ and inference_params is None
156
+ ): # Doesn't support outputting the states
157
+ out = mamba_inner_fn(
158
+ xz,
159
+ self.conv1d.weight,
160
+ self.conv1d.bias,
161
+ self.x_proj.weight,
162
+ self.dt_proj.weight,
163
+ self.out_proj.weight,
164
+ self.out_proj.bias,
165
+ A,
166
+ None, # input-dependent B
167
+ None, # input-dependent C
168
+ self.D.float(),
169
+ delta_bias=self.dt_proj.bias.float(),
170
+ delta_softplus=True,
171
+ )
172
+ else:
173
+ x, z = xz.chunk(2, dim=1)
174
+ # Compute short convolution
175
+ if conv_state is not None:
176
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
177
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
178
+ conv_state.copy_(
179
+ F.pad(x, (self.d_conv - x.shape[-1], 0))
180
+ ) # Update state (B D W)
181
+ if causal_conv1d_fn is None:
182
+ x = self.act(self.conv1d(x)[..., :seqlen])
183
+ else:
184
+ assert self.activation in ["silu", "swish"]
185
+ x = causal_conv1d_fn(
186
+ x=x,
187
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
188
+ bias=self.conv1d.bias,
189
+ activation=self.activation,
190
+ )
191
+
192
+ # We're careful here about the layout, to avoid extra transposes.
193
+ # We want dt to have d as the slowest moving dimension
194
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
195
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
196
+ dt, B, C = torch.split(
197
+ x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
198
+ )
199
+ dt = self.dt_proj.weight @ dt.t()
200
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
201
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
202
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
203
+ assert self.activation in ["silu", "swish"]
204
+ y = selective_scan_fn(
205
+ x,
206
+ dt,
207
+ A,
208
+ B,
209
+ C,
210
+ self.D.float(),
211
+ z=z,
212
+ delta_bias=self.dt_proj.bias.float(),
213
+ delta_softplus=True,
214
+ return_last_state=ssm_state is not None,
215
+ )
216
+ if ssm_state is not None:
217
+ y, last_state = y
218
+ ssm_state.copy_(last_state)
219
+ y = rearrange(y, "b d l -> b l d")
220
+ out = self.out_proj(y)
221
+ return out
222
+
223
+ def step(self, hidden_states, conv_state, ssm_state):
224
+ dtype = hidden_states.dtype
225
+ assert (
226
+ hidden_states.shape[1] == 1
227
+ ), "Only support decoding with 1 token at a time for now"
228
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
229
+ x, z = xz.chunk(2, dim=-1) # (B D)
230
+
231
+ # Conv step
232
+ if causal_conv1d_update is None:
233
+ conv_state.copy_(
234
+ torch.roll(conv_state, shifts=-1, dims=-1)
235
+ ) # Update state (B D W)
236
+ conv_state[:, :, -1] = x
237
+ x = torch.sum(
238
+ conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
239
+ ) # (B D)
240
+ if self.conv1d.bias is not None:
241
+ x = x + self.conv1d.bias
242
+ x = self.act(x).to(dtype=dtype)
243
+ else:
244
+ x = causal_conv1d_update(
245
+ x,
246
+ conv_state,
247
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
248
+ self.conv1d.bias,
249
+ self.activation,
250
+ )
251
+
252
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
253
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
254
+ # Don't add dt_bias here
255
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
256
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
257
+
258
+ # SSM step
259
+ if selective_state_update is None:
260
+ # Discretize A and B
261
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
262
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
263
+ dB = torch.einsum("bd,bn->bdn", dt, B)
264
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
265
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
266
+ y = y + self.D.to(dtype) * x
267
+ y = y * self.act(z) # (B D)
268
+ else:
269
+ y = selective_state_update(
270
+ ssm_state,
271
+ x,
272
+ dt,
273
+ A,
274
+ B,
275
+ C,
276
+ self.D,
277
+ z=z,
278
+ dt_bias=self.dt_proj.bias,
279
+ dt_softplus=True,
280
+ )
281
+
282
+ out = self.out_proj(y)
283
+ return out.unsqueeze(1), conv_state, ssm_state
284
+
285
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
286
+ device = self.out_proj.weight.device
287
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
288
+ conv_state = torch.zeros(
289
+ batch_size,
290
+ self.d_model * self.expand,
291
+ self.d_conv,
292
+ device=device,
293
+ dtype=conv_dtype,
294
+ )
295
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
296
+ # ssm_dtype = torch.float32
297
+ ssm_state = torch.zeros(
298
+ batch_size,
299
+ self.d_model * self.expand,
300
+ self.d_state,
301
+ device=device,
302
+ dtype=ssm_dtype,
303
+ )
304
+ return conv_state, ssm_state
305
+
306
+ def _get_states_from_cache(
307
+ self, inference_params, batch_size, initialize_states=False
308
+ ):
309
+ assert self.layer_idx is not None
310
+ if self.layer_idx not in inference_params.key_value_memory_dict:
311
+ batch_shape = (batch_size,)
312
+ conv_state = torch.zeros(
313
+ batch_size,
314
+ self.d_model * self.expand,
315
+ self.d_conv,
316
+ device=self.conv1d.weight.device,
317
+ dtype=self.conv1d.weight.dtype,
318
+ )
319
+ ssm_state = torch.zeros(
320
+ batch_size,
321
+ self.d_model * self.expand,
322
+ self.d_state,
323
+ device=self.dt_proj.weight.device,
324
+ dtype=self.dt_proj.weight.dtype,
325
+ # dtype=torch.float32,
326
+ )
327
+ inference_params.key_value_memory_dict[self.layer_idx] = (
328
+ conv_state,
329
+ ssm_state,
330
+ )
331
+ else:
332
+ conv_state, ssm_state = inference_params.key_value_memory_dict[
333
+ self.layer_idx
334
+ ]
335
+ # TODO: What if batch size changes between generation, and we reuse the same states?
336
+ if initialize_states:
337
+ conv_state.zero_()
338
+ ssm_state.zero_()
339
+ return conv_state, ssm_state
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mha.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ try:
11
+ from flash_attn import flash_attn_with_kvcache
12
+ except ImportError:
13
+ flash_attn_with_kvcache = None
14
+
15
+ try:
16
+ from flash_attn.layers.rotary import RotaryEmbedding
17
+ except ImportError:
18
+ RotaryEmbedding = None
19
+
20
+ try:
21
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
22
+ except ImportError:
23
+ causal_conv1d_fn, causal_conv1d_update = None, None
24
+
25
+
26
+ def _update_kv_cache(kv, inference_params, layer_idx):
27
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
28
+ # Pre-allocate memory for key-values for inference.
29
+ num_heads, head_dim = kv.shape[-2:]
30
+ assert layer_idx in inference_params.key_value_memory_dict
31
+ kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
32
+ # Adjust key and value for inference
33
+ batch_start = inference_params.batch_size_offset
34
+ batch_end = batch_start + kv.shape[0]
35
+ sequence_start = inference_params.seqlen_offset
36
+ sequence_end = sequence_start + kv.shape[1]
37
+ assert batch_end <= kv_cache.shape[0]
38
+ assert sequence_end <= kv_cache.shape[1]
39
+ assert kv_cache is not None
40
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
41
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
42
+
43
+
44
+ class MHA(nn.Module):
45
+ """Multi-head self-attention and cross-attention"""
46
+
47
+ def __init__(
48
+ self,
49
+ embed_dim,
50
+ num_heads,
51
+ num_heads_kv=None,
52
+ head_dim=None, # If None, use embed_dim // num_heads
53
+ mlp_dim=0,
54
+ qkv_proj_bias=True,
55
+ out_proj_bias=True,
56
+ softmax_scale=None,
57
+ causal=False,
58
+ layer_idx=None,
59
+ d_conv=0,
60
+ rotary_emb_dim=0,
61
+ rotary_emb_base=10000.0,
62
+ rotary_emb_interleaved=False,
63
+ device=None,
64
+ dtype=None,
65
+ ) -> None:
66
+ """
67
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
68
+ return_residual: whether to return the input x along with the output. This is for
69
+ performance reason: for post-norm architecture, returning the input allows us
70
+ to fuse the backward of nn.Linear with the residual connection.
71
+ """
72
+ factory_kwargs = {"device": device, "dtype": dtype}
73
+ super().__init__()
74
+ self.embed_dim = embed_dim
75
+ self.layer_idx = layer_idx
76
+ self.d_conv = d_conv
77
+ self.rotary_emb_dim = rotary_emb_dim
78
+ self.softmax_scale = softmax_scale
79
+ self.causal = causal
80
+
81
+ self.num_heads = num_heads
82
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
83
+ assert (
84
+ self.num_heads % self.num_heads_kv == 0
85
+ ), "num_heads must be divisible by num_heads_kv"
86
+ if head_dim is None:
87
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
88
+ self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
89
+ self.mlp_dim = math.ceil(mlp_dim / 256) * 256
90
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
91
+ out_dim = self.head_dim * self.num_heads
92
+
93
+ if self.rotary_emb_dim > 0:
94
+ assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
95
+ self.rotary_emb = RotaryEmbedding(
96
+ self.rotary_emb_dim,
97
+ base=rotary_emb_base,
98
+ interleaved=rotary_emb_interleaved,
99
+ device=device,
100
+ )
101
+
102
+ self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
103
+ if self.d_conv > 0:
104
+ self.conv1d = nn.Conv1d(
105
+ qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
106
+ **factory_kwargs
107
+ )
108
+ self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
109
+
110
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
111
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
112
+ device = self.out_proj.weight.device
113
+ if self.d_conv > 0:
114
+ conv_state = torch.zeros(
115
+ batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
116
+ )
117
+ else:
118
+ conv_state = None
119
+ kv_cache = torch.empty(
120
+ batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
121
+ )
122
+ return kv_cache, conv_state
123
+
124
+ def _update_kv_cache(self, kv, inference_params):
125
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
126
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
127
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
128
+
129
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
130
+ """
131
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
132
+ q: (batch_size, seqlen_q, nheads, head_dim)
133
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
134
+ """
135
+ assert inference_params is not None and inference_params.seqlen_offset > 0
136
+ if self.rotary_emb_dim > 0:
137
+ self.rotary_emb._update_cos_sin_cache(
138
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
139
+ )
140
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
141
+ else:
142
+ rotary_cos, rotary_sin = None, None
143
+ batch = q.shape[0]
144
+ kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
145
+ kv_cache = kv_cache[:batch]
146
+ cache_seqlens = (
147
+ inference_params.lengths_per_sample[:batch]
148
+ if inference_params.lengths_per_sample is not None
149
+ else inference_params.seqlen_offset
150
+ )
151
+ assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
152
+ context = flash_attn_with_kvcache(
153
+ q,
154
+ kv_cache[:, :, 0],
155
+ kv_cache[:, :, 1],
156
+ kv[:, :, 0],
157
+ kv[:, :, 1],
158
+ rotary_cos=rotary_cos,
159
+ rotary_sin=rotary_sin,
160
+ cache_seqlens=cache_seqlens,
161
+ softmax_scale=self.softmax_scale,
162
+ causal=self.causal,
163
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
164
+ )
165
+ return context
166
+
167
+ def _update_kvcache_attention(self, q, kv, inference_params):
168
+ """Write kv to inference_params, then do attention"""
169
+ if (
170
+ inference_params.seqlen_offset == 0
171
+ or flash_attn_with_kvcache is None
172
+ ):
173
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
174
+ kv = self._update_kv_cache(kv, inference_params)
175
+ k, v = kv.unbind(dim=-3)
176
+ k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
177
+ v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
178
+ return F.scaled_dot_product_attention(
179
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
180
+ ).transpose(1, 2)
181
+ else:
182
+ batch = q.shape[0]
183
+ kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
184
+ kv_cache = kv_cache[:batch]
185
+ cache_seqlens = (
186
+ inference_params.lengths_per_sample[:batch]
187
+ if inference_params.lengths_per_sample is not None
188
+ else inference_params.seqlen_offset
189
+ )
190
+ return flash_attn_with_kvcache(
191
+ q,
192
+ kv_cache[:, :, 0],
193
+ kv_cache[:, :, 1],
194
+ kv[:, :, 0],
195
+ kv[:, :, 1],
196
+ cache_seqlens=cache_seqlens,
197
+ softmax_scale=self.softmax_scale,
198
+ causal=self.causal,
199
+ )
200
+
201
+ def forward(self, x, inference_params=None):
202
+ """
203
+ Arguments:
204
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
205
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
206
+ is the is the sum of the sequence lengths in the batch.
207
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
208
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
209
+ """
210
+ if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
211
+ inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
212
+ x.shape[0], inference_params.max_seqlen, dtype=x.dtype
213
+ )
214
+ seqlen_offset = (
215
+ 0
216
+ if inference_params is None
217
+ else (
218
+ inference_params.lengths_per_sample
219
+ if inference_params.lengths_per_sample is not None
220
+ else inference_params.seqlen_offset
221
+ )
222
+ )
223
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
224
+ qkv = self.in_proj(x)
225
+ if self.mlp_dim > 0:
226
+ qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
227
+ x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
228
+ x_mlp = x_mlp_up * F.silu(x_mlp_gate)
229
+ if self.d_conv > 0:
230
+ # The inference code for conv1d is pretty messy, should clean it up
231
+ if (inference_params is None or inference_params.seqlen_offset == 0):
232
+ if causal_conv1d_fn is None:
233
+ qkv = rearrange(
234
+ self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
235
+ ).contiguous()
236
+ else:
237
+ qkv = causal_conv1d_fn(
238
+ qkv.transpose(1, 2),
239
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
240
+ self.conv1d.bias
241
+ ).transpose(1, 2)
242
+ if inference_params is not None:
243
+ _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
244
+ # If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
245
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
246
+ qkv_t = rearrange(qkv, "b l d -> b d l")
247
+ conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
248
+ else:
249
+ _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
250
+ assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
251
+ qkv = qkv.squeeze(1)
252
+ # Conv step
253
+ if causal_conv1d_update is None:
254
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
255
+ conv_state[:, :, -1] = qkv
256
+ qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
257
+ if self.conv1d.bias is not None:
258
+ qkv = qkv + self.conv1d.bias
259
+ else:
260
+ qkv = causal_conv1d_update(
261
+ qkv,
262
+ conv_state,
263
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
264
+ self.conv1d.bias
265
+ )
266
+ qkv = qkv.unsqueeze(1)
267
+ q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
268
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
269
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
270
+ if (
271
+ inference_params is None
272
+ or inference_params.seqlen_offset == 0
273
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
274
+ ):
275
+ if self.rotary_emb_dim > 0:
276
+ q, kv = self.rotary_emb(
277
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
278
+ )
279
+ if inference_params is None:
280
+ k, v = kv.unbind(dim=-3)
281
+ k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
282
+ v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
283
+ context = F.scaled_dot_product_attention(
284
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
285
+ ).transpose(1, 2)
286
+ else:
287
+ context = self._update_kvcache_attention(q, kv, inference_params)
288
+ else:
289
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
290
+ context = rearrange(context, "... h d -> ... (h d)")
291
+ if self.mlp_dim > 0:
292
+ context = torch.cat([context, x_mlp], dim=-1)
293
+ out = self.out_proj(context)
294
+ return out
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mlp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class GatedMLP(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_features,
10
+ hidden_features=None,
11
+ out_features=None,
12
+ activation=F.silu,
13
+ bias=False,
14
+ multiple_of=128,
15
+ device=None,
16
+ dtype=None,
17
+ ):
18
+ factory_kwargs = {"device": device, "dtype": dtype}
19
+ super().__init__()
20
+ out_features = out_features if out_features is not None else in_features
21
+ hidden_features = (
22
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
23
+ )
24
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
25
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
26
+ self.activation = activation
27
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
28
+
29
+ def forward(self, x):
30
+ y = self.fc1(x)
31
+ y, gate = y.chunk(2, dim=-1)
32
+ y = y * self.activation(gate)
33
+ y = self.fc2(y)
34
+ return y
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/modules/ssd_minimal.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Albert Gu and Tri Dao.
2
+ """Minimal implementation of SSD.
3
+
4
+ This is the same as Listing 1 from the paper.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
12
+
13
+
14
+ def segsum_unstable(x):
15
+ """Naive segment sum calculation."""
16
+ T = x.size(-1)
17
+ x_cumsum = torch.cumsum(x, dim=-1)
18
+ x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
19
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
20
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
21
+ return x_segsum
22
+
23
+
24
+ def segsum(x):
25
+ """More stable segment sum calculation."""
26
+ T = x.size(-1)
27
+ x = repeat(x, "... d -> ... d e", e=T)
28
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
29
+ x = x.masked_fill(~mask, 0)
30
+ x_segsum = torch.cumsum(x, dim=-2)
31
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
32
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
33
+ return x_segsum
34
+
35
+
36
+ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
37
+ """
38
+ Arguments:
39
+ X: (batch, length, n_heads, d_head)
40
+ A: (batch, length, n_heads)
41
+ B: (batch, length, n_heads, d_state)
42
+ C: (batch, length, n_heads, d_state)
43
+ Return:
44
+ Y: (batch, length, n_heads, d_head)
45
+ """
46
+ assert X.dtype == A.dtype == B.dtype == C.dtype
47
+ assert X.shape[1] % block_len == 0
48
+
49
+ # Rearrange into blocks/chunks
50
+ X, A, B, C = [
51
+ rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
52
+ ]
53
+
54
+ A = rearrange(A, "b c l h -> b h c l")
55
+ A_cumsum = torch.cumsum(A, dim=-1)
56
+
57
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
58
+ L = torch.exp(segsum(A))
59
+ Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
60
+
61
+ # 2. Compute the state for each intra-chunk
62
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
63
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
64
+ states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
65
+
66
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
67
+ # (middle term of factorization of off-diag blocks; A terms)
68
+ if initial_states is None:
69
+ initial_states = torch.zeros_like(states[:, :1])
70
+ states = torch.cat([initial_states, states], dim=1)
71
+ decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
72
+ new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
73
+ states, final_state = new_states[:, :-1], new_states[:, -1]
74
+
75
+ # 4. Compute state -> output conversion per chunk
76
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
77
+ state_decay_out = torch.exp(A_cumsum)
78
+ Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
79
+
80
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
81
+ Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
82
+ return Y, final_state
83
+
84
+
85
+ # Simple test
86
+ def test_correctness():
87
+ torch.manual_seed(42)
88
+
89
+ ## Dimensions
90
+ # Denoted (B, T, Q, D, P) in the paper
91
+ batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
92
+ nheads = dim // headdim # (H) in the paper
93
+ ngroups = 1 # (G) in the paper
94
+ dstate = 64 # (N) in the paper
95
+ dtype = torch.float32
96
+ device = "cuda"
97
+
98
+ x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
99
+ dt = F.softplus(
100
+ torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4
101
+ ).requires_grad_()
102
+ A = (
103
+ -torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))
104
+ ).requires_grad_()
105
+ B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
106
+ C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
107
+ D = torch.randn(nheads, dtype=dtype, device=device)
108
+
109
+ # Comparing fused version and minimal version
110
+ y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None)
111
+ y_min, _ = ssd_minimal_discrete(x * dt.unsqueeze(-1), A * dt, B, C, chunk_size)
build/torch27-cxx11-cu118-x86_64-linux/mamba_ssm/ops/__init__.py ADDED
File without changes