from __future__ import annotations """sedd_wrapper.py ========================================= This module provides a minimal HuggingFace-compatible wrapper around the `SEDD` architecture that is implemented in :pyfile:`model/transformer.py`. The wrapper closely follows the design used in the Aero implementation that lives in this code-base (see :pyfile:`configuration_aero.py` and :pyfile:`modeling_aero.py`). Concretely we expose three public objects: * ``SEDDConfig`` A :class:`transformers.PretrainedConfig` subclass that stores the hyper-parameters needed to instantiate a ``SEDD`` model. * ``SEDDModel`` A :class:`transformers.PreTrainedModel` subclass that internally contains an instance of the original ``SEDD`` network and maps from ``input_ids`` + ``sigma`` to the vocabulary logits. * ``SEDDOutput`` A thin :class:`transformers.modeling_outputs.ModelOutput` dataclass that mirrors the usual "logits / loss" structure. With this wrapper a trained model checkpoint can be pushed to / loaded from 🤗 Hub via ``SEDDModel.push_to_hub`` / ``SEDDModel.from_pretrained`` the same way as any other ``transformers`` model. """ from dataclasses import dataclass from typing import Optional, Tuple, List, Dict, Any, Union import torch from torch import nn from transformers.configuration_utils import PretrainedConfig from transformers.modeling_outputs import ModelOutput from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging # Original SEDD implementation from model.transformer import SEDD as _OrigSEDD try: from omegaconf import OmegaConf except ImportError: # pragma: no cover – omegaconf is an explicit dependency of SEDD OmegaConf = None # type: ignore logger = logging.get_logger(__name__) ############################################################################### # Configuration # ############################################################################### class SEDDConfig(PretrainedConfig): """Configuration class for the SEDD architecture. The defaults reproduce *roughly* the "small" configuration shipped in ``configs/model/small.yaml``. Additional keys that are present in the original Hydra config but not required for instantiation (e.g. *training* hyper-parameters) are deliberately omitted here – they can still be stored as *extra* fields in the underlying JSON if a user wishes to preserve them. """ model_type: str = "sedd" def __init__( self, *, tokens: int = 50257, # graph section graph_type: str = "absorb", # model section (mirrors configs/model/*.yaml) model_hidden_size: int = 768, model_cond_dim: int = 128, model_length: int = 1024, model_n_blocks: int = 12, model_n_heads: int = 12, model_scale_by_sigma: bool = True, model_dropout: float = 0.10, # miscellaneous tie_word_embeddings: bool = False, **kwargs, ) -> None: super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) # Top-level attributes (kept flat for simplicity) self.tokens = tokens self.graph_type = graph_type # Model hyper-parameters self.model_hidden_size = model_hidden_size self.model_cond_dim = model_cond_dim self.model_length = model_length self.model_n_blocks = model_n_blocks self.model_n_heads = model_n_heads self.model_scale_by_sigma = model_scale_by_sigma self.model_dropout = model_dropout # --------------------------------------------------------------------- # Serialization helpers – these optionally bridge to the original Hydra # config structure that the reference implementation expects. # --------------------------------------------------------------------- def to_hydra(self): """Convert this *flat* config to the nested OmegaConf structure that the reference ``SEDD`` implementation expects. """ if OmegaConf is None: raise RuntimeError("`omegaconf` is required to build a Hydra config") nested: Dict[str, Any] = { "tokens": self.tokens, "graph": { "type": self.graph_type, }, "model": { "hidden_size": self.model_hidden_size, "cond_dim": self.model_cond_dim, "length": self.model_length, "n_blocks": self.model_n_blocks, "n_heads": self.model_n_heads, "scale_by_sigma": self.model_scale_by_sigma, "dropout": self.model_dropout, }, } return OmegaConf.create(nested) ############################################################################### # Output container # ############################################################################### @dataclass class SEDDOutput(ModelOutput): """Standard output for :class:`SEDDModel`. Attributes ---------- loss: *Optional* scalar returned when ``labels`` are provided. logits: The raw vocabulary logits computed by the model of shape ``(batch_size, sequence_length, vocab_size)``. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor | None = None ############################################################################### # Model # ############################################################################### class SEDDModel(PreTrainedModel): """HuggingFace *Transformers* wrapper around the original ``SEDD`` model.""" config_class = SEDDConfig base_model_prefix = "score_model" _no_split_modules: List[str] = [ "DDiTBlock", # ensure these blocks are not split when using FSDP/TP ] def __init__(self, config: SEDDConfig): super().__init__(config) # ------------------------------------------------------------------ # Instantiate the original SEDD architecture using the Hydra cfg that # the implementation expects. # ------------------------------------------------------------------ if OmegaConf is None: raise RuntimeError("`omegaconf` is required to instantiate SEDD") hydra_cfg = config.to_hydra() self.score_model = _OrigSEDD(hydra_cfg) # Make sure parameters are created on the right device / dtype. self.post_init() # ------------------------------------------------------------------ # Forward pass # ------------------------------------------------------------------ def forward( self, input_ids: torch.LongTensor, sigma: torch.FloatTensor, labels: Optional[torch.LongTensor] = None, **kwargs: Any, ) -> Union[SEDDOutput, Tuple]: """Run a forward pass. Parameters ---------- input_ids: Token indices of shape ``(batch_size, seq_len)``. sigma: Noise level ("time-step") of shape ``(batch_size,)``. labels: *Optional* label tensor used to compute a cross-entropy training loss. If provided the returned :class:`SEDDOutput` will contain a ``loss`` field. """ logits = self.score_model(indices=input_ids, sigma=sigma) loss: Optional[torch.Tensor] = None if labels is not None: # Standard CE loss over the last dimension (vocab) loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) if not self.config.return_dict: output: Tuple[Any, ...] = (logits,) return ((loss,) + output) if loss is not None else output return SEDDOutput(loss=loss, logits=logits) # ------------------------------------------------------------------ # Weight loading helpers – we delegate to the *original* SEDD mixin so that # checkpoints trained with the previous implementation can be re-used. # ------------------------------------------------------------------ @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, *model_args: Any, **kwargs: Any, ) -> "SEDDModel": """Overrides the default method to allow loading legacy SEDD checkpoints whose weights are saved via ``torch.save({'model': state_dict, ...})``. """ try: # First try the regular *transformers* loading routine – this will # succeed if the repository follows the standard file-naming # conventions (i.e. contains a ``pytorch_model.bin`` / safetensors). return super().from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs ) except (EnvironmentError, RuntimeError) as e: logger.info( "Falling back to legacy SEDD checkpoint format because standard " "loading raised: %s", e, ) # ---------------------------------------------------------- # 1. Load config the usual way so we get a `SEDDConfig` instance. # ---------------------------------------------------------- config = kwargs.pop("config", None) or SEDDConfig.from_pretrained( pretrained_model_name_or_path ) model = cls(config, *model_args, **kwargs) # ---------------------------------------------------------- # 2. Attempt to locate the legacy *.pth* checkpoint and load it. # ---------------------------------------------------------- import os import torch as _torch checkpoint_path = os.path.join( pretrained_model_name_or_path, "checkpoints-meta", "checkpoint.pth" ) if not os.path.isfile(checkpoint_path): raise FileNotFoundError( "Could not find legacy SEDD checkpoint at " f"{checkpoint_path}" ) ckpt = _torch.load(checkpoint_path, map_location="cpu") state_dict = ckpt.get("model", ckpt) # Strip prefix if present (sometimes stored under "module.") state_dict = { k.replace("module.", ""): v for k, v in state_dict.items() } missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: logger.warning("Missing keys when loading SEDD weights: %s", missing) if unexpected: logger.warning( "Unexpected keys when loading SEDD weights: %s", unexpected ) return model ############################################################################### # Public API # ############################################################################### __all__ = [ "SEDDConfig", "SEDDModel", "SEDDOutput", ]