Spaces:
Running
Running
musa: Upgrade MUSA SDK version to rc4.0.1 and use mudnn::Unary::IDENTITY op to accelerate D2D memory copy (llama/13647)
9506ebb
| namespace mudnn = musa::dnn; | |
| // Returns a human-readable error string for mudnn::Status | |
| const char* mudnnGetErrorString(mudnn::Status err) { | |
| switch (err) { | |
| case mudnn::Status::SUCCESS: | |
| return "Success"; | |
| case mudnn::Status::INVALID_PARAMETER: | |
| return "Invalid parameter"; | |
| case mudnn::Status::NOT_INITIALIZED: | |
| return "Not initialized"; | |
| case mudnn::Status::ALLOC_FAILED: | |
| return "Allocation failed"; | |
| case mudnn::Status::NOT_SUPPORTED: | |
| return "Not supported"; | |
| case mudnn::Status::INTERNAL_ERROR: | |
| return "Internal error"; | |
| case mudnn::Status::ARCH_MISMATCH: | |
| return "Architecture mismatch"; | |
| case mudnn::Status::EXECUTION_FAILED: | |
| return "Execution failed"; | |
| default: | |
| return "Unknown mudnn status"; | |
| } | |
| } | |
| // Error checking macro for MUDNN calls | |
| namespace { | |
| // Thread-safe cache for mudnn::Handle objects per device | |
| std::unordered_map<int, std::unique_ptr<mudnn::Handle>> handle_cache; | |
| std::mutex handle_cache_mutex; | |
| mudnn::Handle* get_cached_handle(int device_id) { | |
| std::lock_guard<std::mutex> lock(handle_cache_mutex); | |
| auto it = handle_cache.find(device_id); | |
| if (it != handle_cache.end()) { | |
| return it->second.get(); | |
| } | |
| auto handle = std::make_unique<mudnn::Handle>(device_id); | |
| mudnn::Handle* handle_ptr = handle.get(); | |
| handle_cache[device_id] = std::move(handle); | |
| return handle_ptr; | |
| } | |
| } | |
| // Extracts dimensions and strides from a ggml_tensor | |
| int get_ggml_dims_and_strides(const ggml_tensor* tensor, | |
| std::vector<int64_t>& dims, | |
| std::vector<int64_t>& strides) { | |
| const int ndims = ggml_n_dims(tensor); | |
| const size_t element_size = ggml_element_size(tensor); | |
| dims.resize(ndims); | |
| strides.resize(ndims); | |
| for (int i = 0; i < ndims; ++i) { | |
| dims[i] = tensor->ne[i]; | |
| strides[i] = tensor->nb[i] / static_cast<int64_t>(element_size); | |
| } | |
| return ndims; | |
| } | |
| // Converts ggml_type to mudnn::Tensor::Type | |
| mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) { | |
| switch (type) { | |
| case GGML_TYPE_F32: | |
| return mudnn::Tensor::Type::FLOAT; | |
| case GGML_TYPE_F16: | |
| return mudnn::Tensor::Type::HALF; | |
| // TODO: Add support for other types | |
| default: | |
| MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED); | |
| } | |
| return mudnn::Tensor::Type::FLOAT; // Default fallback | |
| } | |
| // Asynchronous memory copy using mudnn::Unary::IDENTITY | |
| musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) { | |
| mudnn::Tensor tensor_dst, tensor_src; | |
| MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type))); | |
| MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type))); | |
| std::vector<int64_t> dims, strides; | |
| const int ndims = get_ggml_dims_and_strides(src, dims, strides); | |
| MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data())); | |
| MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data())); | |
| MUDNN_CHECK(tensor_dst.SetAddr(dst->data)); | |
| MUDNN_CHECK(tensor_src.SetAddr(src->data)); | |
| mudnn::Unary op; | |
| MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY)); | |
| MUDNN_CHECK(op.SetAlpha(0.0f)); | |
| MUDNN_CHECK(op.SetBeta(0.0f)); | |
| mudnn::Handle* handle = get_cached_handle(ctx.device); | |
| MUDNN_CHECK(handle->SetStream(ctx.stream())); | |
| MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src)); | |
| return musaSuccess; | |
| } | |