|
|
|
|
|
|
|
|
|
|
|
import collections
|
|
import logging
|
|
|
|
|
|
import math
|
|
import os
|
|
import re
|
|
from collections import OrderedDict
|
|
from functools import partial
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange, repeat
|
|
from safetensors.torch import load_file as safe_load_file
|
|
from torch.nn.modules.utils import _pair
|
|
from transformers import GPT2Config, PreTrainedModel, ViTConfig, ViTModel
|
|
from transformers.models.bert.modeling_bert import (
|
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
MaskedLMOutput,
|
|
SequenceClassifierOutput,
|
|
)
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutput,
|
|
BaseModelOutputWithPast,
|
|
BaseModelOutputWithPooling,
|
|
MaskedLMOutput,
|
|
MultipleChoiceModelOutput,
|
|
QuestionAnsweringModelOutput,
|
|
SequenceClassifierOutput,
|
|
ModelOutput,
|
|
TokenClassifierOutput,
|
|
)
|
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
|
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
|
|
|
|
from .configuration_hf_nomic_bert import NomicBertConfig
|
|
|
|
try:
|
|
from torch.nn.functional import scaled_dot_product_attention
|
|
except ImportError:
|
|
scaled_dot_product_attention = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
|
|
|
|
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
|
|
is_sharded = False
|
|
load_safe = False
|
|
resolved_archive_file = None
|
|
|
|
weights_path = os.path.join(model_name, WEIGHTS_NAME)
|
|
weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
|
|
safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
|
|
safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
|
|
|
|
if os.path.isfile(weights_path):
|
|
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
|
elif os.path.isfile(weights_index_path):
|
|
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False)
|
|
is_sharded = True
|
|
elif os.path.isfile(safe_weights_path):
|
|
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
|
load_safe = True
|
|
elif os.path.isfile(safe_weights_index_path):
|
|
resolved_archive_file = cached_file(
|
|
model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
|
|
)
|
|
is_sharded = True
|
|
load_safe = True
|
|
else:
|
|
resolved_archive_file = None
|
|
for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
|
|
resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False)
|
|
if resolved_archive_file is not None:
|
|
if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]:
|
|
load_safe = True
|
|
if weight_name in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
|
|
is_sharded = True
|
|
break
|
|
|
|
if resolved_archive_file is None:
|
|
raise EnvironmentError(f"Model name {model_name} was not found.")
|
|
|
|
if load_safe:
|
|
loader = partial(safe_load_file, device=mapped_device)
|
|
else:
|
|
loader = partial(torch.load, map_location=mapped_device)
|
|
|
|
if is_sharded:
|
|
|
|
|
|
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file)
|
|
state_dict = {}
|
|
for sharded_file in resolved_archive_file:
|
|
state_dict.update(loader(sharded_file))
|
|
else:
|
|
state_dict = loader(resolved_archive_file)
|
|
|
|
if dtype is not None:
|
|
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
|
|
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
|
return state_dict
|
|
|
|
|
|
def filter_shapes(state_dict, model):
|
|
"""
|
|
Filters the state dict to match the current model shape.
|
|
"""
|
|
filtered_state_dict = {}
|
|
for key, value in state_dict.items():
|
|
if key in model.state_dict():
|
|
if value.shape == model.state_dict()[key].shape:
|
|
filtered_state_dict[key] = value
|
|
return filtered_state_dict
|
|
|
|
|
|
def remap_bert_state_dict(
|
|
state_dict,
|
|
config,
|
|
remove_bert=False,
|
|
remove_cls_weights=False,
|
|
add_pooling_layer=False,
|
|
):
|
|
"""
|
|
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
|
"""
|
|
|
|
def add_bert_prefix(key):
|
|
|
|
if key.startswith("bert.") or key.startswith("cls."):
|
|
return key
|
|
return f"bert.{key}"
|
|
|
|
state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
|
|
|
|
|
|
def key_mapping_ln_gamma_beta(key):
|
|
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
|
|
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
|
|
return key
|
|
|
|
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
|
|
|
|
|
|
def key_mapping_layers(key):
|
|
return re.sub(r"^bert.encoder.layer\.", "bert.encoder.layers.", key)
|
|
|
|
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
|
|
|
|
|
def key_mapping_ln(key):
|
|
key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
|
|
key = re.sub(
|
|
r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
|
|
r"bert.encoder.layers.\1.norm1.\2",
|
|
key,
|
|
)
|
|
key = re.sub(
|
|
r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
|
|
r"bert.encoder.layers.\1.norm2.\2",
|
|
key,
|
|
)
|
|
key = re.sub(
|
|
r"^cls.predictions.transform.LayerNorm.(weight|bias)",
|
|
r"cls.predictions.transform.layer_norm.\1",
|
|
key,
|
|
)
|
|
return key
|
|
|
|
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
|
|
|
|
|
def key_mapping_mlp(key):
|
|
key = re.sub(
|
|
r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
|
|
r"bert.encoder.layers.\1.mlp.fc1.\2",
|
|
key,
|
|
)
|
|
key = re.sub(
|
|
r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
|
|
r"bert.encoder.layers.\1.mlp.fc2.\2",
|
|
key,
|
|
)
|
|
return key
|
|
|
|
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
|
|
|
|
|
last_layer_subset = getattr(config, "last_layer_subset", False)
|
|
for d in range(config.num_hidden_layers):
|
|
if f"bert.encoder.layers.{d}.attention.self.query.weight" not in state_dict:
|
|
continue
|
|
Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
|
|
Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
|
|
Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
|
|
bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
|
|
bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
|
|
bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
|
|
if not (last_layer_subset and d == config.num_hidden_layers - 1):
|
|
state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
|
|
state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
|
|
else:
|
|
state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
|
|
state_dict[f"bert.encoder.layers.{d}.attn.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
|
|
state_dict[f"bert.encoder.layers.{d}.attn.Wq.bias"] = bq
|
|
state_dict[f"bert.encoder.layers.{d}.attn.Wkv.bias"] = torch.cat([bk, bv], dim=0)
|
|
|
|
def key_mapping_attn(key):
|
|
return re.sub(
|
|
r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
|
|
r"bert.encoder.layers.\1.attn.out_proj.\2",
|
|
key,
|
|
)
|
|
|
|
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
|
|
|
def key_mapping_decoder_bias(key):
|
|
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
|
|
|
|
|
state_dict.pop("cls.seq_relationship.weight", None)
|
|
state_dict.pop("cls.seq_relationship.bias", None)
|
|
state_dict.pop("bert.embeddings.position_ids", None)
|
|
|
|
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
|
|
|
if remove_cls_weights:
|
|
cls_weights = [
|
|
"cls.predictions.decoder.bias",
|
|
"cls.predictions.transform.dense.weight",
|
|
"cls.predictions.transform.dense.bias",
|
|
"cls.predictions.transform.layer_norm.weight",
|
|
"cls.predictions.transform.layer_norm.bias",
|
|
"cls.predictions.decoder.weight",
|
|
]
|
|
for weight in cls_weights:
|
|
state_dict.pop(weight, None)
|
|
|
|
|
|
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
|
if pad_vocab_size_multiple > 1:
|
|
word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
|
|
state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
|
|
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
|
|
)
|
|
if not remove_cls_weights:
|
|
decoder_weight = state_dict["cls.predictions.decoder.weight"]
|
|
state_dict["cls.predictions.decoder.weight"] = F.pad(
|
|
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
|
|
)
|
|
|
|
|
|
|
|
if "cls.predictions.decoder.bias" in state_dict:
|
|
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
|
state_dict["cls.predictions.decoder.bias"] = F.pad(
|
|
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
|
)
|
|
|
|
if add_pooling_layer is False:
|
|
pooler_weights = [
|
|
"bert.pooler.dense.weight",
|
|
"bert.pooler.dense.bias",
|
|
]
|
|
for key in pooler_weights:
|
|
state_dict.pop(key, None)
|
|
|
|
if remove_bert:
|
|
|
|
def remove_bert_prefix(key):
|
|
key = re.sub(r"^bert.", "", key)
|
|
return key
|
|
|
|
state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
|
|
|
|
return state_dict
|
|
|
|
|
|
def _trunc_normal_(tensor, mean, std, a, b):
|
|
|
|
|
|
def norm_cdf(x):
|
|
|
|
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
|
print(
|
|
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
|
"The distribution of values may be incorrect.",
|
|
stacklevel=2,
|
|
)
|
|
|
|
|
|
|
|
|
|
l = norm_cdf((a - mean) / std)
|
|
u = norm_cdf((b - mean) / std)
|
|
|
|
|
|
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
|
|
|
|
|
|
|
tensor.erfinv_()
|
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.0))
|
|
tensor.add_(mean)
|
|
|
|
|
|
tensor.clamp_(min=a, max=b)
|
|
return tensor
|
|
|
|
|
|
def trunc_normal_tf_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
|
r"""Fills the input Tensor with values drawn from a truncated
|
|
normal distribution. The values are effectively drawn from the
|
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
with values outside :math:`[a, b]` redrawn until they are within
|
|
the bounds. The method used for generating the random values works
|
|
best when :math:`a \leq \text{mean} \leq b`.
|
|
|
|
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
|
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
|
and the result is subsquently scaled and shifted by the mean and std args.
|
|
|
|
Args:
|
|
tensor: an n-dimensional `torch.Tensor`
|
|
mean: the mean of the normal distribution
|
|
std: the standard deviation of the normal distribution
|
|
a: the minimum cutoff value
|
|
b: the maximum cutoff value
|
|
Examples:
|
|
>>> w = torch.empty(3, 5)
|
|
>>> nn.init.trunc_normal_(w)
|
|
"""
|
|
with torch.no_grad():
|
|
_trunc_normal_(tensor, 0, 1.0, a, b)
|
|
tensor.mul_(std).add_(mean)
|
|
return tensor
|
|
|
|
|
|
class NomicBertPreTrainedModel(PreTrainedModel):
|
|
"""An abstract class to handle weights initialization and
|
|
a simple interface for dowloading and loading pretrained models.
|
|
"""
|
|
|
|
config_class = NomicBertConfig
|
|
base_model_prefix = "model"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["Block"]
|
|
_skip_keys_device_placement = "past_key_values"
|
|
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super().__init__(config)
|
|
if not isinstance(config, GPT2Config):
|
|
raise ValueError(
|
|
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
|
|
"To create a model from a Google pretrained model use "
|
|
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
|
self.__class__.__name__, self.__class__.__name__
|
|
)
|
|
)
|
|
self.config = config
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_name, config=None, *inputs, **kwargs):
|
|
"""
|
|
Instantiate a NomicBertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
|
Download and cache the pre-trained model file if needed.
|
|
|
|
Params:
|
|
pretrained_model_name_or_path: either:
|
|
- a path or url to a pretrained model archive containing:
|
|
. `bert_config.json` a configuration file for the model
|
|
. `pytorch_model.bin` a PyTorch dump of a NomicBertForPretraining instance
|
|
- a path or url to a pretrained model archive containing:
|
|
. `bert_config.json` a configuration file for the model
|
|
. `model.chkpt` a TensorFlow checkpoint
|
|
*inputs, **kwargs: additional input for the specific NomicBert class
|
|
(ex: num_labels for NomicBertForSequenceClassification)
|
|
"""
|
|
|
|
if config is None:
|
|
config = cls.config_class.from_pretrained(model_name)
|
|
remove_cls = cls != NomicBertForPreTraining
|
|
remove_bert_prefix = cls not in [NomicBertForPreTraining, NomicBertForSequenceClassification, NomicBertForTokenClassification, NomicBertForMultipleChoice, NomicBertForQuestionAnswering]
|
|
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
|
num_labels = kwargs.pop("num_labels", None)
|
|
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
|
strict = kwargs.pop("strict", True)
|
|
dtype = kwargs.pop("torch_dtype", None)
|
|
if rotary_scaling_factor:
|
|
config.rotary_scaling_factor = rotary_scaling_factor
|
|
|
|
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
|
config.n_positions = 2048
|
|
if num_labels:
|
|
config.num_labels = num_labels
|
|
|
|
if "add_pooling_layer" in kwargs:
|
|
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
|
else:
|
|
if cls == NomicBertModel:
|
|
model = cls(config, *inputs, add_pooling_layer=False)
|
|
else:
|
|
model = cls(config, *inputs)
|
|
|
|
if dtype is not None:
|
|
model = model.to(dtype=dtype)
|
|
|
|
|
|
|
|
if os.path.exists(model_name):
|
|
model_path = f"{model_name}/pytorch_model.bin"
|
|
if os.path.exists(model_path):
|
|
state_dict = torch.load(f"{model_name}/pytorch_model.bin")
|
|
else:
|
|
model_path = f"{model_name}/model.safetensors"
|
|
if not os.path.exists(model_path):
|
|
raise ValueError(f"Model path {model_path} not found")
|
|
state_dict = safe_load_file(model_path)
|
|
|
|
if ignore_mismatched_shapes:
|
|
state_dict = filter_shapes(state_dict, model)
|
|
load_return = model.load_state_dict(state_dict, strict=False)
|
|
else:
|
|
|
|
state_dict = state_dict_from_pretrained(model_name, dtype=dtype)
|
|
state_dict = remap_bert_state_dict(
|
|
state_dict,
|
|
config,
|
|
remove_bert=remove_bert_prefix,
|
|
remove_cls_weights=remove_cls,
|
|
add_pooling_layer=getattr(config, "add_pooling_layer", False),
|
|
)
|
|
if ignore_mismatched_shapes:
|
|
state_dict = filter_shapes(state_dict, model)
|
|
|
|
load_return = model.load_state_dict(state_dict, strict=strict)
|
|
logger.warning(load_return)
|
|
return model
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False):
|
|
if isinstance(module, NomicBertEncoder):
|
|
module.gradient_checkpointing = value
|
|
|
|
|
|
|
|
def _init_weights(module, initializer_range=0.02):
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.normal_(module.weight, std=initializer_range)
|
|
if module.bias is not None:
|
|
nn.init.zeros_(module.bias)
|
|
elif isinstance(module, nn.Embedding):
|
|
nn.init.normal_(module.weight, std=initializer_range)
|
|
if module.padding_idx is not None:
|
|
nn.init.zeros_(module.weight[module.padding_idx])
|
|
|
|
|
|
def _ntuple(n):
|
|
def parse(x):
|
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
|
return tuple(x)
|
|
return tuple(repeat(x, n))
|
|
|
|
return parse
|
|
|
|
|
|
to_1tuple = _ntuple(1)
|
|
to_2tuple = _ntuple(2)
|
|
to_3tuple = _ntuple(3)
|
|
to_4tuple = _ntuple(4)
|
|
to_ntuple = _ntuple
|
|
|
|
|
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
|
"""
|
|
Create 2D sin/cos positional embeddings.
|
|
|
|
Args:
|
|
embed_dim (`int`):
|
|
Embedding dimension.
|
|
grid_size (`int`):
|
|
The grid height and width.
|
|
add_cls_token (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to add a classification (CLS) token.
|
|
|
|
Returns:
|
|
(`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
|
|
position embeddings (with or without classification token)
|
|
"""
|
|
grid_h = np.arange(grid_size, dtype=np.float32)
|
|
|
|
grid_w = np.arange(grid_size, dtype=np.float32)
|
|
grid = np.meshgrid(grid_w, grid_h)
|
|
grid = np.stack(grid, axis=0)
|
|
|
|
grid = grid.reshape([2, 1, grid_size, grid_size])
|
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
|
if add_cls_token:
|
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
|
return pos_embed
|
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
|
if embed_dim % 2 != 0:
|
|
raise ValueError("embed_dim must be even")
|
|
|
|
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
|
|
|
emb = np.concatenate([emb_h, emb_w], axis=1)
|
|
return emb
|
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
"""
|
|
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
|
"""
|
|
if embed_dim % 2 != 0:
|
|
raise ValueError("embed_dim must be even")
|
|
|
|
omega = np.arange(embed_dim // 2, dtype=float)
|
|
omega /= embed_dim / 2.0
|
|
omega = 1.0 / 10000**omega
|
|
|
|
pos = pos.reshape(-1)
|
|
out = np.einsum("m,d->md", pos, omega)
|
|
|
|
emb_sin = np.sin(out)
|
|
emb_cos = np.cos(out)
|
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1)
|
|
return emb
|
|
|
|
|
|
def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
|
|
"""generate N-D grid in dimension order.
|
|
|
|
The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
|
|
|
|
That is, the statement
|
|
[X1,X2,X3] = ndgrid(x1,x2,x3)
|
|
|
|
produces the same result as
|
|
|
|
[X2,X1,X3] = meshgrid(x2,x1,x3)
|
|
|
|
This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
|
|
torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
|
|
|
|
"""
|
|
try:
|
|
return torch.meshgrid(*tensors, indexing='ij')
|
|
except TypeError:
|
|
|
|
|
|
return torch.meshgrid(*tensors)
|
|
|
|
|
|
def build_fourier_pos_embed(
|
|
feat_shape: List[int],
|
|
bands: Optional[torch.Tensor] = None,
|
|
num_bands: int = 64,
|
|
max_res: int = 224,
|
|
temperature: float = 10000.0,
|
|
linear_bands: bool = False,
|
|
include_grid: bool = False,
|
|
in_pixels: bool = True,
|
|
ref_feat_shape: Optional[List[int]] = None,
|
|
dtype: torch.dtype = torch.float32,
|
|
device: Optional[torch.device] = None,
|
|
) -> List[torch.Tensor]:
|
|
"""
|
|
|
|
Args:
|
|
feat_shape: Feature shape for embedding.
|
|
bands: Pre-calculated frequency bands.
|
|
num_bands: Number of frequency bands (determines output dim).
|
|
max_res: Maximum resolution for pixel based freq.
|
|
temperature: Temperature for non-pixel freq.
|
|
linear_bands: Linear band spacing for pixel based freq.
|
|
include_grid: Include the spatial grid in output.
|
|
in_pixels: Output in pixel freq.
|
|
ref_feat_shape: Reference feature shape for resize / fine-tune.
|
|
dtype: Output dtype.
|
|
device: Output device.
|
|
|
|
Returns:
|
|
|
|
"""
|
|
if bands is None:
|
|
if in_pixels:
|
|
bands = pixel_freq_bands(
|
|
num_bands,
|
|
float(max_res),
|
|
linear_bands=linear_bands,
|
|
device=device,
|
|
)
|
|
else:
|
|
bands = freq_bands(
|
|
num_bands,
|
|
temperature=temperature,
|
|
step=1,
|
|
device=device,
|
|
)
|
|
else:
|
|
if device is None:
|
|
device = bands.device
|
|
if dtype is None:
|
|
dtype = bands.dtype
|
|
|
|
if in_pixels:
|
|
t = [torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=torch.float32) for s in feat_shape]
|
|
else:
|
|
t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
|
|
|
|
if ref_feat_shape is not None:
|
|
|
|
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
|
|
|
|
grid = torch.stack(ndgrid(t), dim=-1)
|
|
grid = grid.unsqueeze(-1)
|
|
pos = grid * bands
|
|
|
|
pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype)
|
|
out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
|
|
return out
|
|
|
|
|
|
def build_rotary_pos_embed(
|
|
feat_shape: List[int],
|
|
bands: Optional[torch.Tensor] = None,
|
|
dim: int = 64,
|
|
max_res: int = 224,
|
|
temperature: float = 10000.0,
|
|
linear_bands: bool = False,
|
|
in_pixels: bool = True,
|
|
ref_feat_shape: Optional[List[int]] = None,
|
|
dtype: torch.dtype = torch.float32,
|
|
device: Optional[torch.device] = None,
|
|
):
|
|
"""
|
|
|
|
Args:
|
|
feat_shape: Spatial shape of the target tensor for embedding.
|
|
bands: Optional pre-generated frequency bands
|
|
dim: Output dimension of embedding tensor.
|
|
max_res: Maximum resolution for pixel mode.
|
|
temperature: Temperature (inv freq) for non-pixel mode
|
|
linear_bands: Linearly (instead of log) spaced bands for pixel mode
|
|
in_pixels: Pixel vs language (inv freq) mode.
|
|
dtype: Output dtype.
|
|
device: Output device.
|
|
|
|
Returns:
|
|
|
|
"""
|
|
sin_emb, cos_emb = build_fourier_pos_embed(
|
|
feat_shape,
|
|
bands=bands,
|
|
num_bands=dim // 4,
|
|
max_res=max_res,
|
|
temperature=temperature,
|
|
linear_bands=linear_bands,
|
|
in_pixels=in_pixels,
|
|
ref_feat_shape=ref_feat_shape,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
num_spatial_dim = 1
|
|
|
|
for x in feat_shape:
|
|
num_spatial_dim *= x
|
|
sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
|
|
cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
|
|
return sin_emb, cos_emb
|
|
|
|
|
|
def freq_bands(
|
|
num_bands: int,
|
|
temperature: float = 10000.0,
|
|
step: int = 2,
|
|
device: Optional[torch.device] = None,
|
|
) -> torch.Tensor:
|
|
exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
|
|
bands = 1.0 / (temperature**exp)
|
|
return bands
|
|
|
|
|
|
def pixel_freq_bands(
|
|
num_bands: int,
|
|
max_freq: float = 224.0,
|
|
linear_bands: bool = True,
|
|
device: Optional[torch.device] = None,
|
|
):
|
|
if linear_bands:
|
|
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
|
|
else:
|
|
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
|
|
return bands * torch.pi
|
|
|
|
|
|
def rot(x):
|
|
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
|
|
|
|
|
|
def apply_rot_embed_cat(x: torch.Tensor, emb):
|
|
sin_emb, cos_emb = emb.tensor_split(2, -1)
|
|
if sin_emb.ndim == 3:
|
|
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
|
|
return x * cos_emb + rot(x) * sin_emb
|
|
|
|
|
|
|
|
class NomicVisionRotaryEmbeddingCat(nn.Module):
|
|
"""Rotary position embedding w/ concatenatd sin & cos
|
|
|
|
The following impl/resources were referenced for this impl:
|
|
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
|
|
* https://blog.eleuther.ai/rotary-embeddings/
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
max_res=224,
|
|
temperature=10000,
|
|
in_pixels=True,
|
|
linear_bands: bool = False,
|
|
feat_shape: Optional[List[int]] = None,
|
|
ref_feat_shape: Optional[List[int]] = None,
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.max_res = max_res
|
|
self.temperature = temperature
|
|
self.in_pixels = in_pixels
|
|
self.feat_shape = feat_shape
|
|
self.ref_feat_shape = ref_feat_shape
|
|
|
|
if feat_shape is None:
|
|
|
|
if in_pixels:
|
|
bands = pixel_freq_bands(
|
|
dim // 4,
|
|
float(max_res),
|
|
linear_bands=linear_bands,
|
|
)
|
|
else:
|
|
bands = freq_bands(
|
|
dim // 4,
|
|
temperature=temperature,
|
|
step=1,
|
|
)
|
|
self.register_buffer(
|
|
'bands',
|
|
bands,
|
|
persistent=False,
|
|
)
|
|
self.pos_embed = None
|
|
else:
|
|
|
|
embeds = build_rotary_pos_embed(
|
|
feat_shape=feat_shape,
|
|
dim=dim,
|
|
max_res=max_res,
|
|
linear_bands=linear_bands,
|
|
in_pixels=in_pixels,
|
|
ref_feat_shape=self.ref_feat_shape,
|
|
)
|
|
self.bands = None
|
|
self.register_buffer(
|
|
'pos_embed',
|
|
torch.cat(embeds, -1),
|
|
persistent=False,
|
|
)
|
|
|
|
def get_embed(self, shape: Optional[List[int]] = None):
|
|
if self.bands is not None and shape is not None:
|
|
|
|
embeds = build_rotary_pos_embed(
|
|
shape,
|
|
self.bands,
|
|
in_pixels=self.in_pixels,
|
|
ref_feat_shape=self.ref_feat_shape,
|
|
)
|
|
return torch.cat(embeds, -1)
|
|
elif self.pos_embed is not None:
|
|
return self.pos_embed
|
|
else:
|
|
assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands"
|
|
|
|
def forward(self, x):
|
|
|
|
pos_embed = self.get_embed(x.shape[2:])
|
|
return apply_rot_embed_cat(x, pos_embed)
|
|
|
|
|
|
class NomicVisionPatchEmbeddings(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
):
|
|
super().__init__()
|
|
img_size = _pair(config.img_size)
|
|
patch_size = _pair(config.patch_size)
|
|
self.img_size = img_size
|
|
self.patch_size = patch_size
|
|
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
|
|
self.proj = nn.Linear(
|
|
config.num_channels * patch_size[0] * patch_size[1], config.n_embd, bias=config.patch_embed_bias
|
|
)
|
|
|
|
self.learned_pos_embedding = False
|
|
self.sinusoidal_pos_embedding = False
|
|
self.no_embed_class = getattr(config, "no_embed_class", False)
|
|
|
|
self.cls_token = (
|
|
nn.Parameter(torch.zeros(1, 1, config.n_embd)) if not getattr(config, "no_cls_token", False) else None
|
|
)
|
|
if config.learned_pos_embedding:
|
|
|
|
self.learned_pos_embedding = True
|
|
|
|
num_patches = self.num_patches if getattr(config, "register_tokens", 0) > 0 else self.num_patches + 1
|
|
self.pos_embed = (
|
|
nn.Parameter(torch.randn(1, num_patches, config.n_embd) * 0.02)
|
|
if getattr(config, "use_pos_embed", True)
|
|
else None
|
|
)
|
|
elif getattr(config, "sinusoidal_pos_embedding", False):
|
|
self.sinusoidal_pos_embedding = True
|
|
if getattr(config, "use_pos_embed", True):
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, config.n_embd), requires_grad=False)
|
|
pos_embed = get_2d_sincos_pos_embed(config.n_embd, self.grid_size[0], add_cls_token=True)
|
|
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).to(self.pos_embed))
|
|
else:
|
|
self.pos_embed = None
|
|
else:
|
|
self.pos_embed = (
|
|
nn.Parameter(torch.randn(1, self.num_patches + 1, config.n_embd) * 0.02)
|
|
if getattr(config, "use_pos_embed", True)
|
|
else None
|
|
)
|
|
|
|
if getattr(config, "register_tokens", 0) > 0:
|
|
self.reg_token = nn.Parameter(torch.randn(1, config.register_tokens, config.n_embd) * 0.02)
|
|
else:
|
|
self.reg_token = None
|
|
|
|
if config.mask_token:
|
|
self.mask_token = nn.Parameter(torch.zeros(1, config.n_embd))
|
|
|
|
self.patch_dropout = nn.Identity()
|
|
|
|
if getattr(config, "use_rotary_pos_emb", False):
|
|
ref_feat_shape = getattr(config, "ref_feat_shape", None)
|
|
ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
|
|
self.rope = NomicVisionRotaryEmbeddingCat(
|
|
config.n_embd // config.n_head,
|
|
in_pixels=False,
|
|
feat_shape=self.grid_size,
|
|
ref_feat_shape=ref_feat_shape,
|
|
)
|
|
else:
|
|
self.rope = None
|
|
|
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
|
"""
|
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
|
resolution images.
|
|
|
|
Source:
|
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
|
"""
|
|
num_patches = embeddings.shape[1] - 1
|
|
num_positions = self.pos_embed.shape[1] - 1
|
|
if num_patches == num_positions and height == width:
|
|
return self.pos_embed
|
|
class_pos_embed = self.pos_embed[:, 0]
|
|
patch_pos_embed = self.pos_embed[:, 1:]
|
|
dim = embeddings.shape[-1]
|
|
height = height // self.patch_size[0]
|
|
width = width // self.patch_size[1]
|
|
|
|
|
|
height, width = height + 0.1, width + 0.1
|
|
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
|
patch_pos_embed = nn.functional.interpolate(
|
|
patch_pos_embed,
|
|
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
|
|
mode="bicubic",
|
|
align_corners=False,
|
|
)
|
|
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
|
|
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
|
|
|
def forward(self, x):
|
|
|
|
if x.dtype != self.proj.weight.dtype:
|
|
x = x.to(dtype=self.proj.weight.dtype)
|
|
|
|
_, _, height, width = x.shape
|
|
x = self.proj(
|
|
rearrange(
|
|
x,
|
|
"b c (h p1) (w p2) -> b h w (c p1 p2)",
|
|
p1=self.patch_size[0],
|
|
p2=self.patch_size[1],
|
|
)
|
|
)
|
|
embeddings = rearrange(x, "b h w c -> b (h w) c")
|
|
|
|
to_cat = []
|
|
if self.cls_token is not None:
|
|
if self.sinusoidal_pos_embedding:
|
|
cls_token = self.cls_token + self.pos_embed[:, 0]
|
|
cls_token = cls_token.expand(embeddings.shape[0], -1, -1)
|
|
to_cat += [cls_token]
|
|
else:
|
|
cls_token = self.cls_token.expand(embeddings.shape[0], 1, -1)
|
|
to_cat += [cls_token]
|
|
|
|
if self.reg_token is not None:
|
|
to_cat += [self.reg_token.expand(embeddings.shape[0], -1, -1)]
|
|
|
|
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
|
|
|
|
if self.no_embed_class:
|
|
if self.learned_pos_embedding:
|
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
|
else:
|
|
if self.pos_embed is not None:
|
|
embeddings = embeddings + self.pos_embed
|
|
if to_cat:
|
|
embeddings = torch.cat(to_cat + [embeddings], dim=1)
|
|
else:
|
|
if to_cat:
|
|
embeddings = torch.cat(to_cat + [embeddings], dim=1)
|
|
if self.learned_pos_embedding:
|
|
if self.pos_embed is not None:
|
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
|
else:
|
|
if self.pos_embed is not None:
|
|
embeddings = embeddings + self.pos_embed
|
|
|
|
embeddings = self.patch_dropout(embeddings)
|
|
|
|
return embeddings, rot_pos_embed
|
|
|
|
|
|
class NomicBertEmbeddings(nn.Module):
|
|
def __init__(self, config):
|
|
"""
|
|
If max_position_embeddings <= 0, there's no position embeddings
|
|
If type_vocab_size <= 0, there's no token type embeddings
|
|
"""
|
|
super().__init__()
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
|
self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
|
|
self.type_vocab_size = config.type_vocab_size
|
|
if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
|
|
self.position_embeddings = nn.Embedding(
|
|
config.max_position_embeddings,
|
|
config.hidden_size,
|
|
)
|
|
if self.type_vocab_size > 0:
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
|
|
|
def forward(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None):
|
|
"""
|
|
input_ids: (batch, seqlen)
|
|
position_ids: (batch, seqlen)
|
|
token_type_ids: (batch, seqlen)
|
|
"""
|
|
if inputs_embeds is None:
|
|
embeddings = self.word_embeddings(input_ids)
|
|
else:
|
|
embeddings = inputs_embeds
|
|
batch_size, seqlen, _ = embeddings.shape
|
|
|
|
if self.type_vocab_size > 0:
|
|
if token_type_ids is None:
|
|
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=embeddings.device)
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
embeddings = embeddings + token_type_embeddings
|
|
|
|
if self.max_position_embeddings > 0:
|
|
if position_ids is None:
|
|
position_ids = torch.arange(seqlen, dtype=torch.long, device=embeddings.device)
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
embeddings = embeddings + position_embeddings
|
|
return embeddings
|
|
|
|
|
|
class NomicBertMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
hidden_features=None,
|
|
out_features=None,
|
|
activation=F.gelu,
|
|
bias1=True,
|
|
bias2=True,
|
|
return_residual=False,
|
|
fused_bias_fc=False,
|
|
):
|
|
super().__init__()
|
|
out_features = out_features if out_features is not None else in_features
|
|
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
|
self.return_residual = return_residual
|
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
|
|
approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
|
|
self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
|
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
|
|
|
|
def forward(self, x):
|
|
y = self.fc1(x)
|
|
y = self.activation(y)
|
|
y = self.fc2(y)
|
|
return y if not self.return_residual else (y, x)
|
|
|
|
|
|
class NomciBertGatedMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
hidden_features=None,
|
|
out_features=None,
|
|
activation=F.sigmoid,
|
|
bias1=True,
|
|
bias2=True,
|
|
multiple_of=256,
|
|
return_residual=False,
|
|
fused_bias_fc=True,
|
|
device=None,
|
|
dtype=None,
|
|
norm_layer=False,
|
|
):
|
|
super().__init__()
|
|
out_features = out_features if out_features is not None else in_features
|
|
hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
|
hidden_features = int((hidden_features + multiple_of - 1) // multiple_of * multiple_of)
|
|
self.return_residual = return_residual
|
|
|
|
self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1)
|
|
self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1)
|
|
self.activation = activation
|
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
|
|
self.norm = nn.LayerNorm(hidden_features) if norm_layer else nn.Identity()
|
|
|
|
def forward(self, x):
|
|
y = self.fc11(x)
|
|
gate = self.fc12(x)
|
|
if self.activation == F.sigmoid:
|
|
y = F.glu(torch.cat([y, gate], dim=-1), dim=-1)
|
|
else:
|
|
y = y * self.activation(gate)
|
|
|
|
|
|
y = self.norm(y)
|
|
|
|
y = self.fc2(y)
|
|
return y if not self.return_residual else (y, x)
|
|
|
|
|
|
def rotate_half(x, interleaved=False):
|
|
if not interleaved:
|
|
x1, x2 = x.chunk(2, dim=-1)
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
else:
|
|
x1, x2 = x[..., ::2], x[..., 1::2]
|
|
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
|
|
|
|
|
|
def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False):
|
|
"""
|
|
x: (batch_size, seqlen, nheads, headdim)
|
|
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
|
"""
|
|
ro_dim = cos.shape[-1] * 2
|
|
assert ro_dim <= x.shape[-1]
|
|
cos, sin = (
|
|
cos[offset : offset + x.shape[1]],
|
|
sin[offset : offset + x.shape[1]],
|
|
)
|
|
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
|
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
|
return torch.cat(
|
|
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
|
|
dim=-1,
|
|
)
|
|
|
|
|
|
class NomicBertRotaryEmbedding(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
base=10000.0,
|
|
interleaved=False,
|
|
scale_base=None,
|
|
pos_idx_in_fp32=True,
|
|
device=None,
|
|
):
|
|
"""
|
|
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
|
of 1st half and 2nd half (GPT-NeoX style).
|
|
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
|
otherwise they might be in lower precision.
|
|
This option was added because previously (before 2023-07-02), when we construct
|
|
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
|
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
|
self.inv_freq would be bf16, and the position indices are also in bf16.
|
|
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
|
embeddings for some positions will coincide.
|
|
To maintain compatibility with models previously trained in pure bf16,
|
|
we add this option.
|
|
"""
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.base = float(base)
|
|
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
|
|
|
inv_freq = self._compute_inv_freq(device)
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
self.interleaved = interleaved
|
|
self.scale_base = scale_base
|
|
scale = (
|
|
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
|
if scale_base is not None
|
|
else None
|
|
)
|
|
self.register_buffer("scale", scale, persistent=False)
|
|
|
|
self._seq_len_cached = 0
|
|
self._cos_cached = None
|
|
self._sin_cached = None
|
|
self._cos_k_cached = None
|
|
self._sin_k_cached = None
|
|
|
|
def _compute_inv_freq(self, device=None):
|
|
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
|
|
|
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
|
|
|
|
|
|
|
if (
|
|
seqlen > self._seq_len_cached
|
|
or self._cos_cached is None
|
|
or self._cos_cached.device != device
|
|
or self._cos_cached.dtype != dtype
|
|
or (self.training and self._cos_cached.is_inference())
|
|
):
|
|
self._seq_len_cached = seqlen
|
|
|
|
|
|
|
|
if self.pos_idx_in_fp32:
|
|
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
if self.inv_freq.dtype != torch.float32:
|
|
inv_freq = self._compute_inv_freq(device=device)
|
|
else:
|
|
inv_freq = self.inv_freq
|
|
else:
|
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
|
inv_freq = self.inv_freq
|
|
|
|
|
|
freqs = torch.outer(t, inv_freq)
|
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
|
|
|
def forward(
|
|
self,
|
|
qkv: torch.Tensor,
|
|
kv: Optional[torch.Tensor] = None,
|
|
seqlen_offset: Union[int, torch.Tensor] = 0,
|
|
max_seqlen: Optional[int] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
|
|
else it's just q of shape (batch, seqlen, nheads, headdim)
|
|
kv: (batch, seqlen, 2, nheads, headdim)
|
|
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
|
Most commonly used in inference when we have KV cache.
|
|
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
|
|
should pass in max_seqlen, which will update the cos / sin cache up to that length.
|
|
Apply rotary embedding *inplace* to qkv and / or kv.
|
|
"""
|
|
seqlen = qkv.shape[1]
|
|
if seqlen > self._seq_len_cached:
|
|
self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
|
|
elif max_seqlen is not None:
|
|
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
|
elif isinstance(seqlen_offset, int):
|
|
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
|
|
|
q_rot = apply_rotary_emb(qkv[:, :, 0], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
|
|
k_rot = apply_rotary_emb(qkv[:, :, 1], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
|
|
return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
|
|
|
|
|
|
class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
|
|
def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.rotary_scaling_factor = rotary_scaling_factor
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
def _compute_inv_freq(self, base=None, device=None):
|
|
if base is None:
|
|
base = self.base
|
|
return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
|
|
|
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
|
|
|
|
|
|
|
if seqlen > self.max_position_embeddings:
|
|
base = self.base * (
|
|
(self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1)
|
|
) ** (self.dim / (self.dim - 2))
|
|
inv_freq = self._compute_inv_freq(base=base, device=device)
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
if (
|
|
seqlen > self._seq_len_cached
|
|
or self._cos_cached is None
|
|
or self._cos_cached.device != device
|
|
or self._cos_cached.dtype != dtype
|
|
or (self.training and self._cos_cached.is_inference())
|
|
):
|
|
self._seq_len_cached = seqlen
|
|
|
|
|
|
|
|
if self.pos_idx_in_fp32:
|
|
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
if self.inv_freq.dtype != torch.float32:
|
|
if seqlen > self.max_position_embeddings:
|
|
base = self.base * (
|
|
(self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)
|
|
) ** (self.dim / (self.dim - 2))
|
|
else:
|
|
base = self.base
|
|
inv_freq = self._compute_inv_freq(device=device, base=base)
|
|
else:
|
|
inv_freq = self.inv_freq
|
|
else:
|
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
|
inv_freq = self.inv_freq
|
|
|
|
|
|
freqs = torch.outer(t, inv_freq)
|
|
if self.scale is None:
|
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
|
else:
|
|
power = (
|
|
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
|
) / self.scale_base
|
|
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
|
|
|
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
|
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
|
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
|
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
|
|
|
|
|
class NomicBertAttention(nn.Module):
|
|
"""Multi-head self-attention and cross-attention"""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
) -> None:
|
|
"""
|
|
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
|
return_residual: whether to return the input x along with the output. This is for
|
|
performance reason: for post-norm architecture, returning the input allows us
|
|
to fuse the backward of nn.Linear with the residual connection.
|
|
"""
|
|
super().__init__()
|
|
self.embed_dim = config.n_embd
|
|
self.use_flash_attn = config.use_flash_attn
|
|
self.fused_bias_fc = config.fused_bias_fc
|
|
|
|
self.num_heads = config.n_head
|
|
self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
|
|
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
|
|
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
|
|
|
self.register_buffer(
|
|
"norm_factor",
|
|
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
|
|
persistent=False,
|
|
)
|
|
|
|
self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
|
|
if self.rotary_emb_dim > 0:
|
|
if getattr(config, "rotary_scaling_factor", None):
|
|
self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
|
|
dim=self.rotary_emb_dim,
|
|
base=config.rotary_emb_base,
|
|
scale_base=config.rotary_emb_scale_base,
|
|
interleaved=config.rotary_emb_interleaved,
|
|
rotary_scaling_factor=config.rotary_scaling_factor,
|
|
max_position_embeddings=config.max_trained_positions,
|
|
)
|
|
else:
|
|
self.rotary_emb = NomicBertRotaryEmbedding(
|
|
dim=self.rotary_emb_dim,
|
|
base=config.rotary_emb_base,
|
|
scale_base=config.rotary_emb_scale_base,
|
|
interleaved=config.rotary_emb_interleaved,
|
|
)
|
|
|
|
|
|
self.rotary_head_dim = getattr(config, "rotary_head_dim", False)
|
|
|
|
self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=config.qkv_proj_bias)
|
|
|
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
|
|
self.causal = config.causal
|
|
self.drop = nn.Dropout(config.attn_pdrop)
|
|
self.num_prefix_tokens = max(getattr(config, "register_tokens", 1), 1)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
is_padded_inputs: Optional[bool] = True,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seq_len: Optional[int] = None,
|
|
rope: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
has_layer_past = past_key_value is not None
|
|
|
|
if has_layer_past:
|
|
past_key_value = past_key_value[0]
|
|
past_len = past_key_value[1]
|
|
else:
|
|
past_len = 0
|
|
|
|
qkv = self.Wqkv(hidden_states)
|
|
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
|
|
|
past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None
|
|
|
|
if self.rotary_emb_dim > 0:
|
|
if self.rotary_head_dim:
|
|
qkv = rearrange(qkv, "b s three h d -> b h three s d")
|
|
qkv = self.rotary_emb(qkv, seqlen_offset=past_len)
|
|
|
|
if self.rotary_head_dim:
|
|
qkv = rearrange(qkv, "b h three s d -> b s three h d")
|
|
elif rope is not None:
|
|
q, k, v = qkv.permute(0, 3, 1, 2, 4).unbind(dim=-2)
|
|
q = torch.cat(
|
|
[q[:, :, : self.num_prefix_tokens], apply_rot_embed_cat(q[:, :, self.num_prefix_tokens :], rope)], dim=2
|
|
).type_as(q)
|
|
k = torch.cat(
|
|
[k[:, :, : self.num_prefix_tokens], apply_rot_embed_cat(k[:, :, self.num_prefix_tokens :], rope)], dim=2
|
|
).type_as(q)
|
|
|
|
qkv = torch.stack([q, k, v], dim=-2)
|
|
qkv = rearrange(qkv, "b h s three d -> b s three h d")
|
|
|
|
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
|
|
|
query = query.permute(0, 2, 1, 3)
|
|
key = key.permute(0, 2, 1, 3)
|
|
value = value.permute(0, 2, 1, 3)
|
|
if scaled_dot_product_attention is not None:
|
|
attn_output = F.scaled_dot_product_attention(
|
|
query, key, value, attn_mask=attention_mask, dropout_p=self.drop.p, is_causal=False
|
|
)
|
|
else:
|
|
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
|
|
if attention_mask is not None:
|
|
attention_scores = attention_scores + attention_mask
|
|
|
|
attentions_probs = F.softmax(attention_scores, dim=-1)
|
|
attentions_probs = self.drop(attentions_probs)
|
|
|
|
attn_output = torch.matmul(attentions_probs, value)
|
|
|
|
attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output
|
|
|
|
|
|
class NomicBertBlock(NomicBertPreTrainedModel):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
):
|
|
super().__init__(config=config)
|
|
self.prenorm = config.prenorm
|
|
self.fused_dropout_add_ln = config.fused_dropout_add_ln
|
|
|
|
self.attn = NomicBertAttention(config)
|
|
activation = (
|
|
F.sigmoid
|
|
if config.activation_function == "glu"
|
|
else (F.silu if config.activation_function == "swiglu" else F.gelu)
|
|
)
|
|
if config.activation_function in ["glu", "swiglu", "geglu"]:
|
|
self.mlp = NomciBertGatedMLP(
|
|
config.n_embd,
|
|
hidden_features=config.n_inner,
|
|
bias1=config.mlp_fc1_bias,
|
|
bias2=config.mlp_fc2_bias,
|
|
activation=activation,
|
|
fused_bias_fc=config.fused_bias_fc,
|
|
norm_layer=getattr(config, "norm_mlp", False),
|
|
)
|
|
else:
|
|
self.mlp = NomicBertMLP(
|
|
config.n_embd,
|
|
hidden_features=config.n_inner,
|
|
bias1=config.mlp_fc1_bias,
|
|
bias2=config.mlp_fc2_bias,
|
|
activation=activation,
|
|
fused_bias_fc=config.fused_bias_fc,
|
|
)
|
|
|
|
self.dropout1 = nn.Dropout(config.resid_pdrop)
|
|
self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
self.norm2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
self.dropout2 = nn.Dropout(config.resid_pdrop)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
hidden_states2: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
is_padded_inputs: Optional[bool] = True,
|
|
output_attentions: Optional[bool] = False,
|
|
use_cache: Optional[bool] = False,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seq_len: Optional[int] = None,
|
|
rope: Optional[torch.Tensor] = None,
|
|
):
|
|
r"""Pass the input through the encoder layer.
|
|
|
|
Args:
|
|
hidden_states: the sequence to the encoder layer (required).
|
|
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
|
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
|
before applying the query projection. Useful for e.g., ViT where we only care
|
|
about the CLS token in the last layer.
|
|
"""
|
|
if self.prenorm:
|
|
dropped = self.dropout1(hidden_states)
|
|
residual = (dropped + residual) if residual is not None else dropped
|
|
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
|
hidden_states = self.attn(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
is_padded_inputs=is_padded_inputs,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seq_len=max_seq_len,
|
|
rope=rope,
|
|
)
|
|
|
|
dropped = self.dropout2(hidden_states)
|
|
residual = (dropped + residual) if residual is not None else dropped
|
|
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
return hidden_states, None, residual
|
|
else:
|
|
assert residual is None
|
|
attn_outputs = self.attn(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
is_padded_inputs=is_padded_inputs,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seq_len=max_seq_len,
|
|
rope=rope,
|
|
)
|
|
hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
|
|
mlp_out = self.mlp(hidden_states)
|
|
|
|
hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
|
|
return hidden_states, None, None
|
|
|
|
|
|
class NomicBertEncoder(nn.Module):
|
|
def __init__(self, config: GPT2Config):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
|
|
self.gradient_checkpointing = False
|
|
self.config = config
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
is_padded_inputs: Optional[bool] = True,
|
|
rope: Optional[torch.Tensor] = None,
|
|
):
|
|
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
|
This means that we only compute the last layer output for these tokens.
|
|
subset_mask: (batch, seqlen), dtype=torch.bool
|
|
"""
|
|
hidden_states2 = None
|
|
residual = None
|
|
|
|
for _, layer in enumerate(self.layers):
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
|
|
return module(*inputs)
|
|
|
|
return custom_forward
|
|
|
|
hidden_states, hidden_states2, residual = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(layer),
|
|
hidden_states,
|
|
hidden_states2,
|
|
residual,
|
|
attention_mask,
|
|
position_ids,
|
|
past_key_values,
|
|
is_padded_inputs,
|
|
output_attentions,
|
|
use_cache,
|
|
None,
|
|
None,
|
|
rope,
|
|
|
|
|
|
|
|
use_reentrant=False,
|
|
)
|
|
|
|
else:
|
|
hidden_states, hidden_states2, residual = layer(
|
|
hidden_states,
|
|
hidden_states2,
|
|
residual,
|
|
attention_mask,
|
|
position_ids,
|
|
None,
|
|
is_padded_inputs,
|
|
output_attentions,
|
|
use_cache,
|
|
rope=rope,
|
|
)
|
|
return hidden_states
|
|
|
|
|
|
class NomicBertPooler(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.n_embd, config.n_embd)
|
|
self.activation = nn.Tanh()
|
|
|
|
def forward(self, hidden_states, pool=True):
|
|
|
|
|
|
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
|
pooled_output = self.dense(first_token_tensor)
|
|
pooled_output = self.activation(pooled_output)
|
|
return pooled_output
|
|
|
|
|
|
class NomicBertPredictionHeadTransform(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
|
|
approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
|
|
if config.activation_function == "swiglu":
|
|
self.transform_act_fn = F.silu
|
|
else:
|
|
self.transform_act_fn = nn.GELU(approximate=approximate)
|
|
|
|
self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.transform_act_fn(hidden_states)
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class NomicBertLMPredictionHead(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.transform = NomicBertPredictionHeadTransform(config)
|
|
|
|
self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=config.mlp_fc1_bias)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.transform(hidden_states)
|
|
hidden_states = self.decoder(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class NomicBertPreTrainingHeads(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.predictions = NomicBertLMPredictionHead(config)
|
|
|
|
def forward(self, sequence_output):
|
|
prediction_scores = self.predictions(sequence_output)
|
|
return prediction_scores
|
|
|
|
|
|
class NomicBertModel(NomicBertPreTrainedModel):
|
|
def __init__(self, config: GPT2Config, add_pooling_layer=True):
|
|
super().__init__(config)
|
|
self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
|
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
|
config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple)
|
|
|
|
assert config.activation_function in [
|
|
"gelu",
|
|
"gelu_new",
|
|
"gelu_fast",
|
|
"gelu_pytorch_tanh",
|
|
"swiglu",
|
|
"geglu",
|
|
"glu",
|
|
]
|
|
|
|
self.embeddings = NomicBertEmbeddings(config)
|
|
self.emb_drop = nn.Dropout(config.resid_pdrop)
|
|
self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
self.encoder = NomicBertEncoder(config)
|
|
self.pooler = NomicBertPooler(config) if add_pooling_layer else None
|
|
|
|
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
position_ids=None,
|
|
token_type_ids=None,
|
|
return_dict=None,
|
|
matryoshka_dim=None,
|
|
inputs_embeds=None,
|
|
):
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
hidden_states = self.embeddings(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
token_type_ids=token_type_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
hidden_states = self.emb_ln(hidden_states)
|
|
hidden_states = self.emb_drop(hidden_states)
|
|
|
|
attention_mask = self.get_extended_attention_mask(attention_mask, hidden_states.shape[:-1])
|
|
sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
|
|
|
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
|
|
|
if matryoshka_dim:
|
|
sequence_output = sequence_output[:, :matryoshka_dim]
|
|
|
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
)
|
|
|
|
|
|
class NomicBertForPreTraining(NomicBertPreTrainedModel):
|
|
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
|
|
|
def __init__(self, config: GPT2Config):
|
|
super().__init__(config)
|
|
|
|
self.bert = NomicBertModel(config, add_pooling_layer=getattr(config, "add_pooling_layer", False))
|
|
self.cls = NomicBertPreTrainingHeads(config)
|
|
self.mlm_loss = nn.CrossEntropyLoss()
|
|
|
|
|
|
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
|
self.tie_weights()
|
|
|
|
def tie_weights(self):
|
|
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
|
|
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
position_ids=None,
|
|
token_type_ids=None,
|
|
attention_mask=None,
|
|
labels=None,
|
|
):
|
|
"""
|
|
If labels are provided, they must be -100 for masked out tokens (as specified in the attention
|
|
mask).
|
|
Outputs:
|
|
if `labels` and `next_sentence_label` are not `None`:
|
|
Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
|
sentence classification loss.
|
|
if `labels` or `next_sentence_label` is `None`:
|
|
Outputs a tuple comprising
|
|
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
|
- the next sentence classification logits of shape [batch_size, 2].
|
|
|
|
"""
|
|
outputs = self.bert(
|
|
input_ids,
|
|
position_ids=position_ids,
|
|
token_type_ids=token_type_ids,
|
|
attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
|
)
|
|
sequence_output, _ = outputs.last_hidden_state, outputs.pooler_output
|
|
|
|
prediction_scores = self.cls(sequence_output)
|
|
|
|
total_loss = None
|
|
if labels is not None:
|
|
masked_lm_loss = self.mlm_loss(
|
|
rearrange(prediction_scores, "... v -> (...) v"),
|
|
rearrange(labels, "... -> (...)"),
|
|
)
|
|
total_loss = masked_lm_loss.float()
|
|
|
|
return MaskedLMOutput(
|
|
loss=total_loss,
|
|
logits=prediction_scores,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=None,
|
|
)
|
|
|
|
|
|
class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
self.config = config
|
|
|
|
self.bert = NomicBertModel(config)
|
|
classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop)
|
|
self.dropout = nn.Dropout(classifier_dropout)
|
|
self.classifier = nn.Linear(config.n_embd, config.num_labels)
|
|
|
|
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
):
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
outputs = self.bert(
|
|
input_ids,
|
|
position_ids=position_ids,
|
|
token_type_ids=token_type_ids,
|
|
attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
|
)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output)
|
|
logits = self.classifier(pooled_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
if self.config.problem_type is None:
|
|
if self.num_labels == 1:
|
|
self.config.problem_type = "regression"
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
self.config.problem_type = "single_label_classification"
|
|
else:
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
if self.config.problem_type == "regression":
|
|
loss_fct = nn.MSELoss()
|
|
if self.num_labels == 1:
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
|
else:
|
|
loss = loss_fct(logits, labels)
|
|
elif self.config.problem_type == "single_label_classification":
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
loss_fct = nn.BCEWithLogitsLoss()
|
|
loss = loss_fct(logits, labels)
|
|
if not return_dict:
|
|
output = (logits,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return SequenceClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
class NomicBertForMultipleChoice(NomicBertPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.bert = NomicBertModel(config, add_pooling_layer=True)
|
|
classifier_dropout = (
|
|
getattr(config, "classifier_dropout", config.resid_pdrop)
|
|
)
|
|
self.dropout = nn.Dropout(classifier_dropout)
|
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
|
|
|
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
unpad_inputs: Optional[bool] = None,
|
|
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
|
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
|
`input_ids` above)
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
|
|
|
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
|
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
|
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
|
inputs_embeds = (
|
|
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
|
if inputs_embeds is not None
|
|
else None
|
|
)
|
|
|
|
outputs = self.bert(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output)
|
|
logits = self.classifier(pooled_output)
|
|
reshaped_logits = logits.view(-1, num_choices)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
loss = loss_fct(reshaped_logits, labels)
|
|
|
|
if not return_dict:
|
|
output = (reshaped_logits,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return MultipleChoiceModelOutput(
|
|
loss=loss,
|
|
logits=reshaped_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
class NomicBertForTokenClassification(NomicBertPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.bert = NomicBertModel(config, add_pooling_layer=False)
|
|
classifier_dropout = (
|
|
getattr(config, "classifier_dropout", config.resid_pdrop)
|
|
)
|
|
self.dropout = nn.Dropout(classifier_dropout)
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.bert(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
sequence_output = self.dropout(sequence_output)
|
|
logits = self.classifier(sequence_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return TokenClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
class NomicBertForQuestionAnswering(NomicBertPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.bert = NomicBertModel(config, add_pooling_layer=False)
|
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
start_positions: Optional[torch.Tensor] = None,
|
|
end_positions: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
|
r"""
|
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
|
are not taken into account for computing the loss.
|
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
|
are not taken into account for computing the loss.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.bert(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
logits = self.qa_outputs(sequence_output)
|
|
start_logits, end_logits = logits.split(1, dim=-1)
|
|
start_logits = start_logits.squeeze(-1).contiguous()
|
|
end_logits = end_logits.squeeze(-1).contiguous()
|
|
|
|
total_loss = None
|
|
if start_positions is not None and end_positions is not None:
|
|
|
|
if len(start_positions.size()) > 1:
|
|
start_positions = start_positions.squeeze(-1)
|
|
if len(end_positions.size()) > 1:
|
|
end_positions = end_positions.squeeze(-1)
|
|
|
|
ignored_index = start_logits.size(1)
|
|
start_positions = start_positions.clamp(0, ignored_index)
|
|
end_positions = end_positions.clamp(0, ignored_index)
|
|
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
|
|
start_loss = loss_fct(start_logits, start_positions)
|
|
end_loss = loss_fct(end_logits, end_positions)
|
|
total_loss = (start_loss + end_loss) / 2
|
|
|
|
if not return_dict:
|
|
output = (start_logits, end_logits) + outputs[2:]
|
|
return ((total_loss,) + output) if total_loss is not None else output
|
|
|
|
return QuestionAnsweringModelOutput(
|
|
loss=total_loss,
|
|
start_logits=start_logits,
|
|
end_logits=end_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
|
|
return GPT2Config(
|
|
n_embd=vit_config.hidden_size,
|
|
n_layer=vit_config.num_hidden_layers,
|
|
n_head=vit_config.num_attention_heads,
|
|
n_inner=vit_config.intermediate_size,
|
|
activation_function=vit_config.hidden_act,
|
|
vocab_size=0,
|
|
n_positions=0,
|
|
resid_pdrop=0.0,
|
|
embd_pdrop=getattr(vit_config, "dropout", 0.0),
|
|
attn_pdrop=vit_config.attention_probs_dropout_prob,
|
|
layer_norm_epsilon=vit_config.layer_norm_eps,
|
|
initializer_range=vit_config.initializer_range,
|
|
bos_token_id=None,
|
|
eos_token_id=None,
|
|
|
|
drop_path_rate=0.0,
|
|
|
|
prepre_layernom=False,
|
|
layer_scale=False,
|
|
layer_scale_init=None,
|
|
img_size=vit_config.image_size,
|
|
patch_size=vit_config.patch_size,
|
|
num_channels=vit_config.num_channels,
|
|
prenorm=True,
|
|
parallel_block=False,
|
|
parallel_block_tied_norm=False,
|
|
rotary_emb_fraction=0,
|
|
tie_word_embeddings=False,
|
|
fused_dropout_add_ln=True,
|
|
fused_bias_fc=True,
|
|
patch_embed_bias=True,
|
|
use_flash_attn=True,
|
|
qkv_proj_bias=True,
|
|
mlp_fc1_bias=getattr(vit_config, "mlp_fc1_bias", True),
|
|
mlp_fc2_bias=getattr(vit_config, "mlp_fc2_bias", True),
|
|
use_rms_norm=False,
|
|
causal=False,
|
|
hidden_features_scaling_factor=1.0,
|
|
mask_token=False,
|
|
learned_pos_embedding=False,
|
|
patch_dropout=0,
|
|
sinusoidal_pos_embedding=vit_config.model_type == "vit_mae",
|
|
)
|
|
|
|
|
|
class NomicAttentionPooling(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.embed_dim = config.n_embd
|
|
self.use_flash_attn = config.use_flash_attn
|
|
self.fused_bias_fc = config.fused_bias_fc
|
|
|
|
self.num_heads = config.n_head
|
|
self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
|
|
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
|
|
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
|
|
|
self.register_buffer(
|
|
"norm_factor",
|
|
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
|
|
persistent=False,
|
|
)
|
|
|
|
self.Wq = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
|
|
self.Wkv = nn.Linear(self.embed_dim, kv_dim, bias=config.qkv_proj_bias)
|
|
|
|
self.latent = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
|
|
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
|
|
self.causal = config.causal
|
|
self.drop = nn.Dropout(config.attn_pdrop)
|
|
|
|
def init_weights(self):
|
|
trunc_normal_tf_(self.latent, std=self.embed_dim**-0.5)
|
|
|
|
def forward(
|
|
self,
|
|
kv,
|
|
attention_mask=None,
|
|
cu_seqlens_k=None,
|
|
max_seqlen_k=None,
|
|
is_padded_inputs: Optional[bool] = True,
|
|
output_attentions: bool = False,
|
|
):
|
|
"""Implements the multihead softmax attention.
|
|
Arguments
|
|
---------
|
|
q: The tensor containing the query. (B, Sq, H, D)
|
|
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
|
|
causal: if passed, will override self.causal
|
|
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into q.
|
|
max_seqlen: int. Maximum sequence length in the batch of q.
|
|
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into kv.
|
|
max_seqlen_k: int. Maximum sequence length in the batch of k and v.
|
|
"""
|
|
q_latent = self.latent.expand(kv.size(0), -1, -1)
|
|
q = self.Wq(q_latent)
|
|
bsz, q_len, h_size = q.shape
|
|
kv = self.Wkv(kv)
|
|
query = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
|
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
|
|
|
key, value = kv[:, :, 0], kv[:, :, 1]
|
|
|
|
query = query.permute(0, 2, 1, 3)
|
|
key = key.permute(0, 2, 1, 3)
|
|
value = value.permute(0, 2, 1, 3)
|
|
|
|
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
|
|
if attention_mask is not None:
|
|
attention_scores = attention_scores + attention_mask
|
|
|
|
attentions_probs = F.softmax(attention_scores, dim=-1)
|
|
attentions_probs = self.drop(attentions_probs)
|
|
|
|
attn_output = torch.matmul(attentions_probs, value)
|
|
attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output
|
|
|
|
|
|
class NomicMultiHeadAttentionPooling(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
):
|
|
super().__init__()
|
|
self.prenorm = config.prenorm
|
|
self.fused_dropout_add_ln = config.fused_dropout_add_ln
|
|
|
|
self.attn = NomicAttentionPooling(config)
|
|
activation = (
|
|
F.sigmoid
|
|
if config.activation_function == "glu"
|
|
else (F.silu if config.activation_function == "swiglu" else F.gelu)
|
|
)
|
|
if config.activation_function in ["glu", "swiglu", "geglu"]:
|
|
self.mlp = NomciBertGatedMLP(
|
|
config.n_embd,
|
|
hidden_features=config.n_inner,
|
|
bias1=config.mlp_fc1_bias,
|
|
bias2=config.mlp_fc2_bias,
|
|
activation=activation,
|
|
fused_bias_fc=config.fused_bias_fc,
|
|
)
|
|
else:
|
|
self.mlp = NomicBertMLP(
|
|
config.n_embd,
|
|
hidden_features=config.n_inner,
|
|
bias1=config.mlp_fc1_bias,
|
|
bias2=config.mlp_fc2_bias,
|
|
activation=activation,
|
|
fused_bias_fc=config.fused_bias_fc,
|
|
)
|
|
|
|
self.dropout1 = nn.Dropout(config.resid_pdrop)
|
|
self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
self.dropout2 = nn.Dropout(config.resid_pdrop)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
):
|
|
r"""Pass the input through the encoder layer.
|
|
|
|
Args:
|
|
hidden_states: the sequence to the encoder layer (required).
|
|
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
|
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
|
before applying the query projection. Useful for e.g., ViT where we only care
|
|
about the CLS token in the last layer.
|
|
"""
|
|
|
|
attn_outputs = self.attn(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
)
|
|
|
|
normed = self.norm1(attn_outputs)
|
|
hidden_states = hidden_states + self.mlp(normed)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class NomicVisionPreTrainedModel(PreTrainedModel):
|
|
"""An abstract class to handle weights initialization and
|
|
a simple interface for dowloading and loading pretrained models.
|
|
"""
|
|
|
|
config_class = NomicBertConfig
|
|
base_model_prefix = "model"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["Block"]
|
|
_skip_keys_device_placement = "past_key_values"
|
|
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super().__init__(config)
|
|
if not isinstance(config, GPT2Config):
|
|
raise ValueError(
|
|
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
|
|
"To create a model from a Google pretrained model use "
|
|
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
|
self.__class__.__name__, self.__class__.__name__
|
|
)
|
|
)
|
|
self.config = config
|
|
|
|
|
|
class NomicVisionModel(NomicVisionPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.embeddings = NomicVisionPatchEmbeddings(config)
|
|
self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
|
|
|
|
self.selector = NomicMultiHeadAttentionPooling(config)
|
|
|
|
self.global_pool = getattr(config, "global_pool", None)
|
|
self.num_prefix_tokens = (1 if not getattr(config, "no_cls_token", False) else 0) + getattr(
|
|
config, "register_tokens", 0
|
|
)
|
|
|
|
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values,
|
|
attention_mask=None,
|
|
position_ids=None,
|
|
token_type_ids=None,
|
|
return_dict=None,
|
|
matryoshka_dim=None,
|
|
):
|
|
embeddings, rope = self.embeddings(pixel_values)
|
|
|
|
original_dtype = embeddings.dtype
|
|
|
|
hidden_states = embeddings
|
|
|
|
residual = None
|
|
for layer in self.layers:
|
|
|
|
hidden_states, _, residual = layer(
|
|
hidden_states, None, residual=residual, is_padded_inputs=False, rope=rope
|
|
)
|
|
|
|
hidden_states = hidden_states + residual
|
|
if self.global_pool == "avg":
|
|
hidden_states = hidden_states[:, self.num_prefix_tokens :].mean(dim=1)
|
|
|
|
pooled_output = self.selector(hidden_states)
|
|
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=pooled_output,
|
|
hidden_states=hidden_states,
|
|
)
|
|
|