sedd-medium / sedd_wrapper.py
pbcong's picture
Upload folder using huggingface_hub
19ed37d verified
from __future__ import annotations
"""sedd_wrapper.py
=========================================
This module provides a minimal HuggingFace-compatible wrapper around the
`SEDD` architecture that is implemented in :pyfile:`model/transformer.py`.
The wrapper closely follows the design used in the Aero implementation that
lives in this code-base (see :pyfile:`configuration_aero.py` and
:pyfile:`modeling_aero.py`). Concretely we expose three public objects:
* ``SEDDConfig`` A :class:`transformers.PretrainedConfig` subclass that
stores the hyper-parameters needed to instantiate a ``SEDD`` model.
* ``SEDDModel`` A :class:`transformers.PreTrainedModel` subclass that
internally contains an instance of the original ``SEDD`` network and maps
from ``input_ids`` + ``sigma`` to the vocabulary logits.
* ``SEDDOutput`` A thin :class:`transformers.modeling_outputs.ModelOutput`
dataclass that mirrors the usual "logits / loss" structure.
With this wrapper a trained model checkpoint can be pushed to / loaded from
πŸ€— Hub via ``SEDDModel.push_to_hub`` / ``SEDDModel.from_pretrained`` the same
way as any other ``transformers`` model.
"""
from dataclasses import dataclass
from typing import Optional, Tuple, List, Dict, Any, Union
import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
# Original SEDD implementation
from model.transformer import SEDD as _OrigSEDD
try:
from omegaconf import OmegaConf
except ImportError: # pragma: no cover – omegaconf is an explicit dependency of SEDD
OmegaConf = None # type: ignore
logger = logging.get_logger(__name__)
###############################################################################
# Configuration #
###############################################################################
class SEDDConfig(PretrainedConfig):
"""Configuration class for the SEDD architecture.
The defaults reproduce *roughly* the "small" configuration shipped in
``configs/model/small.yaml``. Additional keys that are present in the
original Hydra config but not required for instantiation (e.g. *training*
hyper-parameters) are deliberately omitted here – they can still be stored
as *extra* fields in the underlying JSON if a user wishes to preserve them.
"""
model_type: str = "sedd"
def __init__(
self,
*,
tokens: int = 50257,
# graph section
graph_type: str = "absorb",
# model section (mirrors configs/model/*.yaml)
model_hidden_size: int = 768,
model_cond_dim: int = 128,
model_length: int = 1024,
model_n_blocks: int = 12,
model_n_heads: int = 12,
model_scale_by_sigma: bool = True,
model_dropout: float = 0.10,
# miscellaneous
tie_word_embeddings: bool = False,
**kwargs,
) -> None:
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
# Top-level attributes (kept flat for simplicity)
self.tokens = tokens
self.graph_type = graph_type
# Model hyper-parameters
self.model_hidden_size = model_hidden_size
self.model_cond_dim = model_cond_dim
self.model_length = model_length
self.model_n_blocks = model_n_blocks
self.model_n_heads = model_n_heads
self.model_scale_by_sigma = model_scale_by_sigma
self.model_dropout = model_dropout
# ---------------------------------------------------------------------
# Serialization helpers – these optionally bridge to the original Hydra
# config structure that the reference implementation expects.
# ---------------------------------------------------------------------
def to_hydra(self):
"""Convert this *flat* config to the nested OmegaConf structure that
the reference ``SEDD`` implementation expects.
"""
if OmegaConf is None:
raise RuntimeError("`omegaconf` is required to build a Hydra config")
nested: Dict[str, Any] = {
"tokens": self.tokens,
"graph": {
"type": self.graph_type,
},
"model": {
"hidden_size": self.model_hidden_size,
"cond_dim": self.model_cond_dim,
"length": self.model_length,
"n_blocks": self.model_n_blocks,
"n_heads": self.model_n_heads,
"scale_by_sigma": self.model_scale_by_sigma,
"dropout": self.model_dropout,
},
}
return OmegaConf.create(nested)
###############################################################################
# Output container #
###############################################################################
@dataclass
class SEDDOutput(ModelOutput):
"""Standard output for :class:`SEDDModel`.
Attributes
----------
loss:
*Optional* scalar returned when ``labels`` are provided.
logits:
The raw vocabulary logits computed by the model of shape
``(batch_size, sequence_length, vocab_size)``.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor | None = None
###############################################################################
# Model #
###############################################################################
class SEDDModel(PreTrainedModel):
"""HuggingFace *Transformers* wrapper around the original ``SEDD`` model."""
config_class = SEDDConfig
base_model_prefix = "score_model"
_no_split_modules: List[str] = [
"DDiTBlock", # ensure these blocks are not split when using FSDP/TP
]
def __init__(self, config: SEDDConfig):
super().__init__(config)
# ------------------------------------------------------------------
# Instantiate the original SEDD architecture using the Hydra cfg that
# the implementation expects.
# ------------------------------------------------------------------
if OmegaConf is None:
raise RuntimeError("`omegaconf` is required to instantiate SEDD")
hydra_cfg = config.to_hydra()
self.score_model = _OrigSEDD(hydra_cfg)
# Make sure parameters are created on the right device / dtype.
self.post_init()
# ------------------------------------------------------------------
# Forward pass
# ------------------------------------------------------------------
def forward(
self,
input_ids: torch.LongTensor,
sigma: torch.FloatTensor,
labels: Optional[torch.LongTensor] = None,
**kwargs: Any,
) -> Union[SEDDOutput, Tuple]:
"""Run a forward pass.
Parameters
----------
input_ids:
Token indices of shape ``(batch_size, seq_len)``.
sigma:
Noise level ("time-step") of shape ``(batch_size,)``.
labels:
*Optional* label tensor used to compute a cross-entropy training
loss. If provided the returned :class:`SEDDOutput` will contain a
``loss`` field.
"""
logits = self.score_model(indices=input_ids, sigma=sigma)
loss: Optional[torch.Tensor] = None
if labels is not None:
# Standard CE loss over the last dimension (vocab)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
if not self.config.return_dict:
output: Tuple[Any, ...] = (logits,)
return ((loss,) + output) if loss is not None else output
return SEDDOutput(loss=loss, logits=logits)
# ------------------------------------------------------------------
# Weight loading helpers – we delegate to the *original* SEDD mixin so that
# checkpoints trained with the previous implementation can be re-used.
# ------------------------------------------------------------------
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
*model_args: Any,
**kwargs: Any,
) -> "SEDDModel":
"""Overrides the default method to allow loading legacy SEDD checkpoints
whose weights are saved via ``torch.save({'model': state_dict, ...})``.
"""
try:
# First try the regular *transformers* loading routine – this will
# succeed if the repository follows the standard file-naming
# conventions (i.e. contains a ``pytorch_model.bin`` / safetensors).
return super().from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
except (EnvironmentError, RuntimeError) as e:
logger.info(
"Falling back to legacy SEDD checkpoint format because standard "
"loading raised: %s", e,
)
# ----------------------------------------------------------
# 1. Load config the usual way so we get a `SEDDConfig` instance.
# ----------------------------------------------------------
config = kwargs.pop("config", None) or SEDDConfig.from_pretrained(
pretrained_model_name_or_path
)
model = cls(config, *model_args, **kwargs)
# ----------------------------------------------------------
# 2. Attempt to locate the legacy *.pth* checkpoint and load it.
# ----------------------------------------------------------
import os
import torch as _torch
checkpoint_path = os.path.join(
pretrained_model_name_or_path, "checkpoints-meta", "checkpoint.pth"
)
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(
"Could not find legacy SEDD checkpoint at " f"{checkpoint_path}"
)
ckpt = _torch.load(checkpoint_path, map_location="cpu")
state_dict = ckpt.get("model", ckpt)
# Strip prefix if present (sometimes stored under "module.")
state_dict = {
k.replace("module.", ""): v for k, v in state_dict.items()
}
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
logger.warning("Missing keys when loading SEDD weights: %s", missing)
if unexpected:
logger.warning(
"Unexpected keys when loading SEDD weights: %s", unexpected
)
return model
###############################################################################
# Public API #
###############################################################################
__all__ = [
"SEDDConfig",
"SEDDModel",
"SEDDOutput",
]