Files changed (5) hide show
  1. README.md +32 -8
  2. config.json +1 -1
  3. modeling_drama.py +111 -152
  4. modeling_drama_nested.py +639 -0
  5. modeling_drama_non_nested.py +184 -0
README.md CHANGED
@@ -60,9 +60,10 @@ model_name = "facebook/drama-base"
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
  tokenizer = AutoTokenizer.from_pretrained(model_name)
62
  model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
 
 
 
63
 
64
- query_embs = model.encode_queries(tokenizer, queries)
65
- doc_embs = model.encode_documents(tokenizer, documents)
66
 
67
  scores = query_embs @ doc_embs.T
68
  print(scores.tolist())
@@ -77,8 +78,8 @@ print(scores.tolist())
77
  DRAMA models are trained using Matryoshka Representation Learning ([MRL](https://github.com/RAIVNLab/MRL)) to support flexible dimensionality. Both queries and documents can be encoded into smaller dimensions, such as 256, using the following:
78
 
79
  ```python
80
- query_embs = model.encode_queries(tokenizer, queries, dim=256)
81
- doc_embs = model.encode_documents(tokenizer, documents, dim=256)
82
 
83
  scores = query_embs @ doc_embs.T
84
  print(scores.tolist())
@@ -101,8 +102,8 @@ documents = [
101
 
102
  model = SentenceTransformer("facebook/drama-base", trust_remote_code=True)
103
 
104
- query_embs = model.encode(queries, prompt_name="query")
105
- doc_embs = model.encode(documents)
106
 
107
  scores = model.similarity(query_embs, doc_embs)
108
  print(scores.tolist())
@@ -128,8 +129,8 @@ documents = [
128
 
129
  model = SentenceTransformer("facebook/drama-base", truncate_dim=256, trust_remote_code=True)
130
 
131
- query_embs = model.encode(queries, prompt_name="query")
132
- doc_embs = model.encode(documents)
133
 
134
  scores = model.similarity(query_embs, doc_embs)
135
  print(scores.tolist())
@@ -165,3 +166,26 @@ If you find our paper or models helpful, please consider cite as follows:
165
  year={2025}
166
  }
167
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
  tokenizer = AutoTokenizer.from_pretrained(model_name)
62
  model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
63
+ use_nested = False
64
+ query_embs = model.encode_queries(tokenizer, queries, use_nested=use_nested)
65
+ doc_embs = model.encode_documents(tokenizer, documents, use_nested=use_nested)
66
 
 
 
67
 
68
  scores = query_embs @ doc_embs.T
69
  print(scores.tolist())
 
78
  DRAMA models are trained using Matryoshka Representation Learning ([MRL](https://github.com/RAIVNLab/MRL)) to support flexible dimensionality. Both queries and documents can be encoded into smaller dimensions, such as 256, using the following:
79
 
80
  ```python
81
+ query_embs = model.encode_queries(tokenizer, queries, dim=256, use_nested=use_nested)
82
+ doc_embs = model.encode_documents(tokenizer, documents, dim=256, use_nested=use_nested)
83
 
84
  scores = query_embs @ doc_embs.T
85
  print(scores.tolist())
 
102
 
103
  model = SentenceTransformer("facebook/drama-base", trust_remote_code=True)
104
 
105
+ query_embs = model.encode(queries, prompt_name="query", use_nested=use_nested)
106
+ doc_embs = model.encode(documents, use_nested=use_nested)
107
 
108
  scores = model.similarity(query_embs, doc_embs)
109
  print(scores.tolist())
 
129
 
130
  model = SentenceTransformer("facebook/drama-base", truncate_dim=256, trust_remote_code=True)
131
 
132
+ query_embs = model.encode(queries, prompt_name="query", use_nested=use_nested)
133
+ doc_embs = model.encode(documents, use_nested=use_nested)
134
 
135
  scores = model.similarity(query_embs, doc_embs)
136
  print(scores.tolist())
 
166
  year={2025}
167
  }
168
  ```
169
+
170
+ ## Efficient DRAMA
171
+ ### Nested Tensors
172
+ [Nested Tensors](https://docs.pytorch.org/docs/stable/nested.html) provide a way to handle ragged-shaped data within a single tensor, allowing for efficient operations on such data.
173
+ They store data in a compact packed representation while offering a standard PyTorch tensor interface, making it easy to apply various
174
+ operations.
175
+ Nested Tensors are particularly advantageous for model deployments that perform inference on large batches of sequences with varying
176
+ lengths. Traditional tensors require padding all sequences in a batch to the same length, which can be inefficient, especially when
177
+ the batch includesmany short sequences and a single long sequence. Nested Tensors eliminate the need for padding, thus avoiding
178
+ unnecessary computation on extra pad tokens. This results in more efficient processing of batches with varying sequence lengths.
179
+
180
+ ### Performance
181
+ Experiments have demonstrated a 1.7x to 2.3x (base,large and 1B) improvement in queries per second (QPS) for batch inference with sequences of varied lengths.
182
+
183
+ ### Usage
184
+ To enable Nested Tensors, simply set the use_nested variable to true. This will activate the nested jagged tensors and allow you to
185
+ take advantage of efficient inference.
186
+
187
+ > Prerequisites Package versions as this code have been tested with these versions. Please use these or some latest versions to avoid compatibility issues.
188
+
189
+ >- Python: 3.12
190
+ >- Transformers: 4.51.1
191
+ >- PyTorch: 2.7.1
config.json CHANGED
@@ -4,7 +4,7 @@
4
  "DramaModel"
5
  ],
6
  "auto_map": {
7
- "AutoModel": "modeling_drama.DramaModel"
8
  },
9
  "attention_bias": false,
10
  "attention_dropout": 0.0,
 
4
  "DramaModel"
5
  ],
6
  "auto_map": {
7
+ "AutoModel": "modeling_drama.DramaModelWrapper"
8
  },
9
  "attention_bias": false,
10
  "attention_dropout": 0.0,
modeling_drama.py CHANGED
@@ -1,166 +1,125 @@
1
- from __future__ import annotations
 
2
 
3
- import torch
4
- import torch.nn.functional as F
5
 
6
- from transformers import LlamaModel, LlamaConfig, PreTrainedTokenizer
7
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 
 
8
 
 
 
 
9
 
10
- class DramaModel(LlamaModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  """
12
- DramaModel is a modified version of the LlamaModel that supports bi-directional attention
13
- and provides query and document encoding functionalities.
 
 
 
14
  """
15
 
16
- def __init__(self, config: LlamaConfig):
 
17
  """
18
- Initializes the DramaModel by disabling causal masking in self-attention layers.
19
- """
20
- super().__init__(config)
21
- for layer in self.layers:
22
- layer.self_attn.is_causal = False
23
- # query prefix
24
- self.query_prefix = "Query: "
25
- self.max_seq_len = 8192
26
- self.hidden_size = config.hidden_size
27
-
28
- def _update_causal_mask(
29
- self,
30
- attention_mask: torch.Tensor,
31
- input_tensor: torch.Tensor,
32
- cache_position: torch.Tensor,
33
- past_seen_tokens=None,
34
- output_attentions=False,
35
- ):
36
- """
37
- Updates the causal mask for attention computations.
38
- """
39
- if self.config._attn_implementation == "flash_attention_2":
40
- if attention_mask is not None and (attention_mask == 0.0).any():
41
- return attention_mask
42
- return None
43
- if attention_mask is None or attention_mask.dim() == 4:
44
- return attention_mask
45
-
46
- return AttentionMaskConverter._expand_mask(
47
- mask=attention_mask,
48
- dtype=input_tensor.dtype,
49
- )
50
 
51
- def _average_pool(
52
- self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
53
- ) -> torch.Tensor:
54
- """
55
- Computes the average pooled representation of the last hidden states.
56
- """
57
- last_hidden = last_hidden_states.masked_fill(
58
- ~attention_mask[..., None].bool(), 0.0
59
- )
60
- return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
61
 
62
- def _tokenize(
63
- self,
64
- tokenizer: PreTrainedTokenizer,
65
- texts: list[str],
66
- max_seq_len: int = None,
67
- ):
68
- """
69
- Tokenizes input text sequences with optional sequence length restriction.
70
- """
71
- if max_seq_len is None:
72
- max_seq_len = self.max_seq_len
73
- tokenized = tokenizer(
74
- texts,
75
- padding=True,
76
- truncation=True,
77
- max_length=max_seq_len,
78
- return_tensors='pt',
79
- ).to(self.device)
80
- return tokenized
81
-
82
- def encode(self, input_ids, attention_mask, dim, *args, **kwargs):
83
- """
84
- Pass through the model and compute normalized embeddings.
85
-
86
  Args:
87
- input_ids (torch.Tensor): Input token IDs.
88
- attention_mask (torch.Tensor): Attention mask tensor.
89
- dim (int): Dimensionality for output embeddings.
90
-
91
- Returns:
92
- torch.Tensor: Normalized output embeddings.
93
- """
94
- outputs = self.forward(
95
- input_ids, attention_mask, *args, **kwargs
96
- )
97
- embeddings = self._average_pool(
98
- outputs.last_hidden_state[:, :, :dim], attention_mask
99
- )
100
- # normalize embeddings
101
- embeddings = F.normalize(embeddings, p=2, dim=1)
102
- return embeddings
103
-
104
- def encode_queries(
105
- self,
106
- tokenizer: PreTrainedTokenizer,
107
- queries: list[str],
108
- max_seq_len: int = None,
109
- dim: int = None,
110
- ):
111
- """
112
- Encodes a list of queries into embeddings.
113
-
114
- Args:
115
- tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
116
- queries (list[str]): List of query texts.
117
- max_seq_len (int, optional): Maximum sequence length.
118
- dim (int, optional): Dimensionality for output embeddings.
119
-
120
- Returns:
121
- torch.Tensor: Encoded query embeddings in shape (num_queries, dim).
122
- """
123
- if not queries:
124
- raise ValueError("queries must not be empty.")
125
- if not isinstance(queries, list) or not all(isinstance(q, str) for q in queries):
126
- raise ValueError("queries must be a list of strings.")
127
- if tokenizer is None:
128
- raise ValueError("tokenizer must not be None.")
129
- if dim is not None and (dim < 1 or dim > self.hidden_size):
130
- raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
131
- queries = [self.query_prefix + query for query in queries]
132
- tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len)
133
- embeddings = self.encode(**tokenized_queries, dim=dim)
134
- return embeddings
135
-
136
- def encode_documents(
137
- self,
138
- tokenizer: PreTrainedTokenizer,
139
- documents: list[str],
140
- max_seq_len: int = None,
141
- dim: int = None,
142
- ):
143
- """
144
- Encodes a list of documents into embeddings.
145
-
146
- Args:
147
- tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
148
- documents (list[str]): List of document texts.
149
- max_seq_len (int, optional): Maximum sequence length.
150
- dim (int, optional): Dimensionality for output embeddings.
151
-
152
  Returns:
153
- torch.Tensor: Encoded document embeddings in shape (num_documents, dim).
154
  """
155
- if not documents:
156
- raise ValueError("documents must not be empty.")
157
- if not isinstance(documents, list) or not all(isinstance(d, str) for d in documents):
158
- raise ValueError("documents must be a list of strings.")
159
- if tokenizer is None:
160
- raise ValueError("tokenizer must not be None.")
161
- if dim is not None and (dim < 1 or dim > self.hidden_size):
162
- raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
163
- tokenized_documents = self._tokenize(tokenizer, documents, max_seq_len)
164
- embeddings = self.encode(**tokenized_documents, dim=dim)
165
- return embeddings
 
 
 
 
 
 
 
 
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import warnings
3
 
 
 
4
 
5
+ def _check_torch_version():
6
+ """Check if PyTorch version is >= 2.7.1"""
7
+ try:
8
+ import torch
9
 
10
+ # Simple version comparison
11
+ version_str = torch.__version__.split("+")[0] # Remove any suffixes like +cu118
12
+ version_parts = version_str.split(".")
13
 
14
+ # Compare major version
15
+ if int(version_parts[0]) > 2:
16
+ return True
17
+ # Compare minor version
18
+ elif int(version_parts[0]) == 2 and int(version_parts[1]) > 7:
19
+ return True
20
+ # Compare patch version
21
+ elif (
22
+ int(version_parts[0]) == 2
23
+ and int(version_parts[1]) == 7
24
+ and int(version_parts[2]) >= 1
25
+ ):
26
+ return True
27
+
28
+ return False
29
+ except (ImportError, AttributeError, IndexError, ValueError):
30
+ return False
31
+
32
+
33
+ def _check_transformers_version():
34
+ """Check if Transformers version is >= 4.51.1"""
35
+ try:
36
+ import transformers
37
+
38
+ # Simple version comparison
39
+ version_str = transformers.__version__.split("+")[0] # Remove any suffixes
40
+ version_parts = version_str.split(".")
41
+
42
+ # Compare major version
43
+ if int(version_parts[0]) > 4:
44
+ return True
45
+ # Compare minor version
46
+ elif int(version_parts[0]) == 4 and int(version_parts[1]) > 51:
47
+ return True
48
+ # Compare patch version
49
+ elif (
50
+ int(version_parts[0]) == 4
51
+ and int(version_parts[1]) == 51
52
+ and int(version_parts[2]) >= 1
53
+ ):
54
+ return True
55
+
56
+ return False
57
+ except (ImportError, AttributeError, IndexError, ValueError):
58
+ return False
59
+
60
+
61
+ class DramaModelWrapper:
62
  """
63
+ Factory class for DramaModel that returns the appropriate implementation
64
+ based on the Python version.
65
+
66
+ If Python version >= 3.12, returns an instance of the nested tensor implementation.
67
+ Otherwise, returns an instance of the non-nested implementation.
68
  """
69
 
70
+ @classmethod
71
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
72
  """
73
+ Instantiate a pretrained model from a pre-trained model configuration.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ This method is required by the transformers library's auto model loading mechanism.
 
 
 
 
 
 
 
 
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  Args:
78
+ pretrained_model_name_or_path: Path to the pretrained model or its name
79
+ *model_args: Additional positional arguments to pass to the implementation
80
+ **kwargs: Additional keyword arguments to pass to the implementation
81
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  Returns:
83
+ An instance of the appropriate DramaModel implementation.
84
  """
85
+ # Check Python version
86
+ use_nested = sys.version_info >= (3, 15)
87
+ if not use_nested:
88
+ warnings.warn(
89
+ "Python version < 3.12 detected. Using non-nested implementation."
90
+ )
91
+ # For Python versions below 3.12, use the non-nested implementation
92
+ from .modeling_drama_non_nested import DramaModel as NonNestedDramaModel
93
+
94
+ return NonNestedDramaModel.from_pretrained(
95
+ pretrained_model_name_or_path, *model_args, **kwargs
96
+ )
97
+
98
+ # Check PyTorch version
99
+ if not _check_torch_version():
100
+ warnings.warn(
101
+ "PyTorch version < 2.7.1 detected. Falling back to non-nested implementation."
102
+ )
103
+ from .modeling_drama_non_nested import DramaModel as NonNestedDramaModel
104
 
105
+ return NonNestedDramaModel.from_pretrained(
106
+ pretrained_model_name_or_path, *model_args, **kwargs
107
+ )
108
+
109
+ # Check Transformers version
110
+ if not _check_transformers_version():
111
+ warnings.warn(
112
+ "Transformers version < 4.51.1 detected. Falling back to non-nested implementation."
113
+ )
114
+ from .modeling_drama_non_nested import DramaModel as NonNestedDramaModel
115
+
116
+ return NonNestedDramaModel.from_pretrained(
117
+ pretrained_model_name_or_path, *model_args, **kwargs
118
+ )
119
+
120
+ # Use the nested tensor implementation if all requirements are met
121
+ from .modeling_drama_nested import DramaModel as NestedDramaModel
122
+
123
+ return NestedDramaModel.from_pretrained(
124
+ pretrained_model_name_or_path, *model_args, **kwargs
125
+ )
modeling_drama_nested.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nested._internal.nested_tensor import nested_from_padded
7
+
8
+ from transformers import (
9
+ LlamaConfig,
10
+ LlamaModel,
11
+ LlamaPreTrainedModel,
12
+ PreTrainedTokenizer,
13
+ )
14
+ from transformers.cache_utils import Cache, DynamicCache
15
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
16
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
17
+ from transformers.modeling_outputs import BaseModelOutputWithPast
18
+ from transformers.models.llama.modeling_llama import (
19
+ LlamaAttention,
20
+ LlamaDecoderLayer,
21
+ LlamaMLP,
22
+ LlamaRMSNorm,
23
+ LlamaRotaryEmbedding,
24
+ rotate_half,
25
+ )
26
+ from transformers.processing_utils import Unpack
27
+
28
+
29
+ class ModifiedLlamaAttention(LlamaAttention):
30
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
31
+ super().__init__(*args, **kwargs)
32
+ self.is_causal = False
33
+
34
+ def forward(
35
+ self,
36
+ hidden_states: torch.Tensor,
37
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
38
+ attention_mask: Optional[torch.Tensor],
39
+ past_key_value: Optional[Cache] = None,
40
+ cache_position: Optional[torch.LongTensor] = None,
41
+ **kwargs: Unpack[FlashAttentionKwargs],
42
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
43
+ input_shape = hidden_states.shape[:-1]
44
+ hidden_shape = (*input_shape, -1, self.head_dim)
45
+
46
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
47
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
48
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
49
+
50
+ cos, sin = position_embeddings
51
+ query_states, key_states = apply_rotary_pos_emb(
52
+ query_states, key_states, cos, sin
53
+ )
54
+
55
+ if past_key_value is not None:
56
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
57
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
58
+ key_states, value_states = past_key_value.update(
59
+ key_states, value_states, self.layer_idx, cache_kwargs
60
+ )
61
+
62
+ if self.config._attn_implementation != "eager":
63
+ if self.config._attn_implementation == "sdpa" and kwargs.get(
64
+ "output_attentions", False
65
+ ):
66
+ warnings.warn(
67
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
68
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
69
+ )
70
+
71
+ attn_output, attn_weights = sdpa_attention_forward(
72
+ self,
73
+ query_states,
74
+ key_states,
75
+ value_states,
76
+ attention_mask,
77
+ dropout=0.0,
78
+ scaling=self.scaling,
79
+ is_causal=False,
80
+ **kwargs,
81
+ )
82
+
83
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
84
+ attn_output = self.o_proj(attn_output)
85
+ return attn_output, attn_weights
86
+
87
+
88
+ def sdpa_attention_forward(
89
+ module: torch.nn.Module,
90
+ query: torch.Tensor,
91
+ key: torch.Tensor,
92
+ value: torch.Tensor,
93
+ attention_mask: torch.Tensor,
94
+ dropout: float = 0.0,
95
+ scaling: Optional[float] = None,
96
+ is_causal: Optional[bool] = None,
97
+ **kwargs: Any,
98
+ ) -> Tuple[torch.Tensor, None]:
99
+ if hasattr(module, "num_key_value_groups"):
100
+ if key.is_nested:
101
+ key = repeat_jagged_kv(key, module.num_key_value_groups)
102
+ value = repeat_jagged_kv(value, module.num_key_value_groups)
103
+ else:
104
+ key = repeat_dense_kv(key, module.num_key_value_groups)
105
+ value = repeat_dense_kv(value, module.num_key_value_groups)
106
+
107
+ causal_mask = attention_mask
108
+ if attention_mask is not None and causal_mask.ndim == 4:
109
+ causal_mask = causal_mask[:, :, :, : key.shape[-2]]
110
+
111
+ # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
112
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
113
+ query = query.contiguous()
114
+ key = key.contiguous()
115
+ value = value.contiguous()
116
+
117
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
118
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
119
+ # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
120
+ if is_causal is None:
121
+ is_causal = query.shape[2] > 1 and causal_mask is None
122
+
123
+ # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
124
+ # We convert it to a bool for the SDPA kernel that only accepts bools.
125
+ if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
126
+ is_causal = is_causal.item()
127
+
128
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
129
+ query,
130
+ key,
131
+ value,
132
+ attn_mask=causal_mask,
133
+ dropout_p=dropout,
134
+ scale=scaling,
135
+ is_causal=is_causal,
136
+ )
137
+ attn_output = attn_output.transpose(1, 2).contiguous()
138
+
139
+ return attn_output, None
140
+
141
+
142
+ def repeat_jagged_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
143
+ """
144
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
145
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
146
+ """
147
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
148
+ expand_shape = (batch, num_key_value_heads, -1, n_rep, head_dim)
149
+ if n_rep == 1:
150
+ return hidden_states
151
+ hidden_states = (
152
+ hidden_states.unsqueeze(3)
153
+ .expand(expand_shape)
154
+ .transpose(1, 2)
155
+ .flatten(2, 3)
156
+ .transpose(1, 2)
157
+ )
158
+ return hidden_states
159
+
160
+
161
+ def repeat_dense_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
162
+ """
163
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
164
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
165
+ """
166
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
167
+ if n_rep == 1:
168
+ return hidden_states
169
+ hidden_states = hidden_states[:, :, None, :, :].expand(
170
+ batch, num_key_value_heads, n_rep, slen, head_dim
171
+ )
172
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
173
+
174
+
175
+ def apply_rotary_pos_emb(
176
+ q: torch.Tensor,
177
+ k: torch.Tensor,
178
+ cos: torch.Tensor,
179
+ sin: torch.Tensor,
180
+ unsqueeze_dim: int = 1,
181
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
182
+ """Applies Rotary Position Embedding to the query and key tensors.
183
+
184
+ Args:
185
+ q (`torch.Tensor`): The query tensor.
186
+ k (`torch.Tensor`): The key tensor.
187
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
188
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
189
+ position_ids (`torch.Tensor`, *optional*):
190
+ Deprecated and unused.
191
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
192
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
193
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
194
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
195
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
196
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
197
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
198
+ Returns:
199
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
200
+ """
201
+ cos = cos.unsqueeze(unsqueeze_dim)
202
+ sin = sin.unsqueeze(unsqueeze_dim)
203
+ if q.is_nested and k.is_nested:
204
+ if q.layout != torch.jagged:
205
+ raise NotImplementedError(f"Unsupported layout: {q.layout}")
206
+ if k.layout != torch.jagged:
207
+ raise NotImplementedError(f"Unsupported layout: {k.layout}")
208
+ return _jagged_tensor_forward(q, k, cos, sin)
209
+ else:
210
+ return _padded_tensor_forward(q, k, cos, sin)
211
+
212
+
213
+ def _jagged_tensor_forward(
214
+ q: torch.Tensor,
215
+ k: torch.Tensor,
216
+ cos: torch.Tensor,
217
+ sin: torch.Tensor,
218
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
219
+ q_dense = q.to_padded_tensor(0.0)
220
+ k_dense = k.to_padded_tensor(0.0)
221
+ q_dense_embed = (q_dense * cos) + (rotate_half(q_dense) * sin)
222
+ k_dense_embed = (k_dense * cos) + (rotate_half(k_dense) * sin)
223
+ q_jagged_embed = convert_dense_to_jagged(q, q_dense_embed)
224
+ k_jagged_embed = convert_dense_to_jagged(k, k_dense_embed)
225
+ return q_jagged_embed, k_jagged_embed
226
+
227
+
228
+ def _padded_tensor_forward(
229
+ q: torch.Tensor,
230
+ k: torch.Tensor,
231
+ cos: torch.Tensor,
232
+ sin: torch.Tensor,
233
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
234
+ q_embed = (q * cos) + (rotate_half(q) * sin)
235
+ k_embed = (k * cos) + (rotate_half(k) * sin)
236
+ return q_embed, k_embed
237
+
238
+
239
+ def convert_dense_to_jagged(nested_q: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
240
+ padded_max_S = nested_q._get_max_seqlen()
241
+ total_L = nested_q._values.shape[nested_q._ragged_idx - 1]
242
+ if padded_max_S is None:
243
+ # use upper bound on max seqlen if it's not present
244
+ padded_max_S = total_L
245
+
246
+ # convert dense tensor -> jagged
247
+ q = q.expand(
248
+ [
249
+ x if i != nested_q._ragged_idx else padded_max_S
250
+ for i, x in enumerate(q.shape)
251
+ ]
252
+ )
253
+ nested_result = nested_from_padded(
254
+ q,
255
+ offsets=nested_q._offsets,
256
+ ragged_idx=nested_q._ragged_idx,
257
+ sum_S=total_L,
258
+ min_seqlen=nested_q._get_min_seqlen(),
259
+ max_seqlen=padded_max_S,
260
+ )
261
+ return nested_result
262
+
263
+
264
+ class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
265
+ def __init__(self, config: LlamaConfig, layer_idx: int) -> None:
266
+ nn.Module.__init__(self)
267
+ self.hidden_size: int = config.hidden_size
268
+
269
+ self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx)
270
+
271
+ self.mlp = LlamaMLP(config)
272
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
273
+ self.post_attention_layernorm = LlamaRMSNorm(
274
+ config.hidden_size, eps=config.rms_norm_eps
275
+ )
276
+
277
+
278
+ class LlamaBiModel(LlamaModel):
279
+ def __init__(self, config: LlamaConfig) -> None:
280
+ LlamaPreTrainedModel.__init__(self, config)
281
+ self.padding_idx: int = config.pad_token_id
282
+ self.vocab_size: int = config.vocab_size
283
+
284
+ self.embed_tokens = nn.Embedding(
285
+ config.vocab_size, config.hidden_size, self.padding_idx
286
+ )
287
+ self.layers = nn.ModuleList(
288
+ [
289
+ ModifiedLlamaDecoderLayer(config, layer_idx)
290
+ for layer_idx in range(config.num_hidden_layers)
291
+ ]
292
+ )
293
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
294
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
295
+ self.gradient_checkpointing = False
296
+
297
+ # Initialize weights and apply final processing
298
+ self.post_init()
299
+
300
+ def _update_causal_mask(
301
+ self,
302
+ attention_mask: torch.Tensor,
303
+ input_tensor: torch.Tensor,
304
+ cache_position: torch.Tensor,
305
+ past_seen_tokens=None,
306
+ output_attentions=False,
307
+ ):
308
+ """
309
+ Updates the causal mask for attention computations.
310
+ """
311
+ if self.config._attn_implementation == "flash_attention_2":
312
+ if attention_mask is not None and (attention_mask == 0.0).any():
313
+ return attention_mask
314
+ return None
315
+ if attention_mask is None or attention_mask.dim() == 4:
316
+ return attention_mask
317
+
318
+ return AttentionMaskConverter._expand_mask(
319
+ mask=attention_mask,
320
+ dtype=input_tensor.dtype,
321
+ )
322
+
323
+ def forward(
324
+ self,
325
+ input_ids: Optional[torch.LongTensor] = None,
326
+ attention_mask: Optional[torch.Tensor] = None,
327
+ position_ids: Optional[torch.LongTensor] = None,
328
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
329
+ inputs_embeds: Optional[torch.FloatTensor] = None,
330
+ use_cache: Optional[bool] = None,
331
+ output_attentions: Optional[bool] = None,
332
+ output_hidden_states: Optional[bool] = None,
333
+ return_dict: Optional[bool] = None,
334
+ cache_position: Optional[torch.LongTensor] = None,
335
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
336
+ output_attentions = (
337
+ output_attentions
338
+ if output_attentions is not None
339
+ else self.config.output_attentions
340
+ )
341
+ output_hidden_states = (
342
+ output_hidden_states
343
+ if output_hidden_states is not None
344
+ else self.config.output_hidden_states
345
+ )
346
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
347
+ use_cache = False
348
+ return_dict = (
349
+ return_dict if return_dict is not None else self.config.use_return_dict
350
+ )
351
+
352
+ if (input_ids is None) ^ (inputs_embeds is not None):
353
+ raise ValueError(
354
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
355
+ )
356
+ if self.gradient_checkpointing and self.training and use_cache:
357
+ warnings.warn(
358
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.",
359
+ DeprecationWarning,
360
+ stacklevel=2,
361
+ )
362
+ use_cache = False
363
+
364
+ if inputs_embeds is None:
365
+ inputs_embeds = self.embed_tokens(input_ids)
366
+
367
+ return_legacy_cache = False
368
+ if (
369
+ use_cache and not isinstance(past_key_values, Cache) and not self.training
370
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
371
+ return_legacy_cache = True
372
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
373
+ warnings.warn(
374
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
375
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)",
376
+ DeprecationWarning,
377
+ stacklevel=2,
378
+ )
379
+
380
+ if cache_position is None:
381
+ past_seen_tokens = (
382
+ past_key_values.get_seq_length() if past_key_values is not None else 0
383
+ )
384
+ if inputs_embeds.is_nested:
385
+ seq_len = inputs_embeds._get_max_seqlen()
386
+ else:
387
+ seq_len = inputs_embeds.shape[1]
388
+ cache_position = torch.arange(
389
+ past_seen_tokens,
390
+ past_seen_tokens + seq_len,
391
+ device=inputs_embeds.device,
392
+ )
393
+ if position_ids is None:
394
+ position_ids = cache_position.unsqueeze(0)
395
+ if not inputs_embeds.is_nested:
396
+ causal_mask = self._update_causal_mask(
397
+ attention_mask,
398
+ inputs_embeds,
399
+ cache_position,
400
+ past_key_values,
401
+ )
402
+
403
+ else:
404
+ causal_mask = None
405
+ hidden_states = inputs_embeds
406
+
407
+ # create position embeddings to be shared across the decoder layers
408
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
409
+
410
+ # decoder layers
411
+ all_hidden_states = () if output_hidden_states else None
412
+ all_self_attns = () if output_attentions else None
413
+ next_decoder_cache = None
414
+
415
+ for decoder_layer in self.layers:
416
+ if output_hidden_states:
417
+ all_hidden_states += (hidden_states,)
418
+
419
+ if self.gradient_checkpointing and self.training:
420
+ layer_outputs = self._gradient_checkpointing_func(
421
+ decoder_layer.__call__,
422
+ hidden_states,
423
+ causal_mask,
424
+ position_ids,
425
+ past_key_values,
426
+ output_attentions,
427
+ use_cache,
428
+ cache_position,
429
+ position_embeddings,
430
+ )
431
+ else:
432
+ layer_outputs = decoder_layer(
433
+ hidden_states,
434
+ attention_mask=causal_mask,
435
+ position_ids=position_ids,
436
+ past_key_value=past_key_values,
437
+ output_attentions=output_attentions,
438
+ use_cache=use_cache,
439
+ cache_position=cache_position,
440
+ position_embeddings=position_embeddings,
441
+ )
442
+
443
+ hidden_states = layer_outputs[0]
444
+
445
+ if use_cache:
446
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
447
+
448
+ if output_attentions:
449
+ all_self_attns += (layer_outputs[1],)
450
+
451
+ hidden_states = self.norm(hidden_states)
452
+
453
+ # add hidden states from the last decoder layer
454
+ if output_hidden_states:
455
+ all_hidden_states += (hidden_states,)
456
+
457
+ next_cache = next_decoder_cache if use_cache else None
458
+ if return_legacy_cache:
459
+ next_cache = next_cache.to_legacy_cache()
460
+
461
+ if not return_dict:
462
+ return tuple(
463
+ v
464
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
465
+ if v is not None
466
+ )
467
+ return BaseModelOutputWithPast(
468
+ last_hidden_state=hidden_states,
469
+ past_key_values=next_cache,
470
+ hidden_states=all_hidden_states,
471
+ attentions=all_self_attns,
472
+ )
473
+
474
+
475
+ class DramaModel(LlamaBiModel):
476
+ """
477
+ DramaModel is a modified version of the LlamaModel that supports bi-directional attention
478
+ and provides query and document encoding functionalities.
479
+ """
480
+
481
+ def __init__(self, config: LlamaConfig):
482
+ """
483
+ Initializes the DramaModel by disabling causal masking in self-attention layers.
484
+ """
485
+ super().__init__(config)
486
+ for layer in self.layers:
487
+ layer.self_attn.is_causal = False
488
+ # query prefix
489
+ self.query_prefix = "Query: "
490
+ self.max_seq_len = 8192
491
+ self.hidden_size = config.hidden_size
492
+
493
+ def _average_pool(
494
+ self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
495
+ ) -> torch.Tensor:
496
+ """
497
+ Computes the average pooled representation of the last hidden states.
498
+ """
499
+ last_hidden = last_hidden_states.masked_fill(
500
+ ~attention_mask[..., None].bool(), 0.0
501
+ )
502
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
503
+
504
+ def _tokenize(
505
+ self,
506
+ tokenizer: PreTrainedTokenizer,
507
+ texts: list[str],
508
+ max_seq_len: int = None,
509
+ use_nested: bool = False,
510
+ ):
511
+ """
512
+ Tokenizes input text sequences with optional sequence length restriction.
513
+ """
514
+ if max_seq_len is None:
515
+ max_seq_len = self.max_seq_len
516
+ if use_nested:
517
+ tokenized = tokenizer(
518
+ texts,
519
+ truncation=True,
520
+ max_length=max_seq_len,
521
+ return_length=True,
522
+ )
523
+ tokenized.input_ids = torch.nested.nested_tensor(
524
+ tokenized.input_ids, layout=torch.jagged
525
+ ).to(self.device)
526
+ tokenized.attention_mask = None
527
+ else:
528
+ tokenized = tokenizer(
529
+ texts,
530
+ padding=True,
531
+ truncation=True,
532
+ max_length=max_seq_len,
533
+ return_tensors="pt",
534
+ ).to(self.device)
535
+ tokenizer_ouput = {}
536
+ tokenizer_ouput["input_ids"] = tokenized.input_ids
537
+ tokenizer_ouput["attention_mask"] = tokenized.attention_mask
538
+ return tokenizer_ouput
539
+
540
+ def encode(self, input_ids, attention_mask, dim, *args, **kwargs):
541
+ """
542
+ Pass through the model and compute normalized embeddings.
543
+
544
+ Args:
545
+ input_ids (torch.Tensor): Input token IDs.
546
+ attention_mask (torch.Tensor): Attention mask tensor.
547
+ dim (int): Dimensionality for output embeddings.
548
+
549
+ Returns:
550
+ torch.Tensor: Normalized output embeddings.
551
+ """
552
+
553
+ outputs = self.forward(
554
+ input_ids, attention_mask, *args, **kwargs
555
+ ).last_hidden_state
556
+ if not outputs.is_nested:
557
+ if dim is not None:
558
+ outputs = outputs[:, :, :dim]
559
+ embeddings = self._average_pool(outputs, attention_mask)
560
+ else:
561
+ if dim is not None:
562
+ outputs, _ = outputs.split_with_sizes(
563
+ split_sizes=[dim, outputs.shape[-1] - dim], dim=-1
564
+ )
565
+ embeddings = outputs.sum(dim=-2)
566
+ # normalize embeddings
567
+ embeddings = F.normalize(embeddings, p=2, dim=1)
568
+ return embeddings
569
+
570
+ def encode_queries(
571
+ self,
572
+ tokenizer: PreTrainedTokenizer,
573
+ queries: list[str],
574
+ max_seq_len: int = None,
575
+ dim: int = None,
576
+ use_nested: bool = False,
577
+ ):
578
+ """
579
+ Encodes a list of queries into embeddings.
580
+
581
+ Args:
582
+ tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
583
+ queries (list[str]): List of query texts.
584
+ max_seq_len (int, optional): Maximum sequence length.
585
+ dim (int, optional): Dimensionality for output embeddings.
586
+
587
+ Returns:
588
+ torch.Tensor: Encoded query embeddings in shape (num_queries, dim).
589
+ """
590
+ if not queries:
591
+ raise ValueError("queries must not be empty.")
592
+ if not isinstance(queries, list) or not all(
593
+ isinstance(q, str) for q in queries
594
+ ):
595
+ raise ValueError("queries must be a list of strings.")
596
+ if tokenizer is None:
597
+ raise ValueError("tokenizer must not be None.")
598
+ if dim is not None and (dim < 1 or dim > self.hidden_size):
599
+ raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
600
+ queries = [self.query_prefix + query for query in queries]
601
+ tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len, use_nested)
602
+ embeddings = self.encode(**tokenized_queries, dim=dim)
603
+ return embeddings
604
+
605
+ def encode_documents(
606
+ self,
607
+ tokenizer: PreTrainedTokenizer,
608
+ documents: list[str],
609
+ max_seq_len: int = None,
610
+ dim: int = None,
611
+ use_nested: bool = False,
612
+ ):
613
+ """
614
+ Encodes a list of documents into embeddings.
615
+
616
+ Args:
617
+ tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
618
+ documents (list[str]): List of document texts.
619
+ max_seq_len (int, optional): Maximum sequence length.
620
+ dim (int, optional): Dimensionality for output embeddings.
621
+
622
+ Returns:
623
+ torch.Tensor: Encoded document embeddings in shape (num_documents, dim).
624
+ """
625
+ if not documents:
626
+ raise ValueError("documents must not be empty.")
627
+ if not isinstance(documents, list) or not all(
628
+ isinstance(d, str) for d in documents
629
+ ):
630
+ raise ValueError("documents must be a list of strings.")
631
+ if tokenizer is None:
632
+ raise ValueError("tokenizer must not be None.")
633
+ if dim is not None and (dim < 1 or dim > self.hidden_size):
634
+ raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
635
+ tokenized_documents = self._tokenize(
636
+ tokenizer, documents, max_seq_len, use_nested
637
+ )
638
+ embeddings = self.encode(**tokenized_documents, dim=dim)
639
+ return embeddings
modeling_drama_non_nested.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from transformers import LlamaConfig, LlamaModel, PreTrainedTokenizer
9
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
10
+
11
+
12
+ class DramaModel(LlamaModel):
13
+ """
14
+ DramaModel is a modified version of the LlamaModel that supports bi-directional attention
15
+ and provides query and document encoding functionalities.
16
+ """
17
+
18
+ def __init__(self, config: LlamaConfig):
19
+ """
20
+ Initializes the DramaModel by disabling causal masking in self-attention layers.
21
+ """
22
+ super().__init__(config)
23
+ for layer in self.layers:
24
+ layer.self_attn.is_causal = False
25
+ # query prefix
26
+ self.query_prefix = "Query: "
27
+ self.max_seq_len = 8192
28
+ self.hidden_size = config.hidden_size
29
+
30
+ def _update_causal_mask(
31
+ self,
32
+ attention_mask: torch.Tensor,
33
+ input_tensor: torch.Tensor,
34
+ cache_position: torch.Tensor,
35
+ past_seen_tokens=None,
36
+ output_attentions=False,
37
+ ):
38
+ """
39
+ Updates the causal mask for attention computations.
40
+ """
41
+ if self.config._attn_implementation == "flash_attention_2":
42
+ if attention_mask is not None and (attention_mask == 0.0).any():
43
+ return attention_mask
44
+ return None
45
+ if attention_mask is None or attention_mask.dim() == 4:
46
+ return attention_mask
47
+
48
+ return AttentionMaskConverter._expand_mask(
49
+ mask=attention_mask,
50
+ dtype=input_tensor.dtype,
51
+ )
52
+
53
+ def _average_pool(
54
+ self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
55
+ ) -> torch.Tensor:
56
+ """
57
+ Computes the average pooled representation of the last hidden states.
58
+ """
59
+ last_hidden = last_hidden_states.masked_fill(
60
+ ~attention_mask[..., None].bool(), 0.0
61
+ )
62
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
63
+
64
+ def _tokenize(
65
+ self,
66
+ tokenizer: PreTrainedTokenizer,
67
+ texts: list[str],
68
+ max_seq_len: int = None,
69
+ use_nested: bool = False, # Added for API compatibility with nested version
70
+ ):
71
+ """
72
+ Tokenizes input text sequences with optional sequence length restriction.
73
+ """
74
+ if max_seq_len is None:
75
+ max_seq_len = self.max_seq_len
76
+ tokenized = tokenizer(
77
+ texts,
78
+ padding=True,
79
+ truncation=True,
80
+ max_length=max_seq_len,
81
+ return_tensors="pt",
82
+ ).to(self.device)
83
+ return tokenized
84
+
85
+ def encode(self, input_ids, attention_mask, dim, *args, **kwargs):
86
+ """
87
+ Pass through the model and compute normalized embeddings.
88
+
89
+ Args:
90
+ input_ids (torch.Tensor): Input token IDs.
91
+ attention_mask (torch.Tensor): Attention mask tensor.
92
+ dim (int): Dimensionality for output embeddings.
93
+
94
+ Returns:
95
+ torch.Tensor: Normalized output embeddings.
96
+ """
97
+ outputs = self.forward(input_ids, attention_mask, *args, **kwargs)
98
+ embeddings = self._average_pool(
99
+ outputs.last_hidden_state[:, :, :dim], attention_mask
100
+ )
101
+ # normalize embeddings
102
+ embeddings = F.normalize(embeddings, p=2, dim=1)
103
+ return embeddings
104
+
105
+ def encode_queries(
106
+ self,
107
+ tokenizer: PreTrainedTokenizer,
108
+ queries: list[str],
109
+ max_seq_len: int = None,
110
+ dim: int = None,
111
+ use_nested: bool = False, # Added for API compatibility with nested version
112
+ ):
113
+ """
114
+ Encodes a list of queries into embeddings.
115
+
116
+ Args:
117
+ tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
118
+ queries (list[str]): List of query texts.
119
+ max_seq_len (int, optional): Maximum sequence length.
120
+ dim (int, optional): Dimensionality for output embeddings.
121
+
122
+ Returns:
123
+ torch.Tensor: Encoded query embeddings in shape (num_queries, dim).
124
+ """
125
+ if not queries:
126
+ raise ValueError("queries must not be empty.")
127
+ if not isinstance(queries, list) or not all(
128
+ isinstance(q, str) for q in queries
129
+ ):
130
+ raise ValueError("queries must be a list of strings.")
131
+ if tokenizer is None:
132
+ raise ValueError("tokenizer must not be None.")
133
+ if dim is not None and (dim < 1 or dim > self.hidden_size):
134
+ raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
135
+ if use_nested:
136
+ warnings.warn(
137
+ "use_nested is not supported due to package import versions.",
138
+ UserWarning,
139
+ )
140
+ queries = [self.query_prefix + query for query in queries]
141
+ tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len, use_nested)
142
+ embeddings = self.encode(**tokenized_queries, dim=dim)
143
+ return embeddings
144
+
145
+ def encode_documents(
146
+ self,
147
+ tokenizer: PreTrainedTokenizer,
148
+ documents: list[str],
149
+ max_seq_len: int = None,
150
+ dim: int = None,
151
+ use_nested: bool = False, # Added for API compatibility with nested version
152
+ ):
153
+ """
154
+ Encodes a list of documents into embeddings.
155
+
156
+ Args:
157
+ tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
158
+ documents (list[str]): List of document texts.
159
+ max_seq_len (int, optional): Maximum sequence length.
160
+ dim (int, optional): Dimensionality for output embeddings.
161
+
162
+ Returns:
163
+ torch.Tensor: Encoded document embeddings in shape (num_documents, dim).
164
+ """
165
+ if not documents:
166
+ raise ValueError("documents must not be empty.")
167
+ if not isinstance(documents, list) or not all(
168
+ isinstance(d, str) for d in documents
169
+ ):
170
+ raise ValueError("documents must be a list of strings.")
171
+ if tokenizer is None:
172
+ raise ValueError("tokenizer must not be None.")
173
+ if dim is not None and (dim < 1 or dim > self.hidden_size):
174
+ raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
175
+ if use_nested:
176
+ warnings.warn(
177
+ "use_nested is not supported due to package import versions.",
178
+ UserWarning,
179
+ )
180
+ tokenized_documents = self._tokenize(
181
+ tokenizer, documents, max_seq_len, use_nested
182
+ )
183
+ embeddings = self.encode(**tokenized_documents, dim=dim)
184
+ return embeddings