Spaces:
Sleeping
Sleeping
Commit
·
a35db11
1
Parent(s):
1b11fde
feat: Support Moore Threads GPU (llama/8383)
Browse files* Update doc for MUSA
Signed-off-by: Xiaodong Ye <[email protected]>
* Add GGML_MUSA in Makefile
Signed-off-by: Xiaodong Ye <[email protected]>
* Add GGML_MUSA in CMake
Signed-off-by: Xiaodong Ye <[email protected]>
* CUDA => MUSA
Signed-off-by: Xiaodong Ye <[email protected]>
* MUSA adds support for __vsubss4
Signed-off-by: Xiaodong Ye <[email protected]>
* Fix CI build failure
Signed-off-by: Xiaodong Ye <[email protected]>
---------
Signed-off-by: Xiaodong Ye <[email protected]>
- ggml/CMakeLists.txt +1 -0
- ggml/include/ggml-cuda.h +3 -0
- ggml/src/CMakeLists.txt +55 -7
- ggml/src/ggml-common.h +5 -1
- ggml/src/ggml-cuda.cu +13 -9
- ggml/src/ggml-cuda/common.cuh +191 -3
ggml/CMakeLists.txt
CHANGED
|
@@ -113,6 +113,7 @@ set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
|
|
| 113 |
option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF)
|
| 114 |
|
| 115 |
option(GGML_CUDA "ggml: use CUDA" OFF)
|
|
|
|
| 116 |
option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
|
| 117 |
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
|
| 118 |
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)
|
|
|
|
| 113 |
option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF)
|
| 114 |
|
| 115 |
option(GGML_CUDA "ggml: use CUDA" OFF)
|
| 116 |
+
option(GGML_MUSA "ggml: use MUSA" OFF)
|
| 117 |
option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
|
| 118 |
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
|
| 119 |
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)
|
ggml/include/ggml-cuda.h
CHANGED
|
@@ -6,6 +6,9 @@
|
|
| 6 |
#ifdef GGML_USE_HIPBLAS
|
| 7 |
#define GGML_CUDA_NAME "ROCm"
|
| 8 |
#define GGML_CUBLAS_NAME "hipBLAS"
|
|
|
|
|
|
|
|
|
|
| 9 |
#else
|
| 10 |
#define GGML_CUDA_NAME "CUDA"
|
| 11 |
#define GGML_CUBLAS_NAME "cuBLAS"
|
|
|
|
| 6 |
#ifdef GGML_USE_HIPBLAS
|
| 7 |
#define GGML_CUDA_NAME "ROCm"
|
| 8 |
#define GGML_CUBLAS_NAME "hipBLAS"
|
| 9 |
+
#elif defined(GGML_USE_MUSA)
|
| 10 |
+
#define GGML_CUDA_NAME "MUSA"
|
| 11 |
+
#define GGML_CUBLAS_NAME "muBLAS"
|
| 12 |
#else
|
| 13 |
#define GGML_CUDA_NAME "CUDA"
|
| 14 |
#define GGML_CUBLAS_NAME "cuBLAS"
|
ggml/src/CMakeLists.txt
CHANGED
|
@@ -139,6 +139,17 @@ if (GGML_METAL)
|
|
| 139 |
)
|
| 140 |
endif()
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
if (GGML_OPENMP)
|
| 143 |
find_package(OpenMP)
|
| 144 |
if (OpenMP_FOUND)
|
|
@@ -147,6 +158,11 @@ if (GGML_OPENMP)
|
|
| 147 |
add_compile_definitions(GGML_USE_OPENMP)
|
| 148 |
|
| 149 |
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
else()
|
| 151 |
message(WARNING "OpenMP not found")
|
| 152 |
endif()
|
|
@@ -249,7 +265,13 @@ endif()
|
|
| 249 |
if (GGML_CUDA)
|
| 250 |
cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
|
| 251 |
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
if (CUDAToolkit_FOUND)
|
| 255 |
message(STATUS "CUDA found")
|
|
@@ -268,7 +290,11 @@ if (GGML_CUDA)
|
|
| 268 |
endif()
|
| 269 |
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
| 270 |
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
file(GLOB GGML_HEADERS_CUDA "ggml-cuda/*.cuh")
|
| 274 |
list(APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h")
|
|
@@ -332,21 +358,40 @@ if (GGML_CUDA)
|
|
| 332 |
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
|
| 333 |
endif()
|
| 334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
if (GGML_STATIC)
|
| 336 |
if (WIN32)
|
| 337 |
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
|
| 338 |
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
|
| 339 |
else ()
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
endif()
|
| 342 |
else()
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
endif()
|
| 345 |
|
| 346 |
if (GGML_CUDA_NO_VMM)
|
| 347 |
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
|
| 348 |
else()
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
endif()
|
| 351 |
else()
|
| 352 |
message(WARNING "CUDA not found")
|
|
@@ -857,8 +902,10 @@ function(get_flags CCID CCVER)
|
|
| 857 |
set(C_FLAGS -Wdouble-promotion)
|
| 858 |
set(CXX_FLAGS -Wno-array-bounds)
|
| 859 |
|
| 860 |
-
if (
|
| 861 |
-
|
|
|
|
|
|
|
| 862 |
endif()
|
| 863 |
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
|
| 864 |
list(APPEND CXX_FLAGS -Wextra-semi)
|
|
@@ -1264,6 +1311,7 @@ endif()
|
|
| 1264 |
target_compile_definitions(ggml PUBLIC ${GGML_CDEF_PUBLIC})
|
| 1265 |
target_include_directories(ggml PUBLIC ../include)
|
| 1266 |
target_include_directories(ggml PRIVATE . ${GGML_EXTRA_INCLUDES})
|
|
|
|
| 1267 |
target_compile_features (ggml PRIVATE c_std_11) # don't bump
|
| 1268 |
|
| 1269 |
target_link_libraries(ggml PRIVATE Threads::Threads ${GGML_EXTRA_LIBS})
|
|
|
|
| 139 |
)
|
| 140 |
endif()
|
| 141 |
|
| 142 |
+
if (GGML_MUSA)
|
| 143 |
+
set(CMAKE_C_COMPILER clang)
|
| 144 |
+
set(CMAKE_C_EXTENSIONS OFF)
|
| 145 |
+
set(CMAKE_CXX_COMPILER clang++)
|
| 146 |
+
set(CMAKE_CXX_EXTENSIONS OFF)
|
| 147 |
+
|
| 148 |
+
set(GGML_CUDA ON)
|
| 149 |
+
|
| 150 |
+
list(APPEND GGML_CDEF_PUBLIC GGML_USE_MUSA)
|
| 151 |
+
endif()
|
| 152 |
+
|
| 153 |
if (GGML_OPENMP)
|
| 154 |
find_package(OpenMP)
|
| 155 |
if (OpenMP_FOUND)
|
|
|
|
| 158 |
add_compile_definitions(GGML_USE_OPENMP)
|
| 159 |
|
| 160 |
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
| 161 |
+
|
| 162 |
+
if (GGML_MUSA)
|
| 163 |
+
set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} "/usr/lib/llvm-10/include/openmp")
|
| 164 |
+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} "/usr/lib/llvm-10/lib/libomp.so")
|
| 165 |
+
endif()
|
| 166 |
else()
|
| 167 |
message(WARNING "OpenMP not found")
|
| 168 |
endif()
|
|
|
|
| 265 |
if (GGML_CUDA)
|
| 266 |
cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
|
| 267 |
|
| 268 |
+
if (GGML_MUSA)
|
| 269 |
+
list(APPEND CMAKE_MODULE_PATH "/usr/local/musa/cmake/")
|
| 270 |
+
find_package(MUSAToolkit)
|
| 271 |
+
set(CUDAToolkit_FOUND ${MUSAToolkit_FOUND})
|
| 272 |
+
else()
|
| 273 |
+
find_package(CUDAToolkit)
|
| 274 |
+
endif()
|
| 275 |
|
| 276 |
if (CUDAToolkit_FOUND)
|
| 277 |
message(STATUS "CUDA found")
|
|
|
|
| 290 |
endif()
|
| 291 |
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
| 292 |
|
| 293 |
+
if (GGML_MUSA)
|
| 294 |
+
set(CMAKE_CUDA_COMPILER ${MUSAToolkit_MCC_EXECUTABLE})
|
| 295 |
+
else()
|
| 296 |
+
enable_language(CUDA)
|
| 297 |
+
endif()
|
| 298 |
|
| 299 |
file(GLOB GGML_HEADERS_CUDA "ggml-cuda/*.cuh")
|
| 300 |
list(APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h")
|
|
|
|
| 358 |
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
|
| 359 |
endif()
|
| 360 |
|
| 361 |
+
if (GGML_MUSA)
|
| 362 |
+
set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
|
| 363 |
+
foreach(SOURCE ${GGML_SOURCES_CUDA})
|
| 364 |
+
set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22")
|
| 365 |
+
endforeach()
|
| 366 |
+
endif()
|
| 367 |
+
|
| 368 |
if (GGML_STATIC)
|
| 369 |
if (WIN32)
|
| 370 |
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
|
| 371 |
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
|
| 372 |
else ()
|
| 373 |
+
if (GGML_MUSA)
|
| 374 |
+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart_static MUSA::mublas_static)
|
| 375 |
+
else()
|
| 376 |
+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
| 377 |
+
endif()
|
| 378 |
endif()
|
| 379 |
else()
|
| 380 |
+
if (GGML_MUSA)
|
| 381 |
+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart MUSA::mublas)
|
| 382 |
+
else()
|
| 383 |
+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
| 384 |
+
endif()
|
| 385 |
endif()
|
| 386 |
|
| 387 |
if (GGML_CUDA_NO_VMM)
|
| 388 |
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
|
| 389 |
else()
|
| 390 |
+
if (GGML_MUSA)
|
| 391 |
+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musa_driver) # required by muDeviceGetAttribute(), muMemGetAllocationGranularity(...), ...
|
| 392 |
+
else()
|
| 393 |
+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
|
| 394 |
+
endif()
|
| 395 |
endif()
|
| 396 |
else()
|
| 397 |
message(WARNING "CUDA not found")
|
|
|
|
| 902 |
set(C_FLAGS -Wdouble-promotion)
|
| 903 |
set(CXX_FLAGS -Wno-array-bounds)
|
| 904 |
|
| 905 |
+
if (NOT GGML_MUSA)
|
| 906 |
+
if (CCVER VERSION_GREATER_EQUAL 7.1.0)
|
| 907 |
+
list(APPEND CXX_FLAGS -Wno-format-truncation)
|
| 908 |
+
endif()
|
| 909 |
endif()
|
| 910 |
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
|
| 911 |
list(APPEND CXX_FLAGS -Wextra-semi)
|
|
|
|
| 1311 |
target_compile_definitions(ggml PUBLIC ${GGML_CDEF_PUBLIC})
|
| 1312 |
target_include_directories(ggml PUBLIC ../include)
|
| 1313 |
target_include_directories(ggml PRIVATE . ${GGML_EXTRA_INCLUDES})
|
| 1314 |
+
target_link_directories(ggml PRIVATE ${GGML_EXTRA_LIBDIRS})
|
| 1315 |
target_compile_features (ggml PRIVATE c_std_11) # don't bump
|
| 1316 |
|
| 1317 |
target_link_libraries(ggml PRIVATE Threads::Threads ${GGML_EXTRA_LIBS})
|
ggml/src/ggml-common.h
CHANGED
|
@@ -19,7 +19,11 @@ typedef half2 ggml_half2;
|
|
| 19 |
|
| 20 |
#define GGML_COMMON_DECL
|
| 21 |
#elif defined(GGML_COMMON_DECL_CUDA)
|
|
|
|
|
|
|
|
|
|
| 22 |
#include <cuda_fp16.h>
|
|
|
|
| 23 |
#include <cstdint>
|
| 24 |
|
| 25 |
typedef half ggml_half;
|
|
@@ -415,7 +419,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
|
|
| 415 |
#define GGML_TABLE_END() };
|
| 416 |
|
| 417 |
#define GGML_COMMON_IMPL
|
| 418 |
-
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP)
|
| 419 |
#include <cstdint>
|
| 420 |
|
| 421 |
#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
|
|
|
|
| 19 |
|
| 20 |
#define GGML_COMMON_DECL
|
| 21 |
#elif defined(GGML_COMMON_DECL_CUDA)
|
| 22 |
+
#if defined(GGML_COMMON_DECL_MUSA)
|
| 23 |
+
#include <musa_fp16.h>
|
| 24 |
+
#else
|
| 25 |
#include <cuda_fp16.h>
|
| 26 |
+
#endif
|
| 27 |
#include <cstdint>
|
| 28 |
|
| 29 |
typedef half ggml_half;
|
|
|
|
| 419 |
#define GGML_TABLE_END() };
|
| 420 |
|
| 421 |
#define GGML_COMMON_IMPL
|
| 422 |
+
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
|
| 423 |
#include <cstdint>
|
| 424 |
|
| 425 |
#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
|
ggml/src/ggml-cuda.cu
CHANGED
|
@@ -167,7 +167,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|
| 167 |
for (int id = 0; id < info.device_count; ++id) {
|
| 168 |
int device_vmm = 0;
|
| 169 |
|
| 170 |
-
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
| 171 |
CUdevice device;
|
| 172 |
CU_CHECK(cuDeviceGet(&device, id));
|
| 173 |
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
|
|
@@ -179,7 +179,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|
| 179 |
alloc_prop.location.id = id;
|
| 180 |
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
| 181 |
}
|
| 182 |
-
#endif // !defined(GGML_USE_HIPBLAS)
|
| 183 |
info.devices[id].vmm = !!device_vmm;
|
| 184 |
|
| 185 |
cudaDeviceProp prop;
|
|
@@ -315,7 +315,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
|
| 315 |
};
|
| 316 |
|
| 317 |
// pool with virtual memory
|
| 318 |
-
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
| 319 |
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
| 320 |
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
|
| 321 |
|
|
@@ -409,14 +409,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
|
| 409 |
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
|
| 410 |
}
|
| 411 |
};
|
| 412 |
-
#endif // !defined(GGML_USE_HIPBLAS)
|
| 413 |
|
| 414 |
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
|
| 415 |
-
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
| 416 |
if (ggml_cuda_info().devices[device].vmm) {
|
| 417 |
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
|
| 418 |
}
|
| 419 |
-
#endif
|
| 420 |
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
| 421 |
}
|
| 422 |
|
|
@@ -1341,7 +1341,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
|
|
| 1341 |
static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
|
| 1342 |
void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
|
| 1343 |
|
| 1344 |
-
#if !defined(GGML_USE_HIPBLAS)
|
| 1345 |
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
|
| 1346 |
cudaMemcpy3DPeerParms p = {};
|
| 1347 |
p.dstDevice = dstDevice;
|
|
@@ -1355,7 +1355,7 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
|
|
| 1355 |
GGML_UNUSED(dstDevice);
|
| 1356 |
GGML_UNUSED(srcDevice);
|
| 1357 |
return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
|
| 1358 |
-
#endif // !defined(GGML_USE_HIPBLAS)
|
| 1359 |
}
|
| 1360 |
|
| 1361 |
static void ggml_cuda_op_mul_mat(
|
|
@@ -1828,6 +1828,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
| 1828 |
}
|
| 1829 |
}
|
| 1830 |
#else
|
|
|
|
|
|
|
|
|
|
| 1831 |
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
| 1832 |
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
| 1833 |
// use cublasGemmStridedBatchedEx
|
|
@@ -1870,6 +1873,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
| 1870 |
cu_compute_type,
|
| 1871 |
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
| 1872 |
}
|
|
|
|
| 1873 |
#endif
|
| 1874 |
|
| 1875 |
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
|
@@ -3027,7 +3031,7 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size
|
|
| 3027 |
return false;
|
| 3028 |
}
|
| 3029 |
|
| 3030 |
-
#if CUDART_VERSION >= 11100
|
| 3031 |
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
|
| 3032 |
if (err != cudaSuccess) {
|
| 3033 |
// clear the error
|
|
|
|
| 167 |
for (int id = 0; id < info.device_count; ++id) {
|
| 168 |
int device_vmm = 0;
|
| 169 |
|
| 170 |
+
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
| 171 |
CUdevice device;
|
| 172 |
CU_CHECK(cuDeviceGet(&device, id));
|
| 173 |
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
|
|
|
|
| 179 |
alloc_prop.location.id = id;
|
| 180 |
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
| 181 |
}
|
| 182 |
+
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
| 183 |
info.devices[id].vmm = !!device_vmm;
|
| 184 |
|
| 185 |
cudaDeviceProp prop;
|
|
|
|
| 315 |
};
|
| 316 |
|
| 317 |
// pool with virtual memory
|
| 318 |
+
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
| 319 |
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
| 320 |
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
|
| 321 |
|
|
|
|
| 409 |
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
|
| 410 |
}
|
| 411 |
};
|
| 412 |
+
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
| 413 |
|
| 414 |
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
|
| 415 |
+
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
| 416 |
if (ggml_cuda_info().devices[device].vmm) {
|
| 417 |
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
|
| 418 |
}
|
| 419 |
+
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
| 420 |
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
| 421 |
}
|
| 422 |
|
|
|
|
| 1341 |
static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
|
| 1342 |
void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
|
| 1343 |
|
| 1344 |
+
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
| 1345 |
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
|
| 1346 |
cudaMemcpy3DPeerParms p = {};
|
| 1347 |
p.dstDevice = dstDevice;
|
|
|
|
| 1355 |
GGML_UNUSED(dstDevice);
|
| 1356 |
GGML_UNUSED(srcDevice);
|
| 1357 |
return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
|
| 1358 |
+
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
| 1359 |
}
|
| 1360 |
|
| 1361 |
static void ggml_cuda_op_mul_mat(
|
|
|
|
| 1828 |
}
|
| 1829 |
}
|
| 1830 |
#else
|
| 1831 |
+
#ifdef GGML_USE_MUSA
|
| 1832 |
+
GGML_ASSERT(false);
|
| 1833 |
+
#else // !GGML_USE_MUSA
|
| 1834 |
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
| 1835 |
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
| 1836 |
// use cublasGemmStridedBatchedEx
|
|
|
|
| 1873 |
cu_compute_type,
|
| 1874 |
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
| 1875 |
}
|
| 1876 |
+
#endif // GGML_USE_MUSA
|
| 1877 |
#endif
|
| 1878 |
|
| 1879 |
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
|
|
|
| 3031 |
return false;
|
| 3032 |
}
|
| 3033 |
|
| 3034 |
+
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
| 3035 |
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
|
| 3036 |
if (err != cudaSuccess) {
|
| 3037 |
// clear the error
|
ggml/src/ggml-cuda/common.cuh
CHANGED
|
@@ -12,6 +12,10 @@
|
|
| 12 |
#else
|
| 13 |
#define GGML_COMMON_DECL_CUDA
|
| 14 |
#define GGML_COMMON_IMPL_CUDA
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
#endif
|
| 16 |
#include "ggml-common.h"
|
| 17 |
|
|
@@ -114,6 +118,150 @@
|
|
| 114 |
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
| 115 |
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
| 116 |
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
#else
|
| 118 |
#include <cuda_runtime.h>
|
| 119 |
#include <cuda.h>
|
|
@@ -168,9 +316,13 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
|
|
| 168 |
|
| 169 |
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
|
| 170 |
|
| 171 |
-
#if CUDART_VERSION >= 12000
|
| 172 |
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
|
|
|
| 173 |
return cublasGetStatusString(err);
|
|
|
|
|
|
|
|
|
|
| 174 |
}
|
| 175 |
#else
|
| 176 |
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
|
@@ -200,7 +352,7 @@ static const char * cu_get_error_str(CUresult err) {
|
|
| 200 |
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
| 201 |
#endif
|
| 202 |
|
| 203 |
-
#if CUDART_VERSION >= 11100
|
| 204 |
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
| 205 |
#else
|
| 206 |
#define GGML_CUDA_ASSUME(x)
|
|
@@ -214,6 +366,42 @@ typedef float dfloat; // dequantize float
|
|
| 214 |
typedef float2 dfloat2;
|
| 215 |
#endif //GGML_CUDA_F16
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
#if defined(GGML_USE_HIPBLAS)
|
| 218 |
#define __CUDA_ARCH__ 1300
|
| 219 |
|
|
@@ -455,7 +643,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
|
|
| 455 |
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
|
| 456 |
return mask_low | mask_high;
|
| 457 |
}
|
| 458 |
-
#endif // CUDART_VERSION <
|
| 459 |
|
| 460 |
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
|
| 461 |
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
|
|
| 12 |
#else
|
| 13 |
#define GGML_COMMON_DECL_CUDA
|
| 14 |
#define GGML_COMMON_IMPL_CUDA
|
| 15 |
+
#if defined(GGML_USE_MUSA)
|
| 16 |
+
#define GGML_COMMON_DECL_MUSA
|
| 17 |
+
#define GGML_COMMON_IMPL_MUSA
|
| 18 |
+
#endif
|
| 19 |
#endif
|
| 20 |
#include "ggml-common.h"
|
| 21 |
|
|
|
|
| 118 |
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
| 119 |
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
| 120 |
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
| 121 |
+
#elif defined(GGML_USE_MUSA)
|
| 122 |
+
#include <musa_runtime.h>
|
| 123 |
+
#include <musa.h>
|
| 124 |
+
#include <mublas.h>
|
| 125 |
+
#include <musa_fp16.h>
|
| 126 |
+
// XXX: Keep the following order the same as hipBLAS
|
| 127 |
+
// #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F
|
| 128 |
+
// #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F
|
| 129 |
+
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
|
| 130 |
+
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
|
| 131 |
+
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
| 132 |
+
#define CUBLAS_OP_N MUBLAS_OP_N
|
| 133 |
+
#define CUBLAS_OP_T MUBLAS_OP_T
|
| 134 |
+
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
| 135 |
+
// #define CUBLAS_TF32_TENSOR_OP_MATH 0
|
| 136 |
+
#define CUDA_R_16F MUSA_R_16F
|
| 137 |
+
#define CUDA_R_32F MUSA_R_32F
|
| 138 |
+
// #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
| 139 |
+
// #define cublasComputeType_t mublasComputeType_t
|
| 140 |
+
#define cublasCreate mublasCreate
|
| 141 |
+
#define cublasDestroy mublasDestroy
|
| 142 |
+
#define cublasGemmEx mublasGemmEx
|
| 143 |
+
#define cublasGemmBatchedEx mublasGemmBatchedEx
|
| 144 |
+
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
|
| 145 |
+
#define cublasHandle_t mublasHandle_t
|
| 146 |
+
// #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
| 147 |
+
#define cublasSetMathMode mublasSetMathMode
|
| 148 |
+
#define cublasSetStream mublasSetStream
|
| 149 |
+
#define cublasSgemm mublasSgemm
|
| 150 |
+
#define cublasStatus_t mublasStatus_t
|
| 151 |
+
#define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6
|
| 152 |
+
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
| 153 |
+
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
| 154 |
+
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
| 155 |
+
#define cudaDeviceProp musaDeviceProp
|
| 156 |
+
#define cudaDeviceSynchronize musaDeviceSynchronize
|
| 157 |
+
#define cudaError_t musaError_t
|
| 158 |
+
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
|
| 159 |
+
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
|
| 160 |
+
#define cudaEventCreateWithFlags musaEventCreateWithFlags
|
| 161 |
+
#define cudaEventDisableTiming musaEventDisableTiming
|
| 162 |
+
#define cudaEventRecord musaEventRecord
|
| 163 |
+
#define cudaEventSynchronize musaEventSynchronize
|
| 164 |
+
#define cudaEvent_t musaEvent_t
|
| 165 |
+
#define cudaEventDestroy musaEventDestroy
|
| 166 |
+
#define cudaFree musaFree
|
| 167 |
+
#define cudaFreeHost musaFreeHost
|
| 168 |
+
#define cudaGetDevice musaGetDevice
|
| 169 |
+
#define cudaGetDeviceCount musaGetDeviceCount
|
| 170 |
+
#define cudaGetDeviceProperties musaGetDeviceProperties
|
| 171 |
+
#define cudaGetErrorString musaGetErrorString
|
| 172 |
+
#define cudaGetLastError musaGetLastError
|
| 173 |
+
#define cudaHostRegister musaHostRegister
|
| 174 |
+
#define cudaHostRegisterPortable musaHostRegisterPortable
|
| 175 |
+
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
| 176 |
+
#define cudaHostUnregister musaHostUnregister
|
| 177 |
+
#define cudaLaunchHostFunc musaLaunchHostFunc
|
| 178 |
+
#define cudaMalloc musaMalloc
|
| 179 |
+
#define cudaMallocHost musaMallocHost
|
| 180 |
+
#define cudaMemcpy musaMemcpy
|
| 181 |
+
#define cudaMemcpyAsync musaMemcpyAsync
|
| 182 |
+
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
| 183 |
+
#define cudaMemcpy2DAsync musaMemcpy2DAsync
|
| 184 |
+
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
|
| 185 |
+
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
|
| 186 |
+
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
|
| 187 |
+
#define cudaMemcpyKind musaMemcpyKind
|
| 188 |
+
#define cudaMemset musaMemset
|
| 189 |
+
#define cudaMemsetAsync musaMemsetAsync
|
| 190 |
+
#define cudaMemGetInfo musaMemGetInfo
|
| 191 |
+
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
|
| 192 |
+
#define cudaSetDevice musaSetDevice
|
| 193 |
+
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
|
| 194 |
+
#define cudaStreamDestroy musaStreamDestroy
|
| 195 |
+
#define cudaStreamFireAndForget musaStreamFireAndForget
|
| 196 |
+
#define cudaStreamNonBlocking musaStreamNonBlocking
|
| 197 |
+
#define cudaStreamPerThread musaStreamPerThread
|
| 198 |
+
#define cudaStreamSynchronize musaStreamSynchronize
|
| 199 |
+
#define cudaStreamWaitEvent musaStreamWaitEvent
|
| 200 |
+
#define cudaStream_t musaStream_t
|
| 201 |
+
#define cudaSuccess musaSuccess
|
| 202 |
+
|
| 203 |
+
// XXX: Other CUDA => MUSA mapping
|
| 204 |
+
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
|
| 205 |
+
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
| 206 |
+
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
|
| 207 |
+
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
|
| 208 |
+
#define CUdevice MUdevice
|
| 209 |
+
#define CUdeviceptr MUdeviceptr
|
| 210 |
+
#define CUmemAccessDesc MUmemAccessDesc
|
| 211 |
+
#define CUmemAllocationProp MUmemAllocationProp
|
| 212 |
+
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
|
| 213 |
+
#define cuDeviceGet muDeviceGet
|
| 214 |
+
#define cuDeviceGetAttribute muDeviceGetAttribute
|
| 215 |
+
#define cuMemAddressFree muMemAddressFree
|
| 216 |
+
#define cuMemAddressReserve muMemAddressReserve
|
| 217 |
+
#define cuMemCreate muMemCreate
|
| 218 |
+
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
|
| 219 |
+
#define cuMemMap muMemMap
|
| 220 |
+
#define cuMemRelease muMemRelease
|
| 221 |
+
#define cuMemSetAccess muMemSetAccess
|
| 222 |
+
#define cuMemUnmap muMemUnmap
|
| 223 |
+
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
|
| 224 |
+
#define cudaFuncSetAttribute musaFuncSetAttribute
|
| 225 |
+
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
|
| 226 |
+
#define make_cudaExtent make_musaExtent
|
| 227 |
+
#define make_cudaPitchedPtr make_musaPitchedPtr
|
| 228 |
+
|
| 229 |
+
// XXX: USE_CUDA_GRAPH
|
| 230 |
+
#define CUDA_SUCCESS MUSA_SUCCESS
|
| 231 |
+
#define CUresult MUresult
|
| 232 |
+
#define cuGetErrorString muGetErrorString
|
| 233 |
+
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
|
| 234 |
+
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
|
| 235 |
+
#define cudaGraphDestroy musaGraphDestroy
|
| 236 |
+
#define cudaGraphExecDestroy musaGraphExecDestroy
|
| 237 |
+
#define cudaGraphExec_t musaGraphExec_t
|
| 238 |
+
#define cudaGraphExecUpdate musaGraphExecUpdate
|
| 239 |
+
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
| 240 |
+
#define cudaGraphGetNodes musaGraphGetNodes
|
| 241 |
+
#define cudaGraphInstantiate musaGraphInstantiate
|
| 242 |
+
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
| 243 |
+
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
|
| 244 |
+
#define cudaGraphLaunch musaGraphLaunch
|
| 245 |
+
#define cudaGraphNodeGetType musaGraphNodeGetType
|
| 246 |
+
#define cudaGraphNode_t musaGraphNode_t
|
| 247 |
+
#define cudaGraphNodeType musaGraphNodeType
|
| 248 |
+
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
|
| 249 |
+
#define cudaGraph_t musaGraph_t
|
| 250 |
+
#define cudaKernelNodeParams musaKernelNodeParams
|
| 251 |
+
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
| 252 |
+
#define cudaStreamEndCapture musaStreamEndCapture
|
| 253 |
+
|
| 254 |
+
// XXX: cuBLAS => muBLAS mapping
|
| 255 |
+
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
| 256 |
+
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
|
| 257 |
+
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
| 258 |
+
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
| 259 |
+
#define cublasComputeType_t cudaDataType_t
|
| 260 |
+
|
| 261 |
+
// XXX: Clang builtins mapping
|
| 262 |
+
#define __vsub4 __vsub4_musa
|
| 263 |
+
#define __vcmpeq4 __vcmpeq4_musa
|
| 264 |
+
#define __vcmpne4 __vcmpne4_musa
|
| 265 |
#else
|
| 266 |
#include <cuda_runtime.h>
|
| 267 |
#include <cuda.h>
|
|
|
|
| 316 |
|
| 317 |
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
|
| 318 |
|
| 319 |
+
#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
|
| 320 |
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
| 321 |
+
#ifndef GGML_USE_MUSA
|
| 322 |
return cublasGetStatusString(err);
|
| 323 |
+
#else
|
| 324 |
+
return mublasStatus_to_string(err);
|
| 325 |
+
#endif // GGML_USE_MUSA
|
| 326 |
}
|
| 327 |
#else
|
| 328 |
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
|
|
|
| 352 |
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
| 353 |
#endif
|
| 354 |
|
| 355 |
+
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
| 356 |
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
| 357 |
#else
|
| 358 |
#define GGML_CUDA_ASSUME(x)
|
|
|
|
| 366 |
typedef float2 dfloat2;
|
| 367 |
#endif //GGML_CUDA_F16
|
| 368 |
|
| 369 |
+
#if defined(GGML_USE_MUSA)
|
| 370 |
+
#ifndef __has_builtin
|
| 371 |
+
#define __has_builtin(x) 0
|
| 372 |
+
#endif
|
| 373 |
+
|
| 374 |
+
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
| 375 |
+
|
| 376 |
+
static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
|
| 377 |
+
return __vsubss4(a, b);
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
|
| 381 |
+
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
| 382 |
+
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
| 383 |
+
unsigned int c;
|
| 384 |
+
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
| 385 |
+
#pragma unroll
|
| 386 |
+
for (int i = 0; i < 4; ++i) {
|
| 387 |
+
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
| 388 |
+
}
|
| 389 |
+
return c;
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
|
| 393 |
+
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
| 394 |
+
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
| 395 |
+
unsigned int c;
|
| 396 |
+
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
| 397 |
+
#pragma unroll
|
| 398 |
+
for (int i = 0; i < 4; ++i) {
|
| 399 |
+
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
| 400 |
+
}
|
| 401 |
+
return c;
|
| 402 |
+
}
|
| 403 |
+
#endif // defined(GGML_USE_MUSA)
|
| 404 |
+
|
| 405 |
#if defined(GGML_USE_HIPBLAS)
|
| 406 |
#define __CUDA_ARCH__ 1300
|
| 407 |
|
|
|
|
| 643 |
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
|
| 644 |
return mask_low | mask_high;
|
| 645 |
}
|
| 646 |
+
#endif // CUDART_VERSION < CUDART_HMASK
|
| 647 |
|
| 648 |
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
|
| 649 |
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|