from typing import TYPE_CHECKING import torch # neuron has torch version that doesn't even have impl_abstract if TYPE_CHECKING: def register_fake(fn): return lambda name: fn else: try: from torch.library import register_fake except ImportError: from torch.library import impl_abstract as register_fake try: from ._ops import ops, add_op_namespace_prefix except ImportError as e: # Fallback for local development. try: import _moe ops = torch._moe def add_op_namespace_prefix(op_name: str): return f"_quantization::{op_name}" except ImportError: raise e from .scalar_type import ScalarType def gptq_marlin_moe_repack( b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int, ) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 output = torch.empty( (num_experts, size_k // 16, size_n * (num_bits // 2)), device=b_q_weight.device, dtype=b_q_weight.dtype, ) for e in range(num_experts): output[e] = ops.gptq_marlin_repack( b_q_weight[e], perm[e], size_k, size_n, num_bits ) return output def awq_marlin_moe_repack( b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int, ) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 output = torch.empty( (num_experts, size_k // 16, size_n * (num_bits // 2)), device=b_q_weight.device, dtype=b_q_weight.dtype, ) for e in range(num_experts): output[e] = ops.awq_marlin_repack(b_q_weight[e], size_k, size_n, num_bits) return output def moe_sum(input: torch.Tensor, output: torch.Tensor): ops.moe_sum(input, output) def moe_align_block_size( topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor, ) -> None: ops.moe_align_block_size( topk_ids, num_experts, block_size, sorted_token_ids, experts_ids, num_tokens_post_pad, ) def topk_softmax( topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indicies: torch.Tensor, gating_output: float, ) -> None: ops.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output) if hasattr(ops, "marlin_gemm_moe"): @register_fake(add_op_namespace_prefix("marlin_gemm_moe")) def marlin_gemm_moe_fake( a: torch.Tensor, b_q_weights: torch.Tensor, sorted_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, b_scales: torch.Tensor, b_zero_points: torch.Tensor, g_idx: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, b_q_type: ScalarType, size_m: torch.SymInt, size_n: torch.SymInt, size_k: torch.SymInt, is_k_full: bool, num_experts: int, topk: int, moe_block_size: int, replicate_input: bool, apply_weights: bool, ) -> torch.Tensor: return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device) def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: ops.silu_and_mul(out, x) return out