ekurtic commited on
Commit
5b6cee4
·
verified ·
1 Parent(s): 642153e

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. config.json +56 -0
  2. eagle3.py +570 -0
  3. generation_config.json +4 -0
  4. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Eagle3Speculator"
4
+ ],
5
+ "auto_map": {
6
+ "": "eagle3.Eagle3SpeculatorConfig"
7
+ },
8
+ "draft_vocab_size": 32000,
9
+ "has_no_defaults_at_init": false,
10
+ "norm_before_residual": true,
11
+ "speculators_config": {
12
+ "algorithm": "eagle3",
13
+ "default_proposal_method": "greedy",
14
+ "proposal_methods": [
15
+ {
16
+ "accept_tolerance": 0.0,
17
+ "proposal_type": "greedy",
18
+ "speculative_tokens": 5,
19
+ "verifier_accept_k": 1
20
+ }
21
+ ],
22
+ "verifier": {
23
+ "architectures": [
24
+ "LlamaForCausalLM"
25
+ ],
26
+ "name_or_path": "/proving-grounds/machine/eldarkurtic/hf_downloads/meta-llama/Llama-3.1-8B-Instruct"
27
+ }
28
+ },
29
+ "speculators_model_type": "eagle3",
30
+ "speculators_version": "0.2.0.dev11",
31
+ "target_hidden_size": null,
32
+ "torch_dtype": "bfloat16",
33
+ "transformer_layer_config": {
34
+ "attention_bias": false,
35
+ "attention_dropout": 0.0,
36
+ "head_dim": 128,
37
+ "hidden_act": "silu",
38
+ "hidden_size": 4096,
39
+ "initializer_range": 0.02,
40
+ "intermediate_size": 14336,
41
+ "max_position_embeddings": 131072,
42
+ "mlp_bias": false,
43
+ "model_type": "llama",
44
+ "num_attention_heads": 32,
45
+ "num_hidden_layers": 1,
46
+ "num_key_value_heads": 8,
47
+ "pretraining_tp": 1,
48
+ "rms_norm_eps": 1e-05,
49
+ "rope_scaling": null,
50
+ "rope_theta": 500000.0,
51
+ "torch_dtype": "bfloat16",
52
+ "use_cache": true,
53
+ "vocab_size": 128256
54
+ },
55
+ "transformers_version": "4.53.2"
56
+ }
eagle3.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speculators implementation of EAGLE-3:
3
+ - https://arxiv.org/abs/2503.01840
4
+
5
+ Classes:
6
+ Eagle3SpeculatorConfig: Configuration class for EAGLE-3 speculator model
7
+ EagleSpeculator3: Main model implementation for EAGLE-3 speculators
8
+ Eagle3Attention: Custom attention layer for EAGLE-3, processes
9
+ concatenated embeddings and hidden states
10
+ Eagle3DecoderLayer: Custom decoder layer for EAGLE-3, processes
11
+ concatenated embeddings and hidden states with Eagle3Attention
12
+ and support for moving hidden layernorm before residual
13
+ """
14
+
15
+ import os
16
+ from typing import Any, ClassVar, Literal, Optional, Union
17
+
18
+ import torch
19
+ from pydantic import Field, field_serializer, field_validator
20
+ from torch import nn
21
+ from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+ from transformers.models.llama.configuration_llama import LlamaConfig
25
+ from transformers.models.llama.modeling_llama import (
26
+ LlamaMLP,
27
+ LlamaRMSNorm,
28
+ apply_rotary_pos_emb,
29
+ repeat_kv,
30
+ )
31
+
32
+ from speculators import SpeculatorModel, SpeculatorModelConfig
33
+
34
+ __all__ = [
35
+ "Eagle3Attention",
36
+ "Eagle3DecoderLayer",
37
+ "Eagle3Speculator",
38
+ "Eagle3SpeculatorConfig",
39
+ ]
40
+
41
+
42
+ @SpeculatorModelConfig.register("eagle3")
43
+ class Eagle3SpeculatorConfig(SpeculatorModelConfig):
44
+ """
45
+ Configuration for EAGLE-3 speculator with vocabulary mapping.
46
+
47
+ EAGLE-3 features vocabulary mapping between draft (32K) and target (128K)
48
+ vocabularies, enabling cross-tokenizer speculation.
49
+
50
+ :param transformer_layer_config: Configuration for the transformer decoder layer
51
+ :param draft_vocab_size: Size of draft model vocabulary for speculation
52
+ :param norm_before_residual: Apply hidden_norm before storing residual
53
+ """
54
+
55
+ speculators_model_type: Literal["eagle3"] = "eagle3"
56
+ architectures: list[str] = Field(
57
+ default_factory=lambda: ["Eagle3Speculator"],
58
+ description="Model architectures that can load these weights",
59
+ )
60
+
61
+ transformer_layer_config: PretrainedConfig = Field(
62
+ default_factory=LlamaConfig,
63
+ description="Configuration for the transformer decoder layer",
64
+ )
65
+
66
+ draft_vocab_size: int = Field(
67
+ default=32000,
68
+ description="Size of draft model vocabulary for speculation",
69
+ )
70
+
71
+ norm_before_residual: bool = Field(
72
+ default=False,
73
+ description="Apply hidden_norm before storing residual",
74
+ )
75
+
76
+ target_hidden_size: Optional[int] = Field(
77
+ default=None,
78
+ description="Hidden size of the target model (if different from draft model)",
79
+ )
80
+
81
+ @property
82
+ def target_vocab_size(self) -> int:
83
+ """Get target vocabulary size from transformer config."""
84
+ return self.transformer_layer_config.vocab_size
85
+
86
+ @field_serializer("transformer_layer_config")
87
+ def serialize_transformer_config(self, value: PretrainedConfig) -> dict:
88
+ """Serialize transformer config to dict."""
89
+ return value.to_diff_dict()
90
+
91
+ @field_validator("transformer_layer_config", mode="before")
92
+ @classmethod
93
+ def validate_transformer_config(cls, value: Any) -> PretrainedConfig:
94
+ """Validate and convert transformer config."""
95
+ if isinstance(value, dict):
96
+ config_class: type[PretrainedConfig] = LlamaConfig
97
+ if "model_type" in value:
98
+ from transformers import AutoConfig
99
+
100
+ config_class = AutoConfig.for_model(
101
+ model_type=value["model_type"]
102
+ ).__class__
103
+ return config_class(**value)
104
+ return value
105
+
106
+
107
+ class Eagle3Attention(nn.Module):
108
+ """
109
+ Eagle-3 attention module that processes concatenated embeddings and hidden states.
110
+
111
+ Modified from standard Llama attention to accept 2x hidden_size input
112
+ for Q/K/V projections while maintaining standard output size.
113
+ """
114
+
115
+ def __init__(self, config: PretrainedConfig, layer_idx: int):
116
+ super().__init__()
117
+ self.config = config
118
+ self.layer_idx = layer_idx
119
+
120
+ self.num_heads = config.num_attention_heads
121
+ self.num_key_value_heads = config.num_key_value_heads
122
+ self.hidden_size = config.hidden_size
123
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
124
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
125
+
126
+ input_size = 2 * self.hidden_size
127
+ self.q_proj = nn.Linear(
128
+ input_size, self.num_heads * self.head_dim, bias=config.attention_bias
129
+ )
130
+ self.k_proj = nn.Linear(
131
+ input_size,
132
+ self.num_key_value_heads * self.head_dim,
133
+ bias=config.attention_bias,
134
+ )
135
+ self.v_proj = nn.Linear(
136
+ input_size,
137
+ self.num_key_value_heads * self.head_dim,
138
+ bias=config.attention_bias,
139
+ )
140
+ self.o_proj = nn.Linear(
141
+ self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
142
+ )
143
+
144
+ def forward(
145
+ self,
146
+ hidden_states: torch.Tensor,
147
+ attention_mask: Optional[torch.Tensor] = None,
148
+ position_ids: Optional[torch.LongTensor] = None,
149
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
150
+ output_attentions: bool = False,
151
+ use_cache: bool = False,
152
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
153
+ **kwargs, # noqa: ARG002
154
+ ) -> tuple:
155
+ """
156
+ Forward pass for Eagle-3 attention.
157
+ Taken from Llama Attention but modified to accept 2x hidden_size input.
158
+
159
+ :param hidden_states: Input tensor of shape [batch, seq_len, 2*hidden_size]
160
+ :param attention_mask: Optional attention mask
161
+ :param position_ids: Optional position IDs for rotary embeddings
162
+ :param past_key_value: Optional cached key-value pairs
163
+ :param output_attentions: Whether to return attention weights
164
+ :param use_cache: Whether to cache key-value pairs
165
+ :param position_embeddings: Optional precomputed rotary embeddings
166
+ :return: Tuple of (hidden_states, [attention_weights], [past_key_value])
167
+ """
168
+ bsz, q_len, _ = hidden_states.size()
169
+
170
+ query_states = self.q_proj(hidden_states)
171
+ key_states = self.k_proj(hidden_states)
172
+ value_states = self.v_proj(hidden_states)
173
+
174
+ query_states = query_states.view(
175
+ bsz, q_len, self.num_heads, self.head_dim
176
+ ).transpose(1, 2)
177
+ key_states = key_states.view(
178
+ bsz, q_len, self.num_key_value_heads, self.head_dim
179
+ ).transpose(1, 2)
180
+ value_states = value_states.view(
181
+ bsz, q_len, self.num_key_value_heads, self.head_dim
182
+ ).transpose(1, 2)
183
+
184
+ if position_embeddings is not None:
185
+ cos, sin = position_embeddings
186
+ query_states, key_states = apply_rotary_pos_emb(
187
+ query_states, key_states, cos, sin, position_ids
188
+ )
189
+
190
+ past_key_value_out = None
191
+ if past_key_value is not None:
192
+ past_key = past_key_value[0]
193
+ past_value = past_key_value[1]
194
+ key_states = torch.cat([past_key, key_states], dim=2)
195
+ value_states = torch.cat([past_value, value_states], dim=2)
196
+
197
+ if use_cache:
198
+ past_key_value_out = (key_states, value_states)
199
+
200
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
201
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
202
+
203
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / (
204
+ self.head_dim**0.5
205
+ )
206
+
207
+ if attention_mask is not None:
208
+ attn_weights = attn_weights + attention_mask
209
+
210
+ attn_weights = nn.functional.softmax(
211
+ attn_weights, dim=-1, dtype=torch.float32
212
+ ).to(query_states.dtype)
213
+
214
+ attn_output = torch.matmul(attn_weights, value_states)
215
+ attn_output = attn_output.transpose(1, 2).contiguous()
216
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
217
+
218
+ attn_output = self.o_proj(attn_output)
219
+
220
+ if not output_attentions:
221
+ attn_weights = None
222
+
223
+ return attn_output, attn_weights, past_key_value_out
224
+
225
+
226
+ class Eagle3DecoderLayer(nn.Module):
227
+ """
228
+ Eagle-3 decoder layer that processes concatenated embeddings and hidden states.
229
+
230
+ Accepts 2x hidden_size input from concatenated embeddings and fused hidden states.
231
+ Uses Eagle3Attention for the self-attention computation.
232
+ """
233
+
234
+ def __init__(
235
+ self,
236
+ config: PretrainedConfig,
237
+ layer_idx: int,
238
+ norm_before_residual: bool = False,
239
+ ):
240
+ super().__init__()
241
+ self.hidden_size = config.hidden_size
242
+ self.norm_before_residual = norm_before_residual
243
+
244
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
245
+ self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
246
+ self.post_attention_layernorm = LlamaRMSNorm(
247
+ config.hidden_size, eps=config.rms_norm_eps
248
+ )
249
+
250
+ self.self_attn = Eagle3Attention(config, layer_idx)
251
+
252
+ self.mlp = LlamaMLP(config)
253
+
254
+ def forward(
255
+ self,
256
+ hidden_states: torch.Tensor,
257
+ attention_mask: Optional[torch.Tensor] = None,
258
+ position_ids: Optional[torch.LongTensor] = None,
259
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
260
+ output_attentions: Optional[bool] = False,
261
+ use_cache: Optional[bool] = False,
262
+ cache_position: Optional[torch.LongTensor] = None, # noqa: ARG002
263
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
264
+ **kwargs, # noqa: ARG002
265
+ ) -> tuple:
266
+ """
267
+ Process concatenated embeddings and hidden states through modified decoder
268
+ layer.
269
+
270
+ :param hidden_states: Input tensor of shape [batch, seq_len, 2*hidden_size]
271
+ :return: Tuple of layer outputs
272
+ """
273
+ embeds = hidden_states[:, :, : self.hidden_size]
274
+ hidden = hidden_states[:, :, self.hidden_size : 2 * self.hidden_size]
275
+
276
+ if self.norm_before_residual:
277
+ hidden = self.hidden_norm(hidden)
278
+ residual = hidden
279
+ else:
280
+ residual = hidden
281
+ hidden = self.hidden_norm(hidden)
282
+
283
+ embeds = self.input_layernorm(embeds)
284
+
285
+ attn_input = torch.cat([embeds, hidden], dim=-1)
286
+
287
+ attn_output, attn_weights, past_key_value_out = self.self_attn(
288
+ hidden_states=attn_input,
289
+ attention_mask=attention_mask,
290
+ position_ids=position_ids,
291
+ past_key_value=past_key_value,
292
+ output_attentions=output_attentions,
293
+ use_cache=use_cache,
294
+ position_embeddings=position_embeddings,
295
+ )
296
+
297
+ hidden_states = residual + attn_output
298
+
299
+ residual = hidden_states
300
+ hidden_states = self.post_attention_layernorm(hidden_states)
301
+ hidden_states = self.mlp(hidden_states)
302
+ hidden_states = residual + hidden_states
303
+
304
+ outputs = (hidden_states,)
305
+
306
+ if output_attentions:
307
+ outputs += (attn_weights,) # type: ignore[assignment]
308
+
309
+ if use_cache:
310
+ outputs += (past_key_value_out,) # type: ignore[assignment]
311
+
312
+ return outputs
313
+
314
+
315
+ @SpeculatorModel.register("eagle3")
316
+ class Eagle3Speculator(SpeculatorModel):
317
+ """
318
+ EAGLE-3 speculator with vocabulary mapping and multi-layer fusion.
319
+
320
+ EAGLE-3 processes concatenated hidden states from multiple verifier layers
321
+ through a fusion layer, then combines with embeddings for a custom decoder
322
+ layer that accepts 2x hidden_size input.
323
+ """
324
+
325
+ config_class: ClassVar[type[Eagle3SpeculatorConfig]] = Eagle3SpeculatorConfig # type: ignore[misc]
326
+ _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc]
327
+ "verifier*",
328
+ ]
329
+ _keys_to_ignore_on_save: ClassVar[list[str]] = [] # type: ignore[misc,assignment]
330
+
331
+ def __init__(
332
+ self,
333
+ config: Eagle3SpeculatorConfig,
334
+ verifier: Optional[Union[str, os.PathLike, PreTrainedModel]] = None,
335
+ verifier_attachment_mode: Optional[
336
+ Literal["detached", "full", "train_only"]
337
+ ] = None,
338
+ ):
339
+ """
340
+ Initialize Eagle3 speculator.
341
+
342
+ :param config: Eagle3SpeculatorConfig instance
343
+ :param verifier: Optional verifier model
344
+ :param verifier_attachment_mode: How to attach the verifier
345
+ """
346
+ if not isinstance(config, Eagle3SpeculatorConfig):
347
+ raise ValueError(
348
+ f"config must be Eagle3SpeculatorConfig, got {type(config)}"
349
+ )
350
+
351
+ self.config: Eagle3SpeculatorConfig = config
352
+
353
+ self.hidden_size = config.transformer_layer_config.hidden_size
354
+ self.draft_vocab_size = config.draft_vocab_size
355
+ self.target_vocab_size = config.target_vocab_size
356
+
357
+ # Use target_hidden_size if specified, otherwise use draft model's hidden_size
358
+ self.target_hidden_size = (
359
+ config.target_hidden_size
360
+ if config.target_hidden_size is not None
361
+ else self.hidden_size
362
+ )
363
+
364
+ super().__init__(
365
+ config=config,
366
+ verifier=verifier,
367
+ verifier_attachment_mode=verifier_attachment_mode,
368
+ )
369
+
370
+ self.embed_tokens = nn.Embedding(
371
+ self.target_vocab_size,
372
+ self.hidden_size,
373
+ padding_idx=config.transformer_layer_config.pad_token_id
374
+ if hasattr(config.transformer_layer_config, "pad_token_id")
375
+ else None,
376
+ )
377
+
378
+ self.fc = nn.Linear(
379
+ 3 * self.target_hidden_size, # Use target model's hidden size
380
+ self.hidden_size,
381
+ bias=False,
382
+ )
383
+
384
+ self.layers = nn.ModuleList(
385
+ [
386
+ Eagle3DecoderLayer(
387
+ config.transformer_layer_config,
388
+ layer_idx=0,
389
+ norm_before_residual=config.norm_before_residual,
390
+ )
391
+ ]
392
+ )
393
+
394
+ self.norm = LlamaRMSNorm(
395
+ self.hidden_size,
396
+ eps=config.transformer_layer_config.rms_norm_eps,
397
+ )
398
+
399
+ self.lm_head = nn.Linear(
400
+ self.hidden_size,
401
+ self.draft_vocab_size,
402
+ bias=False,
403
+ )
404
+
405
+ self.register_buffer( # type: ignore[attr-defined]
406
+ "d2t",
407
+ torch.zeros(self.draft_vocab_size, dtype=torch.long),
408
+ )
409
+ self.register_buffer( # type: ignore[attr-defined]
410
+ "t2d",
411
+ torch.zeros(self.target_vocab_size, dtype=torch.bool),
412
+ )
413
+
414
+ # Type hints for buffers
415
+ self.d2t: torch.Tensor
416
+ self.t2d: torch.Tensor
417
+
418
+ self.post_init() # type: ignore[attr-defined]
419
+
420
+ def forward(
421
+ self,
422
+ input_ids: torch.LongTensor,
423
+ hidden_states: torch.FloatTensor,
424
+ attention_mask: Optional[torch.Tensor] = None,
425
+ position_ids: Optional[torch.LongTensor] = None,
426
+ past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
427
+ use_cache: Optional[bool] = None,
428
+ output_attentions: Optional[bool] = None,
429
+ output_hidden_states: Optional[bool] = None, # noqa: ARG002
430
+ return_dict: Optional[bool] = None,
431
+ ) -> Union[torch.FloatTensor, CausalLMOutputWithPast]:
432
+ """
433
+ Forward pass for EAGLE-3 speculation.
434
+
435
+ :param input_ids: Input token IDs from draft vocabulary
436
+ :param hidden_states: Concatenated hidden states from 3 verifier layers
437
+ [B, L, 3*target_H] where target_H is the target model's hidden size
438
+ :param attention_mask: Optional attention mask
439
+ :param position_ids: Optional position IDs
440
+ :param past_key_values: Optional cached key-values
441
+ :param use_cache: Whether to cache key-values
442
+ :param output_attentions: Return attention weights
443
+ :param output_hidden_states: Return hidden states
444
+ :param return_dict: Return dict output
445
+ :return: Model outputs with draft vocabulary logits
446
+ """
447
+ return_dict = (
448
+ return_dict if return_dict is not None else self.config.use_return_dict
449
+ )
450
+
451
+ inputs_embeds = self.embed_tokens(input_ids)
452
+
453
+ fused_hidden = self.fc(hidden_states)
454
+
455
+ layer_input = torch.cat([inputs_embeds, fused_hidden], dim=-1)
456
+
457
+ batch_size, seq_length = layer_input.shape[:2]
458
+ if attention_mask is not None and attention_mask.dim() == 2: # noqa: PLR2004
459
+ past_key_values_length = (
460
+ past_key_values[0][0].shape[2] if past_key_values else 0
461
+ )
462
+ attention_mask = _prepare_4d_causal_attention_mask(
463
+ attention_mask,
464
+ (batch_size, seq_length),
465
+ hidden_states,
466
+ past_key_values_length,
467
+ )
468
+
469
+ if position_ids is None:
470
+ device = hidden_states.device
471
+ position_ids = (
472
+ torch.arange( # type: ignore[assignment]
473
+ seq_length, dtype=torch.long, device=device
474
+ )
475
+ .unsqueeze(0)
476
+ .expand(batch_size, -1)
477
+ )
478
+
479
+ layer_outputs = self.layers[0](
480
+ layer_input,
481
+ attention_mask=attention_mask,
482
+ position_ids=position_ids,
483
+ past_key_value=past_key_values[0] if past_key_values else None,
484
+ output_attentions=output_attentions,
485
+ use_cache=use_cache,
486
+ )
487
+
488
+ hidden_states = layer_outputs[0]
489
+
490
+ hidden_states = self.norm(hidden_states)
491
+
492
+ logits = self.compute_logits(hidden_states, map_to_target_vocab=True)
493
+
494
+ if not return_dict:
495
+ return logits
496
+
497
+ return CausalLMOutputWithPast(
498
+ logits=logits,
499
+ past_key_values=[layer_outputs[1]] if use_cache else None, # type: ignore[arg-type]
500
+ hidden_states=None,
501
+ attentions=None,
502
+ )
503
+
504
+ def compute_logits(
505
+ self,
506
+ hidden_states: torch.FloatTensor,
507
+ map_to_target_vocab: bool = True,
508
+ ) -> torch.FloatTensor:
509
+ """
510
+ Compute logits with optional vocabulary mapping.
511
+
512
+ :param hidden_states: Hidden states from the model
513
+ :param map_to_target_vocab: Whether to map draft logits to target vocabulary
514
+ :return: Logits tensor
515
+ """
516
+ logits = self.lm_head(hidden_states)
517
+
518
+ if not map_to_target_vocab:
519
+ return logits
520
+
521
+ batch_size, seq_length, _ = logits.shape
522
+
523
+ draft_indices = torch.arange(self.draft_vocab_size, device=logits.device)
524
+
525
+ target_indices = draft_indices + self.d2t
526
+
527
+ mapped_logits = logits.new_full(
528
+ (batch_size, seq_length, self.target_vocab_size), float("-inf")
529
+ )
530
+
531
+ mapped_logits[:, :, target_indices] = logits
532
+
533
+ return mapped_logits
534
+
535
+ def map_draft_to_target_tokens(
536
+ self, draft_tokens: torch.LongTensor
537
+ ) -> torch.LongTensor:
538
+ """
539
+ Map draft token IDs to target token IDs.
540
+
541
+ :param draft_tokens: Draft vocabulary token IDs
542
+ :return: Target vocabulary token IDs
543
+ """
544
+ return draft_tokens + self.d2t[draft_tokens] # type: ignore[return-value]
545
+
546
+ def check_target_token_availability(
547
+ self, target_tokens: torch.LongTensor
548
+ ) -> torch.BoolTensor:
549
+ """
550
+ Check if target tokens have draft equivalents.
551
+
552
+ :param target_tokens: Target vocabulary token IDs
553
+ :return: Boolean mask indicating availability in draft vocabulary
554
+ """
555
+ return self.t2d[target_tokens] # type: ignore[return-value]
556
+
557
+ def tie_weights(self):
558
+ """
559
+ Override tie_weights to prevent vocabulary corruption in transformers 4.54.1+
560
+
561
+ Eagle3 intentionally uses different vocabulary sizes:
562
+ - Input embeddings (embed_tokens): 128256 (full vocabulary)
563
+ - Output embeddings (lm_head): 32000 (draft vocabulary)
564
+
565
+ The default tie_weights() tries to make them identical, breaking Eagle3.
566
+ This override preserves the intentional vocabulary size difference.
567
+ """
568
+ # Don't call super().tie_weights() - this prevents vocabulary corruption
569
+ # that occurs when _tie_or_clone_weights replaces lm_head.weight with
570
+ # embed_tokens.weight
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.53.2"
4
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:502d930bffd89a9ea220790f4b62583f2001b6b0b8caf7d93cf90a2ed708b60a
3
+ size 1900438376