#include #include "registration.h" #include "torch_binding.h" TORCH_LIBRARY_EXPAND( TORCH_EXTENSION_NAME, ops) { ops.def("gather(" "Tensor x, " "Tensor indices, " "Tensor bins, " "Tensor! output, " "int E, " "int C, " "int top_k) -> ()"); ops.impl("gather", torch::kCUDA, &gather_cuda); ops.def("scatter(" "Tensor src, " "Tensor indices, " "Tensor bins, " "Tensor weights, " "Tensor! y, " "int T, " "int E, " "int C, " "int top_k) -> ()"); ops.impl("scatter", torch::kCUDA, &scatter_cuda); ops.def("sort(" "Tensor x, " "int end_bit, " "Tensor! x_out, " "Tensor! iota_out) -> ()"); ops.impl("sort", torch::kCUDA, &sort_cuda); ops.def("bincount_cumsum(" "Tensor input, " "Tensor! output, " "int minlength) -> ()"); ops.impl("bincount_cumsum", torch::kCUDA, &bincount_cumsum_cuda); ops.def("index_select_out(" "Tensor! out, " "Tensor input, " "Tensor idx_int32) -> Tensor"); ops.impl("index_select_out", torch::kCUDA, &index_select_out_cuda); ops.def("batch_mm(" "Tensor x, " "Tensor weights, " "Tensor batch_sizes, " "Tensor! output, " "bool trans_b=False) -> Tensor"); ops.impl("batch_mm", torch::kCUDA, &batch_mm); ops.def("experts(" "Tensor hidden_states, " "Tensor router_indices, " "Tensor routing_weights, " "Tensor gate_up_proj, " "Tensor gate_up_proj_bias, " "Tensor down_proj, " "Tensor down_proj_bias, " "int expert_capacity, " "int num_experts, " "int top_k) -> Tensor"); ops.impl("experts", torch::kCUDA, &experts_cuda); ops.def("experts_backward(" "Tensor grad_out, " "Tensor hidden_states, " "Tensor router_indices, " "Tensor routing_weights, " "Tensor gate_up_proj, " "Tensor gate_up_proj_bias, " "Tensor down_proj, " "Tensor down_proj_bias, " "int expert_capacity, " "int num_experts, " "int top_k) -> Tensor[]"); ops.impl("experts_backward", torch::kCUDA, &experts_backward_cuda); } REGISTER_EXTENSION( TORCH_EXTENSION_NAME)