medmekk's picture
medmekk HF Staff
Upload custom kernels
4303459 verified
raw
history blame contribute delete
489 Bytes
import torch
import torch.nn as nn
from ._ops import ops
class LlamaRMSNorm(nn.Module):
weight: torch.Tensor
variance_epsilon: float
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return ops.rmsnorm_forward(
hidden_states,
self.weight,
bias=None,
residual=None,
eps=self.variance_epsilon,
dropout_p=0.0,
prenorm=False,
residual_in_fp32=False,
)