yeahdongcn JohannesGaessler commited on
Commit
9506ebb
·
1 Parent(s): 8602d10

musa: Upgrade MUSA SDK version to rc4.0.1 and use mudnn::Unary::IDENTITY op to accelerate D2D memory copy (llama/13647)

Browse files

* musa: fix build warning (unused parameter)

Signed-off-by: Xiaodong Ye <[email protected]>

* musa: upgrade MUSA SDK version to rc4.0.1

Signed-off-by: Xiaodong Ye <[email protected]>

* musa: use mudnn::Unary::IDENTITY op to accelerate D2D memory copy

Signed-off-by: Xiaodong Ye <[email protected]>

* Update ggml/src/ggml-cuda/cpy.cu

Co-authored-by: Johannes Gäßler <[email protected]>

* musa: remove MUDNN_CHECK_GEN and use CUDA_CHECK_GEN instead in MUDNN_CHECK

Signed-off-by: Xiaodong Ye <[email protected]>

---------

Signed-off-by: Xiaodong Ye <[email protected]>
Co-authored-by: Johannes Gäßler <[email protected]>

ggml/src/ggml-cuda/cpy.cu CHANGED
@@ -1,5 +1,8 @@
1
  #include "cpy.cuh"
2
  #include "dequantize.cuh"
 
 
 
3
 
4
  typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
5
 
@@ -597,7 +600,14 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
597
  #endif
598
  if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
599
  GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
600
- CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
 
 
 
 
 
 
 
601
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
602
  ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
603
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
 
1
  #include "cpy.cuh"
2
  #include "dequantize.cuh"
3
+ #ifdef GGML_USE_MUSA
4
+ #include "ggml-musa/mudnn.cuh"
5
+ #endif // GGML_USE_MUSA
6
 
7
  typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
8
 
 
600
  #endif
601
  if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
602
  GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
603
+ #ifdef GGML_USE_MUSA
604
+ if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
605
+ CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
606
+ } else
607
+ #endif // GGML_USE_MUSA
608
+ {
609
+ CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
610
+ }
611
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
612
  ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
613
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
ggml/src/ggml-cuda/fattn-mma-f16.cuh CHANGED
@@ -772,7 +772,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
772
  GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
773
  GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
774
  GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
775
- GGML_UNUSED(kb0);
776
  NO_DEVICE_CODE;
777
  #endif // NEW_MMA_AVAILABLE
778
  }
 
772
  GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
773
  GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
774
  GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
775
+ GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
776
  NO_DEVICE_CODE;
777
  #endif // NEW_MMA_AVAILABLE
778
  }
ggml/src/ggml-musa/CMakeLists.txt CHANGED
@@ -27,12 +27,15 @@ if (MUSAToolkit_FOUND)
27
 
28
  file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
29
  list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
 
30
 
31
  file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
32
  file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
33
  list(APPEND GGML_SOURCES_MUSA ${SRCS})
34
  file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
35
  list(APPEND GGML_SOURCES_MUSA ${SRCS})
 
 
36
 
37
  if (GGML_CUDA_FA_ALL_QUANTS)
38
  file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
@@ -62,7 +65,9 @@ if (MUSAToolkit_FOUND)
62
  )
63
 
64
  # TODO: do not use CUDA definitions for MUSA
65
- target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
 
 
66
 
67
  add_compile_definitions(GGML_USE_MUSA)
68
  add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
@@ -92,9 +97,10 @@ if (MUSAToolkit_FOUND)
92
  endif()
93
 
94
  if (GGML_STATIC)
 
95
  target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
96
  else()
97
- target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
98
  endif()
99
 
100
  if (GGML_CUDA_NO_VMM)
 
27
 
28
  file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
29
  list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
30
+ list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
31
 
32
  file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
33
  file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
34
  list(APPEND GGML_SOURCES_MUSA ${SRCS})
35
  file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
36
  list(APPEND GGML_SOURCES_MUSA ${SRCS})
37
+ file(GLOB SRCS "../ggml-musa/*.cu")
38
+ list(APPEND GGML_SOURCES_MUSA ${SRCS})
39
 
40
  if (GGML_CUDA_FA_ALL_QUANTS)
41
  file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
 
65
  )
66
 
67
  # TODO: do not use CUDA definitions for MUSA
68
+ if (NOT GGML_BACKEND_DL)
69
+ target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
70
+ endif()
71
 
72
  add_compile_definitions(GGML_USE_MUSA)
73
  add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
 
97
  endif()
98
 
99
  if (GGML_STATIC)
100
+ # TODO: mudnn has not provided static libraries yet
101
  target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
102
  else()
103
+ target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn)
104
  endif()
105
 
106
  if (GGML_CUDA_NO_VMM)
ggml/src/ggml-musa/mudnn.cu ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <mutex>
2
+ #include <mudnn.h>
3
+
4
+ #include "mudnn.cuh"
5
+
6
+ namespace mudnn = musa::dnn;
7
+
8
+ // Returns a human-readable error string for mudnn::Status
9
+ const char* mudnnGetErrorString(mudnn::Status err) {
10
+ switch (err) {
11
+ case mudnn::Status::SUCCESS:
12
+ return "Success";
13
+ case mudnn::Status::INVALID_PARAMETER:
14
+ return "Invalid parameter";
15
+ case mudnn::Status::NOT_INITIALIZED:
16
+ return "Not initialized";
17
+ case mudnn::Status::ALLOC_FAILED:
18
+ return "Allocation failed";
19
+ case mudnn::Status::NOT_SUPPORTED:
20
+ return "Not supported";
21
+ case mudnn::Status::INTERNAL_ERROR:
22
+ return "Internal error";
23
+ case mudnn::Status::ARCH_MISMATCH:
24
+ return "Architecture mismatch";
25
+ case mudnn::Status::EXECUTION_FAILED:
26
+ return "Execution failed";
27
+ default:
28
+ return "Unknown mudnn status";
29
+ }
30
+ }
31
+
32
+ // Error checking macro for MUDNN calls
33
+ #define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString)
34
+
35
+ namespace {
36
+ // Thread-safe cache for mudnn::Handle objects per device
37
+ std::unordered_map<int, std::unique_ptr<mudnn::Handle>> handle_cache;
38
+ std::mutex handle_cache_mutex;
39
+
40
+ mudnn::Handle* get_cached_handle(int device_id) {
41
+ std::lock_guard<std::mutex> lock(handle_cache_mutex);
42
+ auto it = handle_cache.find(device_id);
43
+ if (it != handle_cache.end()) {
44
+ return it->second.get();
45
+ }
46
+ auto handle = std::make_unique<mudnn::Handle>(device_id);
47
+ mudnn::Handle* handle_ptr = handle.get();
48
+ handle_cache[device_id] = std::move(handle);
49
+ return handle_ptr;
50
+ }
51
+ }
52
+
53
+ // Extracts dimensions and strides from a ggml_tensor
54
+ int get_ggml_dims_and_strides(const ggml_tensor* tensor,
55
+ std::vector<int64_t>& dims,
56
+ std::vector<int64_t>& strides) {
57
+ const int ndims = ggml_n_dims(tensor);
58
+ const size_t element_size = ggml_element_size(tensor);
59
+
60
+ dims.resize(ndims);
61
+ strides.resize(ndims);
62
+
63
+ for (int i = 0; i < ndims; ++i) {
64
+ dims[i] = tensor->ne[i];
65
+ strides[i] = tensor->nb[i] / static_cast<int64_t>(element_size);
66
+ }
67
+ return ndims;
68
+ }
69
+
70
+ // Converts ggml_type to mudnn::Tensor::Type
71
+ mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) {
72
+ switch (type) {
73
+ case GGML_TYPE_F32:
74
+ return mudnn::Tensor::Type::FLOAT;
75
+ case GGML_TYPE_F16:
76
+ return mudnn::Tensor::Type::HALF;
77
+
78
+ // TODO: Add support for other types
79
+
80
+ default:
81
+ MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED);
82
+ }
83
+
84
+ return mudnn::Tensor::Type::FLOAT; // Default fallback
85
+ }
86
+
87
+ // Asynchronous memory copy using mudnn::Unary::IDENTITY
88
+ musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) {
89
+ mudnn::Tensor tensor_dst, tensor_src;
90
+
91
+ MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type)));
92
+ MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type)));
93
+
94
+ std::vector<int64_t> dims, strides;
95
+ const int ndims = get_ggml_dims_and_strides(src, dims, strides);
96
+
97
+ MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data()));
98
+ MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data()));
99
+ MUDNN_CHECK(tensor_dst.SetAddr(dst->data));
100
+ MUDNN_CHECK(tensor_src.SetAddr(src->data));
101
+
102
+ mudnn::Unary op;
103
+ MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY));
104
+ MUDNN_CHECK(op.SetAlpha(0.0f));
105
+ MUDNN_CHECK(op.SetBeta(0.0f));
106
+
107
+ mudnn::Handle* handle = get_cached_handle(ctx.device);
108
+ MUDNN_CHECK(handle->SetStream(ctx.stream()));
109
+ MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src));
110
+
111
+ return musaSuccess;
112
+ }
ggml/src/ggml-musa/mudnn.cuh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "../include/ggml.h"
4
+ #include "../ggml-cuda/common.cuh"
5
+
6
+ // Asynchronously copies data from src tensor to dst tensor using the provided context.
7
+ // Returns a musaError_t indicating success or failure.
8
+ musaError_t mudnnMemcpyAsync(
9
+ ggml_backend_cuda_context &ctx,
10
+ const ggml_tensor *dst,
11
+ const ggml_tensor *src
12
+ );