import torch | |
import torch.nn as nn | |
from ._ops import ops | |
class LlamaRMSNorm(nn.Module): | |
weight: torch.Tensor | |
variance_epsilon: float | |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
return ops.rmsnorm_forward( | |
hidden_states, | |
self.weight, | |
bias=None, | |
residual=None, | |
eps=self.variance_epsilon, | |
dropout_p=0.0, | |
prenorm=False, | |
residual_in_fp32=False, | |
) |