Shreya Goyal commited on
Commit
842c1e0
·
1 Parent(s): 1d0338c

adding support for NJTs

Browse files
Files changed (2) hide show
  1. README.md +9 -8
  2. modeling_drama.py +530 -57
README.md CHANGED
@@ -60,9 +60,10 @@ model_name = "facebook/drama-base"
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
  tokenizer = AutoTokenizer.from_pretrained(model_name)
62
  model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
 
 
 
63
 
64
- query_embs = model.encode_queries(tokenizer, queries)
65
- doc_embs = model.encode_documents(tokenizer, documents)
66
 
67
  scores = query_embs @ doc_embs.T
68
  print(scores.tolist())
@@ -77,8 +78,8 @@ print(scores.tolist())
77
  DRAMA models are trained using Matryoshka Representation Learning ([MRL](https://github.com/RAIVNLab/MRL)) to support flexible dimensionality. Both queries and documents can be encoded into smaller dimensions, such as 256, using the following:
78
 
79
  ```python
80
- query_embs = model.encode_queries(tokenizer, queries, dim=256)
81
- doc_embs = model.encode_documents(tokenizer, documents, dim=256)
82
 
83
  scores = query_embs @ doc_embs.T
84
  print(scores.tolist())
@@ -101,8 +102,8 @@ documents = [
101
 
102
  model = SentenceTransformer("facebook/drama-base", trust_remote_code=True)
103
 
104
- query_embs = model.encode(queries, prompt_name="query")
105
- doc_embs = model.encode(documents)
106
 
107
  scores = model.similarity(query_embs, doc_embs)
108
  print(scores.tolist())
@@ -128,8 +129,8 @@ documents = [
128
 
129
  model = SentenceTransformer("facebook/drama-base", truncate_dim=256, trust_remote_code=True)
130
 
131
- query_embs = model.encode(queries, prompt_name="query")
132
- doc_embs = model.encode(documents)
133
 
134
  scores = model.similarity(query_embs, doc_embs)
135
  print(scores.tolist())
 
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
  tokenizer = AutoTokenizer.from_pretrained(model_name)
62
  model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
63
+ use_nested = False
64
+ query_embs = model.encode_queries(tokenizer, queries, use_nested=use_nested)
65
+ doc_embs = model.encode_documents(tokenizer, documents, use_nested=use_nested)
66
 
 
 
67
 
68
  scores = query_embs @ doc_embs.T
69
  print(scores.tolist())
 
78
  DRAMA models are trained using Matryoshka Representation Learning ([MRL](https://github.com/RAIVNLab/MRL)) to support flexible dimensionality. Both queries and documents can be encoded into smaller dimensions, such as 256, using the following:
79
 
80
  ```python
81
+ query_embs = model.encode_queries(tokenizer, queries, dim=256, use_nested=use_nested)
82
+ doc_embs = model.encode_documents(tokenizer, documents, dim=256, use_nested=use_nested)
83
 
84
  scores = query_embs @ doc_embs.T
85
  print(scores.tolist())
 
102
 
103
  model = SentenceTransformer("facebook/drama-base", trust_remote_code=True)
104
 
105
+ query_embs = model.encode(queries, prompt_name="query", use_nested=use_nested)
106
+ doc_embs = model.encode(documents, use_nested=use_nested)
107
 
108
  scores = model.similarity(query_embs, doc_embs)
109
  print(scores.tolist())
 
129
 
130
  model = SentenceTransformer("facebook/drama-base", truncate_dim=256, trust_remote_code=True)
131
 
132
+ query_embs = model.encode(queries, prompt_name="query", use_nested=use_nested)
133
+ doc_embs = model.encode(documents, use_nested=use_nested)
134
 
135
  scores = model.similarity(query_embs, doc_embs)
136
  print(scores.tolist())
modeling_drama.py CHANGED
@@ -1,30 +1,302 @@
1
  from __future__ import annotations
2
 
3
  import torch
 
4
  import torch.nn.functional as F
 
5
 
6
- from transformers import LlamaModel, LlamaConfig, PreTrainedTokenizer
 
 
 
 
 
 
7
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
- class DramaModel(LlamaModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  """
12
- DramaModel is a modified version of the LlamaModel that supports bi-directional attention
13
- and provides query and document encoding functionalities.
14
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- def __init__(self, config: LlamaConfig):
17
- """
18
- Initializes the DramaModel by disabling causal masking in self-attention layers.
19
- """
20
- super().__init__(config)
21
- for layer in self.layers:
22
- layer.self_attn.is_causal = False
23
- # query prefix
24
- self.query_prefix = "Query: "
25
- self.max_seq_len = 8192
26
- self.hidden_size = config.hidden_size
27
-
28
  def _update_causal_mask(
29
  self,
30
  attention_mask: torch.Tensor,
@@ -42,12 +314,182 @@ class DramaModel(LlamaModel):
42
  return None
43
  if attention_mask is None or attention_mask.dim() == 4:
44
  return attention_mask
45
-
46
  return AttentionMaskConverter._expand_mask(
47
  mask=attention_mask,
48
  dtype=input_tensor.dtype,
49
  )
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def _average_pool(
52
  self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
53
  ) -> torch.Tensor:
@@ -60,107 +502,138 @@ class DramaModel(LlamaModel):
60
  return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
61
 
62
  def _tokenize(
63
- self,
64
- tokenizer: PreTrainedTokenizer,
65
- texts: list[str],
66
- max_seq_len: int = None,
67
- ):
 
68
  """
69
  Tokenizes input text sequences with optional sequence length restriction.
70
  """
71
  if max_seq_len is None:
72
  max_seq_len = self.max_seq_len
73
- tokenized = tokenizer(
74
- texts,
75
- padding=True,
76
- truncation=True,
77
- max_length=max_seq_len,
78
- return_tensors='pt',
79
- ).to(self.device)
80
- return tokenized
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def encode(self, input_ids, attention_mask, dim, *args, **kwargs):
83
  """
84
  Pass through the model and compute normalized embeddings.
85
-
86
  Args:
87
  input_ids (torch.Tensor): Input token IDs.
88
  attention_mask (torch.Tensor): Attention mask tensor.
89
  dim (int): Dimensionality for output embeddings.
90
-
91
  Returns:
92
  torch.Tensor: Normalized output embeddings.
93
  """
 
94
  outputs = self.forward(
95
  input_ids, attention_mask, *args, **kwargs
96
- )
97
- embeddings = self._average_pool(
98
- outputs.last_hidden_state[:, :, :dim], attention_mask
99
- )
 
 
 
 
 
 
 
100
  # normalize embeddings
101
  embeddings = F.normalize(embeddings, p=2, dim=1)
102
  return embeddings
103
 
104
  def encode_queries(
105
- self,
106
- tokenizer: PreTrainedTokenizer,
107
- queries: list[str],
108
- max_seq_len: int = None,
109
- dim: int = None,
110
- ):
 
111
  """
112
  Encodes a list of queries into embeddings.
113
-
114
  Args:
115
  tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
116
  queries (list[str]): List of query texts.
117
  max_seq_len (int, optional): Maximum sequence length.
118
  dim (int, optional): Dimensionality for output embeddings.
119
-
120
  Returns:
121
  torch.Tensor: Encoded query embeddings in shape (num_queries, dim).
122
  """
123
  if not queries:
124
  raise ValueError("queries must not be empty.")
125
- if not isinstance(queries, list) or not all(isinstance(q, str) for q in queries):
 
 
126
  raise ValueError("queries must be a list of strings.")
127
  if tokenizer is None:
128
  raise ValueError("tokenizer must not be None.")
129
  if dim is not None and (dim < 1 or dim > self.hidden_size):
130
  raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
131
  queries = [self.query_prefix + query for query in queries]
132
- tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len)
133
  embeddings = self.encode(**tokenized_queries, dim=dim)
134
  return embeddings
135
 
136
  def encode_documents(
137
- self,
138
- tokenizer: PreTrainedTokenizer,
139
- documents: list[str],
140
- max_seq_len: int = None,
141
- dim: int = None,
142
- ):
 
143
  """
144
  Encodes a list of documents into embeddings.
145
-
146
  Args:
147
  tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
148
  documents (list[str]): List of document texts.
149
  max_seq_len (int, optional): Maximum sequence length.
150
  dim (int, optional): Dimensionality for output embeddings.
151
-
152
  Returns:
153
  torch.Tensor: Encoded document embeddings in shape (num_documents, dim).
154
  """
155
  if not documents:
156
  raise ValueError("documents must not be empty.")
157
- if not isinstance(documents, list) or not all(isinstance(d, str) for d in documents):
 
 
158
  raise ValueError("documents must be a list of strings.")
159
  if tokenizer is None:
160
  raise ValueError("tokenizer must not be None.")
161
  if dim is not None and (dim < 1 or dim > self.hidden_size):
162
  raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
163
- tokenized_documents = self._tokenize(tokenizer, documents, max_seq_len)
 
 
164
  embeddings = self.encode(**tokenized_documents, dim=dim)
165
  return embeddings
166
-
 
1
  from __future__ import annotations
2
 
3
  import torch
4
+ import torch.nn as nn
5
  import torch.nn.functional as F
6
+ from torch.nested._internal.nested_tensor import nested_from_padded
7
 
8
+ from transformers import (
9
+ LlamaConfig,
10
+ LlamaModel,
11
+ LlamaPreTrainedModel,
12
+ PreTrainedTokenizer,
13
+ )
14
+ from transformers.cache_utils import Cache, DynamicCache
15
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
16
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
17
+ from transformers.modeling_outputs import BaseModelOutputWithPast
18
+ from transformers.models.llama.modeling_llama import (
19
+ LlamaAttention,
20
+ LlamaDecoderLayer,
21
+ LlamaMLP,
22
+ LlamaRMSNorm,
23
+ LlamaRotaryEmbedding,
24
+ rotate_half,
25
+ )
26
+ from transformers.processing_utils import Unpack
27
 
28
 
29
+ class ModifiedLlamaAttention(LlamaAttention):
30
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
31
+ super().__init__(*args, **kwargs)
32
+ self.is_causal = False
33
+
34
+ def forward(
35
+ self,
36
+ hidden_states: torch.Tensor,
37
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
38
+ attention_mask: Optional[torch.Tensor],
39
+ past_key_value: Optional[Cache] = None,
40
+ cache_position: Optional[torch.LongTensor] = None,
41
+ **kwargs: Unpack[FlashAttentionKwargs],
42
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
43
+ input_shape = hidden_states.shape[:-1]
44
+ hidden_shape = (*input_shape, -1, self.head_dim)
45
+
46
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
47
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
48
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
49
+
50
+ cos, sin = position_embeddings
51
+ query_states, key_states = apply_rotary_pos_emb(
52
+ query_states, key_states, cos, sin
53
+ )
54
+
55
+ if past_key_value is not None:
56
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
57
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
58
+ key_states, value_states = past_key_value.update(
59
+ key_states, value_states, self.layer_idx, cache_kwargs
60
+ )
61
+
62
+ if self.config._attn_implementation != "eager":
63
+ if self.config._attn_implementation == "sdpa" and kwargs.get(
64
+ "output_attentions", False
65
+ ):
66
+ warnings.warn(
67
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
68
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
69
+ )
70
+
71
+ attn_output, attn_weights = sdpa_attention_forward(
72
+ self,
73
+ query_states,
74
+ key_states,
75
+ value_states,
76
+ attention_mask,
77
+ dropout=0.0,
78
+ scaling=self.scaling,
79
+ is_causal=False,
80
+ **kwargs,
81
+ )
82
+
83
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
84
+ attn_output = self.o_proj(attn_output)
85
+ return attn_output, attn_weights
86
+
87
+
88
+ def sdpa_attention_forward(
89
+ module: torch.nn.Module,
90
+ query: torch.Tensor,
91
+ key: torch.Tensor,
92
+ value: torch.Tensor,
93
+ attention_mask: torch.Tensor,
94
+ dropout: float = 0.0,
95
+ scaling: Optional[float] = None,
96
+ is_causal: Optional[bool] = None,
97
+ **kwargs: Any,
98
+ ) -> Tuple[torch.Tensor, None]:
99
+ if hasattr(module, "num_key_value_groups"):
100
+ if key.is_nested:
101
+ key = repeat_jagged_kv(key, module.num_key_value_groups)
102
+ value = repeat_jagged_kv(value, module.num_key_value_groups)
103
+ else:
104
+ key = repeat_dense_kv(key, module.num_key_value_groups)
105
+ value = repeat_dense_kv(value, module.num_key_value_groups)
106
+
107
+ causal_mask = attention_mask
108
+ if attention_mask is not None and causal_mask.ndim == 4:
109
+ causal_mask = causal_mask[:, :, :, : key.shape[-2]]
110
+
111
+ # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
112
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
113
+ query = query.contiguous()
114
+ key = key.contiguous()
115
+ value = value.contiguous()
116
+
117
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
118
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
119
+ # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
120
+ if is_causal is None:
121
+ is_causal = query.shape[2] > 1 and causal_mask is None
122
+
123
+ # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
124
+ # We convert it to a bool for the SDPA kernel that only accepts bools.
125
+ if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
126
+ is_causal = is_causal.item()
127
+
128
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
129
+ query,
130
+ key,
131
+ value,
132
+ attn_mask=causal_mask,
133
+ dropout_p=dropout,
134
+ scale=scaling,
135
+ is_causal=is_causal,
136
+ )
137
+ attn_output = attn_output.transpose(1, 2).contiguous()
138
+
139
+ return attn_output, None
140
+
141
+
142
+ def repeat_jagged_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
143
  """
144
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
145
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
146
  """
147
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
148
+ expand_shape = (batch, num_key_value_heads, -1, n_rep, head_dim)
149
+ if n_rep == 1:
150
+ return hidden_states
151
+ hidden_states = (
152
+ hidden_states.unsqueeze(3)
153
+ .expand(expand_shape)
154
+ .transpose(1, 2)
155
+ .flatten(2, 3)
156
+ .transpose(1, 2)
157
+ )
158
+ return hidden_states
159
+
160
+
161
+ def repeat_dense_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
162
+ """
163
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
164
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
165
+ """
166
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
167
+ if n_rep == 1:
168
+ return hidden_states
169
+ hidden_states = hidden_states[:, :, None, :, :].expand(
170
+ batch, num_key_value_heads, n_rep, slen, head_dim
171
+ )
172
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
173
+
174
+
175
+ def apply_rotary_pos_emb(
176
+ q: torch.Tensor,
177
+ k: torch.Tensor,
178
+ cos: torch.Tensor,
179
+ sin: torch.Tensor,
180
+ unsqueeze_dim: int = 1,
181
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
182
+ """Applies Rotary Position Embedding to the query and key tensors.
183
+
184
+ Args:
185
+ q (`torch.Tensor`): The query tensor.
186
+ k (`torch.Tensor`): The key tensor.
187
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
188
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
189
+ position_ids (`torch.Tensor`, *optional*):
190
+ Deprecated and unused.
191
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
192
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
193
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
194
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
195
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
196
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
197
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
198
+ Returns:
199
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
200
+ """
201
+ cos = cos.unsqueeze(unsqueeze_dim)
202
+ sin = sin.unsqueeze(unsqueeze_dim)
203
+ if q.is_nested and k.is_nested:
204
+ if q.layout != torch.jagged:
205
+ raise NotImplementedError(f"Unsupported layout: {q.layout}")
206
+ if k.layout != torch.jagged:
207
+ raise NotImplementedError(f"Unsupported layout: {k.layout}")
208
+ return _jagged_tensor_forward(q, k, cos, sin)
209
+ else:
210
+ return _padded_tensor_forward(q, k, cos, sin)
211
+
212
+
213
+ def _jagged_tensor_forward(
214
+ q: torch.Tensor,
215
+ k: torch.Tensor,
216
+ cos: torch.Tensor,
217
+ sin: torch.Tensor,
218
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
219
+ q_dense = q.to_padded_tensor(0.0)
220
+ k_dense = k.to_padded_tensor(0.0)
221
+ q_dense_embed = (q_dense * cos) + (rotate_half(q_dense) * sin)
222
+ k_dense_embed = (k_dense * cos) + (rotate_half(k_dense) * sin)
223
+ q_jagged_embed = convert_dense_to_jagged(q, q_dense_embed)
224
+ k_jagged_embed = convert_dense_to_jagged(k, k_dense_embed)
225
+ return q_jagged_embed, k_jagged_embed
226
+
227
+
228
+ def _padded_tensor_forward(
229
+ q: torch.Tensor,
230
+ k: torch.Tensor,
231
+ cos: torch.Tensor,
232
+ sin: torch.Tensor,
233
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
234
+ q_embed = (q * cos) + (rotate_half(q) * sin)
235
+ k_embed = (k * cos) + (rotate_half(k) * sin)
236
+ return q_embed, k_embed
237
+
238
+
239
+ def convert_dense_to_jagged(nested_q: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
240
+ padded_max_S = nested_q._get_max_seqlen()
241
+ total_L = nested_q._values.shape[nested_q._ragged_idx - 1]
242
+ if padded_max_S is None:
243
+ # use upper bound on max seqlen if it's not present
244
+ padded_max_S = total_L
245
+
246
+ # convert dense tensor -> jagged
247
+ q = q.expand(
248
+ [
249
+ x if i != nested_q._ragged_idx else padded_max_S
250
+ for i, x in enumerate(q.shape)
251
+ ]
252
+ )
253
+ nested_result = nested_from_padded(
254
+ q,
255
+ offsets=nested_q._offsets,
256
+ ragged_idx=nested_q._ragged_idx,
257
+ sum_S=total_L,
258
+ min_seqlen=nested_q._get_min_seqlen(),
259
+ max_seqlen=padded_max_S,
260
+ )
261
+ return nested_result
262
+
263
+
264
+ class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
265
+ def __init__(self, config: LlamaConfig, layer_idx: int) -> None:
266
+ nn.Module.__init__(self)
267
+ self.hidden_size: int = config.hidden_size
268
+
269
+ self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx)
270
+
271
+ self.mlp = LlamaMLP(config)
272
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
273
+ self.post_attention_layernorm = LlamaRMSNorm(
274
+ config.hidden_size, eps=config.rms_norm_eps
275
+ )
276
+
277
+
278
+ class LlamaBiModel(LlamaModel):
279
+ def __init__(self, config: LlamaConfig) -> None:
280
+ LlamaPreTrainedModel.__init__(self, config)
281
+ self.padding_idx: int = config.pad_token_id
282
+ self.vocab_size: int = config.vocab_size
283
+
284
+ self.embed_tokens = nn.Embedding(
285
+ config.vocab_size, config.hidden_size, self.padding_idx
286
+ )
287
+ self.layers = nn.ModuleList(
288
+ [
289
+ ModifiedLlamaDecoderLayer(config, layer_idx)
290
+ for layer_idx in range(config.num_hidden_layers)
291
+ ]
292
+ )
293
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
294
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
295
+ self.gradient_checkpointing = False
296
+
297
+ # Initialize weights and apply final processing
298
+ self.post_init()
299
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  def _update_causal_mask(
301
  self,
302
  attention_mask: torch.Tensor,
 
314
  return None
315
  if attention_mask is None or attention_mask.dim() == 4:
316
  return attention_mask
317
+
318
  return AttentionMaskConverter._expand_mask(
319
  mask=attention_mask,
320
  dtype=input_tensor.dtype,
321
  )
322
 
323
+ def forward(
324
+ self,
325
+ input_ids: Optional[torch.LongTensor] = None,
326
+ attention_mask: Optional[torch.Tensor] = None,
327
+ position_ids: Optional[torch.LongTensor] = None,
328
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
329
+ inputs_embeds: Optional[torch.FloatTensor] = None,
330
+ use_cache: Optional[bool] = None,
331
+ output_attentions: Optional[bool] = None,
332
+ output_hidden_states: Optional[bool] = None,
333
+ return_dict: Optional[bool] = None,
334
+ cache_position: Optional[torch.LongTensor] = None,
335
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
336
+ output_attentions = (
337
+ output_attentions
338
+ if output_attentions is not None
339
+ else self.config.output_attentions
340
+ )
341
+ output_hidden_states = (
342
+ output_hidden_states
343
+ if output_hidden_states is not None
344
+ else self.config.output_hidden_states
345
+ )
346
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
347
+ use_cache = False
348
+ return_dict = (
349
+ return_dict if return_dict is not None else self.config.use_return_dict
350
+ )
351
+
352
+ if (input_ids is None) ^ (inputs_embeds is not None):
353
+ raise ValueError(
354
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
355
+ )
356
+ if self.gradient_checkpointing and self.training and use_cache:
357
+ warnings.warn(
358
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.",
359
+ DeprecationWarning,
360
+ stacklevel=2,
361
+ )
362
+ use_cache = False
363
+
364
+ if inputs_embeds is None:
365
+ inputs_embeds = self.embed_tokens(input_ids)
366
+
367
+ return_legacy_cache = False
368
+ if (
369
+ use_cache and not isinstance(past_key_values, Cache) and not self.training
370
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
371
+ return_legacy_cache = True
372
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
373
+ warnings.warn(
374
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
375
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)",
376
+ DeprecationWarning,
377
+ stacklevel=2,
378
+ )
379
+
380
+ if cache_position is None:
381
+ past_seen_tokens = (
382
+ past_key_values.get_seq_length() if past_key_values is not None else 0
383
+ )
384
+ if inputs_embeds.is_nested:
385
+ seq_len = inputs_embeds._get_max_seqlen()
386
+ else:
387
+ seq_len = inputs_embeds.shape[1]
388
+ cache_position = torch.arange(
389
+ past_seen_tokens,
390
+ past_seen_tokens + seq_len,
391
+ device=inputs_embeds.device,
392
+ )
393
+ if position_ids is None:
394
+ position_ids = cache_position.unsqueeze(0)
395
+ if not inputs_embeds.is_nested:
396
+ causal_mask = self._update_causal_mask(
397
+ attention_mask,
398
+ inputs_embeds,
399
+ cache_position,
400
+ past_key_values,
401
+ )
402
+
403
+ else:
404
+ causal_mask = None
405
+ hidden_states = inputs_embeds
406
+
407
+ # create position embeddings to be shared across the decoder layers
408
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
409
+
410
+ # decoder layers
411
+ all_hidden_states = () if output_hidden_states else None
412
+ all_self_attns = () if output_attentions else None
413
+ next_decoder_cache = None
414
+
415
+ for decoder_layer in self.layers:
416
+ if output_hidden_states:
417
+ all_hidden_states += (hidden_states,)
418
+
419
+ if self.gradient_checkpointing and self.training:
420
+ layer_outputs = self._gradient_checkpointing_func(
421
+ decoder_layer.__call__,
422
+ hidden_states,
423
+ causal_mask,
424
+ position_ids,
425
+ past_key_values,
426
+ output_attentions,
427
+ use_cache,
428
+ cache_position,
429
+ position_embeddings,
430
+ )
431
+ else:
432
+ layer_outputs = decoder_layer(
433
+ hidden_states,
434
+ attention_mask=causal_mask,
435
+ position_ids=position_ids,
436
+ past_key_value=past_key_values,
437
+ output_attentions=output_attentions,
438
+ use_cache=use_cache,
439
+ cache_position=cache_position,
440
+ position_embeddings=position_embeddings,
441
+ )
442
+
443
+ hidden_states = layer_outputs[0]
444
+
445
+ if use_cache:
446
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
447
+
448
+ if output_attentions:
449
+ all_self_attns += (layer_outputs[1],)
450
+
451
+ hidden_states = self.norm(hidden_states)
452
+
453
+ # add hidden states from the last decoder layer
454
+ if output_hidden_states:
455
+ all_hidden_states += (hidden_states,)
456
+
457
+ next_cache = next_decoder_cache if use_cache else None
458
+ if return_legacy_cache:
459
+ next_cache = next_cache.to_legacy_cache()
460
+
461
+ if not return_dict:
462
+ return tuple(
463
+ v
464
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
465
+ if v is not None
466
+ )
467
+ return BaseModelOutputWithPast(
468
+ last_hidden_state=hidden_states,
469
+ past_key_values=next_cache,
470
+ hidden_states=all_hidden_states,
471
+ attentions=all_self_attns,
472
+ )
473
+
474
+
475
+ class DramaModel(LlamaBiModel):
476
+ """
477
+ DramaModel is a modified version of the LlamaModel that supports bi-directional attention
478
+ and provides query and document encoding functionalities.
479
+ """
480
+
481
+ def __init__(self, config: LlamaConfig):
482
+ """
483
+ Initializes the DramaModel by disabling causal masking in self-attention layers.
484
+ """
485
+ super().__init__(config)
486
+ for layer in self.layers:
487
+ layer.self_attn.is_causal = False
488
+ # query prefix
489
+ self.query_prefix = "Query: "
490
+ self.max_seq_len = 8192
491
+ self.hidden_size = config.hidden_size
492
+
493
  def _average_pool(
494
  self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
495
  ) -> torch.Tensor:
 
502
  return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
503
 
504
  def _tokenize(
505
+ self,
506
+ tokenizer: PreTrainedTokenizer,
507
+ texts: list[str],
508
+ max_seq_len: int = None,
509
+ use_nested: bool = False,
510
+ ):
511
  """
512
  Tokenizes input text sequences with optional sequence length restriction.
513
  """
514
  if max_seq_len is None:
515
  max_seq_len = self.max_seq_len
516
+ if use_nested:
517
+ tokenized = tokenizer(
518
+ texts,
519
+ truncation=True,
520
+ max_length=max_seq_len,
521
+ return_length=True,
522
+ )
523
+ tokenized.input_ids = torch.nested.nested_tensor(
524
+ tokenized.input_ids, layout=torch.jagged
525
+ ).to(self.device)
526
+ tokenized.attention_mask = None
527
+ else:
528
+ tokenized = tokenizer(
529
+ texts,
530
+ padding=True,
531
+ truncation=True,
532
+ max_length=max_seq_len,
533
+ return_tensors="pt",
534
+ ).to(self.device)
535
+ tokenizer_ouput = {}
536
+ tokenizer_ouput["input_ids"] = tokenized.input_ids
537
+ tokenizer_ouput["attention_mask"] = tokenized.attention_mask
538
+ return tokenizer_ouput
539
 
540
  def encode(self, input_ids, attention_mask, dim, *args, **kwargs):
541
  """
542
  Pass through the model and compute normalized embeddings.
543
+
544
  Args:
545
  input_ids (torch.Tensor): Input token IDs.
546
  attention_mask (torch.Tensor): Attention mask tensor.
547
  dim (int): Dimensionality for output embeddings.
548
+
549
  Returns:
550
  torch.Tensor: Normalized output embeddings.
551
  """
552
+
553
  outputs = self.forward(
554
  input_ids, attention_mask, *args, **kwargs
555
+ ).last_hidden_state
556
+ if not outputs.is_nested:
557
+ if dim is not None:
558
+ outputs = outputs[:, :, :dim]
559
+ embeddings = self._average_pool(outputs, attention_mask)
560
+ else:
561
+ if dim is not None:
562
+ outputs, _ = outputs.split_with_sizes(
563
+ split_sizes=[dim, outputs.shape[-1] - dim], dim=-1
564
+ )
565
+ embeddings = outputs.sum(dim=-2)
566
  # normalize embeddings
567
  embeddings = F.normalize(embeddings, p=2, dim=1)
568
  return embeddings
569
 
570
  def encode_queries(
571
+ self,
572
+ tokenizer: PreTrainedTokenizer,
573
+ queries: list[str],
574
+ max_seq_len: int = None,
575
+ dim: int = None,
576
+ use_nested: bool = False,
577
+ ):
578
  """
579
  Encodes a list of queries into embeddings.
580
+
581
  Args:
582
  tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
583
  queries (list[str]): List of query texts.
584
  max_seq_len (int, optional): Maximum sequence length.
585
  dim (int, optional): Dimensionality for output embeddings.
586
+
587
  Returns:
588
  torch.Tensor: Encoded query embeddings in shape (num_queries, dim).
589
  """
590
  if not queries:
591
  raise ValueError("queries must not be empty.")
592
+ if not isinstance(queries, list) or not all(
593
+ isinstance(q, str) for q in queries
594
+ ):
595
  raise ValueError("queries must be a list of strings.")
596
  if tokenizer is None:
597
  raise ValueError("tokenizer must not be None.")
598
  if dim is not None and (dim < 1 or dim > self.hidden_size):
599
  raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
600
  queries = [self.query_prefix + query for query in queries]
601
+ tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len, use_nested)
602
  embeddings = self.encode(**tokenized_queries, dim=dim)
603
  return embeddings
604
 
605
  def encode_documents(
606
+ self,
607
+ tokenizer: PreTrainedTokenizer,
608
+ documents: list[str],
609
+ max_seq_len: int = None,
610
+ dim: int = None,
611
+ use_nested: bool = False,
612
+ ):
613
  """
614
  Encodes a list of documents into embeddings.
615
+
616
  Args:
617
  tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
618
  documents (list[str]): List of document texts.
619
  max_seq_len (int, optional): Maximum sequence length.
620
  dim (int, optional): Dimensionality for output embeddings.
621
+
622
  Returns:
623
  torch.Tensor: Encoded document embeddings in shape (num_documents, dim).
624
  """
625
  if not documents:
626
  raise ValueError("documents must not be empty.")
627
+ if not isinstance(documents, list) or not all(
628
+ isinstance(d, str) for d in documents
629
+ ):
630
  raise ValueError("documents must be a list of strings.")
631
  if tokenizer is None:
632
  raise ValueError("tokenizer must not be None.")
633
  if dim is not None and (dim < 1 or dim > self.hidden_size):
634
  raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
635
+ tokenized_documents = self._tokenize(
636
+ tokenizer, documents, max_seq_len, use_nested
637
+ )
638
  embeddings = self.encode(**tokenized_documents, dim=dim)
639
  return embeddings