reference the flash attention GitHub
Browse files- bert_padding.py +5 -0
- block.py +5 -0
- embedding.py +5 -0
- mha.py +9 -0
- mlp.py +5 -0
bert_padding.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
| 1 |
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
| 5 |
from einops import rearrange, repeat
|
|
|
|
| 1 |
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
| 2 |
|
| 3 |
+
""""
|
| 4 |
+
The implementation was further adapted from
|
| 5 |
+
https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
import torch
|
| 9 |
import torch.nn.functional as F
|
| 10 |
from einops import rearrange, repeat
|
block.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
| 1 |
# Copyright (c) 2024, Tri Dao.
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from functools import partial
|
| 4 |
from typing import Optional
|
| 5 |
|
|
|
|
| 1 |
# Copyright (c) 2024, Tri Dao.
|
| 2 |
|
| 3 |
+
""""
|
| 4 |
+
The implementation was adopted from
|
| 5 |
+
https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
from functools import partial
|
| 9 |
from typing import Optional
|
| 10 |
|
embedding.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
| 1 |
# Copyright (c) 2022, Tri Dao.
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
from torch import Tensor
|
|
|
|
| 1 |
# Copyright (c) 2022, Tri Dao.
|
| 2 |
|
| 3 |
+
""""
|
| 4 |
+
The implementation was adopted from
|
| 5 |
+
https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0/flash_attn/models/bert.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
from torch import Tensor
|
mha.py
CHANGED
|
@@ -1,5 +1,14 @@
|
|
| 1 |
# Copyright (c) 2023, Tri Dao.
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import math
|
| 4 |
from functools import partial
|
| 5 |
|
|
|
|
| 1 |
# Copyright (c) 2023, Tri Dao.
|
| 2 |
|
| 3 |
+
""""
|
| 4 |
+
The implementation was adopted from
|
| 5 |
+
https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
|
| 6 |
+
and made modifications to
|
| 7 |
+
- support QK normalization
|
| 8 |
+
- make ALiBi run with MHA (needed to cast alibi slopes to fp32)
|
| 9 |
+
- make ALiBi run on CPU
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
import math
|
| 13 |
from functools import partial
|
| 14 |
|
mlp.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
| 1 |
# Copyright (c) 2023, Tri Dao.
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
|
|
|
| 1 |
# Copyright (c) 2023, Tri Dao.
|
| 2 |
|
| 3 |
+
""""
|
| 4 |
+
The implementation was adopted from
|
| 5 |
+
https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.nn.functional as F
|