|
|
from __future__ import annotations |
|
|
|
|
|
"""sedd_wrapper.py |
|
|
========================================= |
|
|
This module provides a minimal HuggingFace-compatible wrapper around the |
|
|
`SEDD` architecture that is implemented in :pyfile:`model/transformer.py`. |
|
|
|
|
|
The wrapper closely follows the design used in the Aero implementation that |
|
|
lives in this code-base (see :pyfile:`configuration_aero.py` and |
|
|
:pyfile:`modeling_aero.py`). Concretely we expose three public objects: |
|
|
|
|
|
* ``SEDDConfig`` A :class:`transformers.PretrainedConfig` subclass that |
|
|
stores the hyper-parameters needed to instantiate a ``SEDD`` model. |
|
|
* ``SEDDModel`` A :class:`transformers.PreTrainedModel` subclass that |
|
|
internally contains an instance of the original ``SEDD`` network and maps |
|
|
from ``input_ids`` + ``sigma`` to the vocabulary logits. |
|
|
* ``SEDDOutput`` A thin :class:`transformers.modeling_outputs.ModelOutput` |
|
|
dataclass that mirrors the usual "logits / loss" structure. |
|
|
|
|
|
With this wrapper a trained model checkpoint can be pushed to / loaded from |
|
|
π€ Hub via ``SEDDModel.push_to_hub`` / ``SEDDModel.from_pretrained`` the same |
|
|
way as any other ``transformers`` model. |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, Tuple, List, Dict, Any, Union |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
from transformers.modeling_outputs import ModelOutput |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.utils import logging |
|
|
|
|
|
|
|
|
from model.transformer import SEDD as _OrigSEDD |
|
|
|
|
|
try: |
|
|
from omegaconf import OmegaConf |
|
|
except ImportError: |
|
|
OmegaConf = None |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SEDDConfig(PretrainedConfig): |
|
|
"""Configuration class for the SEDD architecture. |
|
|
|
|
|
The defaults reproduce *roughly* the "small" configuration shipped in |
|
|
``configs/model/small.yaml``. Additional keys that are present in the |
|
|
original Hydra config but not required for instantiation (e.g. *training* |
|
|
hyper-parameters) are deliberately omitted here β they can still be stored |
|
|
as *extra* fields in the underlying JSON if a user wishes to preserve them. |
|
|
""" |
|
|
|
|
|
model_type: str = "sedd" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
tokens: int = 50257, |
|
|
|
|
|
graph_type: str = "absorb", |
|
|
|
|
|
model_hidden_size: int = 768, |
|
|
model_cond_dim: int = 128, |
|
|
model_length: int = 1024, |
|
|
model_n_blocks: int = 12, |
|
|
model_n_heads: int = 12, |
|
|
model_scale_by_sigma: bool = True, |
|
|
model_dropout: float = 0.10, |
|
|
|
|
|
tie_word_embeddings: bool = False, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) |
|
|
|
|
|
|
|
|
self.tokens = tokens |
|
|
self.graph_type = graph_type |
|
|
|
|
|
|
|
|
self.model_hidden_size = model_hidden_size |
|
|
self.model_cond_dim = model_cond_dim |
|
|
self.model_length = model_length |
|
|
self.model_n_blocks = model_n_blocks |
|
|
self.model_n_heads = model_n_heads |
|
|
self.model_scale_by_sigma = model_scale_by_sigma |
|
|
self.model_dropout = model_dropout |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_hydra(self): |
|
|
"""Convert this *flat* config to the nested OmegaConf structure that |
|
|
the reference ``SEDD`` implementation expects. |
|
|
""" |
|
|
|
|
|
if OmegaConf is None: |
|
|
raise RuntimeError("`omegaconf` is required to build a Hydra config") |
|
|
|
|
|
nested: Dict[str, Any] = { |
|
|
"tokens": self.tokens, |
|
|
"graph": { |
|
|
"type": self.graph_type, |
|
|
}, |
|
|
"model": { |
|
|
"hidden_size": self.model_hidden_size, |
|
|
"cond_dim": self.model_cond_dim, |
|
|
"length": self.model_length, |
|
|
"n_blocks": self.model_n_blocks, |
|
|
"n_heads": self.model_n_heads, |
|
|
"scale_by_sigma": self.model_scale_by_sigma, |
|
|
"dropout": self.model_dropout, |
|
|
}, |
|
|
} |
|
|
return OmegaConf.create(nested) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SEDDOutput(ModelOutput): |
|
|
"""Standard output for :class:`SEDDModel`. |
|
|
|
|
|
Attributes |
|
|
---------- |
|
|
loss: |
|
|
*Optional* scalar returned when ``labels`` are provided. |
|
|
logits: |
|
|
The raw vocabulary logits computed by the model of shape |
|
|
``(batch_size, sequence_length, vocab_size)``. |
|
|
""" |
|
|
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
logits: torch.FloatTensor | None = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SEDDModel(PreTrainedModel): |
|
|
"""HuggingFace *Transformers* wrapper around the original ``SEDD`` model.""" |
|
|
|
|
|
config_class = SEDDConfig |
|
|
base_model_prefix = "score_model" |
|
|
_no_split_modules: List[str] = [ |
|
|
"DDiTBlock", |
|
|
] |
|
|
|
|
|
def __init__(self, config: SEDDConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if OmegaConf is None: |
|
|
raise RuntimeError("`omegaconf` is required to instantiate SEDD") |
|
|
|
|
|
hydra_cfg = config.to_hydra() |
|
|
self.score_model = _OrigSEDD(hydra_cfg) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
sigma: torch.FloatTensor, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
**kwargs: Any, |
|
|
) -> Union[SEDDOutput, Tuple]: |
|
|
"""Run a forward pass. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
input_ids: |
|
|
Token indices of shape ``(batch_size, seq_len)``. |
|
|
sigma: |
|
|
Noise level ("time-step") of shape ``(batch_size,)``. |
|
|
labels: |
|
|
*Optional* label tensor used to compute a cross-entropy training |
|
|
loss. If provided the returned :class:`SEDDOutput` will contain a |
|
|
``loss`` field. |
|
|
""" |
|
|
|
|
|
logits = self.score_model(indices=input_ids, sigma=sigma) |
|
|
|
|
|
loss: Optional[torch.Tensor] = None |
|
|
if labels is not None: |
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) |
|
|
|
|
|
if not self.config.return_dict: |
|
|
output: Tuple[Any, ...] = (logits,) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return SEDDOutput(loss=loss, logits=logits) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
pretrained_model_name_or_path: str, |
|
|
*model_args: Any, |
|
|
**kwargs: Any, |
|
|
) -> "SEDDModel": |
|
|
"""Overrides the default method to allow loading legacy SEDD checkpoints |
|
|
whose weights are saved via ``torch.save({'model': state_dict, ...})``. |
|
|
""" |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
return super().from_pretrained( |
|
|
pretrained_model_name_or_path, *model_args, **kwargs |
|
|
) |
|
|
except (EnvironmentError, RuntimeError) as e: |
|
|
logger.info( |
|
|
"Falling back to legacy SEDD checkpoint format because standard " |
|
|
"loading raised: %s", e, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = kwargs.pop("config", None) or SEDDConfig.from_pretrained( |
|
|
pretrained_model_name_or_path |
|
|
) |
|
|
model = cls(config, *model_args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import torch as _torch |
|
|
|
|
|
checkpoint_path = os.path.join( |
|
|
pretrained_model_name_or_path, "checkpoints-meta", "checkpoint.pth" |
|
|
) |
|
|
if not os.path.isfile(checkpoint_path): |
|
|
raise FileNotFoundError( |
|
|
"Could not find legacy SEDD checkpoint at " f"{checkpoint_path}" |
|
|
) |
|
|
|
|
|
ckpt = _torch.load(checkpoint_path, map_location="cpu") |
|
|
state_dict = ckpt.get("model", ckpt) |
|
|
|
|
|
state_dict = { |
|
|
k.replace("module.", ""): v for k, v in state_dict.items() |
|
|
} |
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=False) |
|
|
if missing: |
|
|
logger.warning("Missing keys when loading SEDD weights: %s", missing) |
|
|
if unexpected: |
|
|
logger.warning( |
|
|
"Unexpected keys when loading SEDD weights: %s", unexpected |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"SEDDConfig", |
|
|
"SEDDModel", |
|
|
"SEDDOutput", |
|
|
] |
|
|
|