Spaces:
Sleeping
Sleeping
ardfork
commited on
whisper : initial hipBLAS support (#1209)
Browse files- CMakeLists.txt +32 -0
- Makefile +15 -0
- ggml-cuda.cu +51 -0
CMakeLists.txt
CHANGED
|
@@ -65,6 +65,7 @@ else()
|
|
| 65 |
option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic)
|
| 66 |
option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF)
|
| 67 |
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
|
|
|
|
| 68 |
option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
|
| 69 |
endif()
|
| 70 |
|
|
@@ -191,6 +192,37 @@ if (WHISPER_CUBLAS)
|
|
| 191 |
endif()
|
| 192 |
endif()
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
if (WHISPER_CLBLAST)
|
| 195 |
find_package(CLBlast)
|
| 196 |
if (CLBlast_FOUND)
|
|
|
|
| 65 |
option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic)
|
| 66 |
option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF)
|
| 67 |
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
|
| 68 |
+
option(WHISPER_HIPBLAS "whisper: support for hipBLAS" OFF)
|
| 69 |
option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
|
| 70 |
endif()
|
| 71 |
|
|
|
|
| 192 |
endif()
|
| 193 |
endif()
|
| 194 |
|
| 195 |
+
|
| 196 |
+
if (WHISPER_HIPBLAS)
|
| 197 |
+
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
|
| 198 |
+
if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
|
| 199 |
+
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
|
| 200 |
+
endif()
|
| 201 |
+
if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
|
| 202 |
+
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
|
| 203 |
+
endif()
|
| 204 |
+
|
| 205 |
+
find_package(hip)
|
| 206 |
+
find_package(hipblas)
|
| 207 |
+
find_package(rocblas)
|
| 208 |
+
|
| 209 |
+
if (${hipblas_FOUND} AND ${hip_FOUND})
|
| 210 |
+
message(STATUS "HIP and hipBLAS found")
|
| 211 |
+
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
|
| 212 |
+
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
|
| 213 |
+
set_property(TARGET ggml-rocm PROPERTY POSITION_INDEPENDENT_CODE ON)
|
| 214 |
+
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
|
| 215 |
+
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
|
| 216 |
+
|
| 217 |
+
if (WHISPER_STATIC)
|
| 218 |
+
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
|
| 219 |
+
endif()
|
| 220 |
+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm)
|
| 221 |
+
else()
|
| 222 |
+
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
|
| 223 |
+
endif()
|
| 224 |
+
endif()
|
| 225 |
+
|
| 226 |
if (WHISPER_CLBLAST)
|
| 227 |
find_package(CLBlast)
|
| 228 |
if (CLBlast_FOUND)
|
Makefile
CHANGED
|
@@ -161,6 +161,21 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
|
|
| 161 |
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
|
| 162 |
endif
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
ifdef WHISPER_CLBLAST
|
| 165 |
CFLAGS += -DGGML_USE_CLBLAST
|
| 166 |
CXXFLAGS += -DGGML_USE_CLBLAST
|
|
|
|
| 161 |
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
|
| 162 |
endif
|
| 163 |
|
| 164 |
+
ifdef WHISPER_HIPBLAS
|
| 165 |
+
ROCM_PATH ?= /opt/rocm
|
| 166 |
+
HIPCC ?= $(ROCM_PATH)/bin/hipcc
|
| 167 |
+
GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
|
| 168 |
+
CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
|
| 169 |
+
CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
|
| 170 |
+
LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
|
| 171 |
+
LDFLAGS += -lhipblas -lamdhip64 -lrocblas
|
| 172 |
+
HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
|
| 173 |
+
WHISPER_OBJ += ggml-cuda.o
|
| 174 |
+
|
| 175 |
+
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
|
| 176 |
+
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
|
| 177 |
+
endif
|
| 178 |
+
|
| 179 |
ifdef WHISPER_CLBLAST
|
| 180 |
CFLAGS += -DGGML_USE_CLBLAST
|
| 181 |
CXXFLAGS += -DGGML_USE_CLBLAST
|
ggml-cuda.cu
CHANGED
|
@@ -6,9 +6,60 @@
|
|
| 6 |
#include <atomic>
|
| 7 |
#include <assert.h>
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
#include <cuda_runtime.h>
|
| 10 |
#include <cublas_v2.h>
|
| 11 |
#include <cuda_fp16.h>
|
|
|
|
| 12 |
|
| 13 |
#include "ggml-cuda.h"
|
| 14 |
#include "ggml.h"
|
|
|
|
| 6 |
#include <atomic>
|
| 7 |
#include <assert.h>
|
| 8 |
|
| 9 |
+
#if defined(GGML_USE_HIPBLAS)
|
| 10 |
+
#include <hip/hip_runtime.h>
|
| 11 |
+
#include <hipblas/hipblas.h>
|
| 12 |
+
#include <hip/hip_fp16.h>
|
| 13 |
+
#include <rocblas/rocblas.h>
|
| 14 |
+
#define CUBLAS_OP_N HIPBLAS_OP_N
|
| 15 |
+
#define CUBLAS_OP_T HIPBLAS_OP_T
|
| 16 |
+
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
| 17 |
+
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
| 18 |
+
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
| 19 |
+
#define cublasCreate hipblasCreate
|
| 20 |
+
#define cublasGetStatusString rocblas_status_to_string
|
| 21 |
+
#define cublasHandle_t hipblasHandle_t
|
| 22 |
+
#define cublasLoggerConfigure(logIsOn, logToStdOut, logToStdErr, logFileName) CUBLAS_STATUS_SUCCESS
|
| 23 |
+
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
| 24 |
+
#define cublasSetStream hipblasSetStream
|
| 25 |
+
#define cublasSgemm hipblasSgemm
|
| 26 |
+
#define cublasStatus_t hipblasStatus_t
|
| 27 |
+
#define cudaDeviceProp hipDeviceProp_t
|
| 28 |
+
#define cudaDeviceSynchronize hipDeviceSynchronize
|
| 29 |
+
#define cudaError_t hipError_t
|
| 30 |
+
#define cudaEventCreateWithFlags hipEventCreateWithFlags
|
| 31 |
+
#define cudaEventDestroy hipEventDestroy
|
| 32 |
+
#define cudaEventDisableTiming hipEventDisableTiming
|
| 33 |
+
#define cudaEventRecord hipEventRecord
|
| 34 |
+
#define cudaEvent_t hipEvent_t
|
| 35 |
+
#define cudaFree hipFree
|
| 36 |
+
#define cudaFreeHost hipHostFree
|
| 37 |
+
#define cudaGetDevice hipGetDevice
|
| 38 |
+
#define cudaGetDeviceCount hipGetDeviceCount
|
| 39 |
+
#define cudaGetDeviceProperties hipGetDeviceProperties
|
| 40 |
+
#define cudaGetErrorString hipGetErrorString
|
| 41 |
+
#define cudaGetLastError hipGetLastError
|
| 42 |
+
#define cudaMalloc hipMalloc
|
| 43 |
+
#define cudaMallocHost hipHostMalloc
|
| 44 |
+
#define cudaMemcpy hipMemcpy
|
| 45 |
+
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
| 46 |
+
#define cudaMemcpyAsync hipMemcpyAsync
|
| 47 |
+
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
| 48 |
+
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
| 49 |
+
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
| 50 |
+
#define cudaMemcpyKind hipMemcpyKind
|
| 51 |
+
#define cudaMemset hipMemset
|
| 52 |
+
#define cudaSetDevice hipSetDevice
|
| 53 |
+
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
| 54 |
+
#define cudaStreamNonBlocking hipStreamNonBlocking
|
| 55 |
+
#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
|
| 56 |
+
#define cudaStream_t hipStream_t
|
| 57 |
+
#define cudaSuccess hipSuccess
|
| 58 |
+
#else
|
| 59 |
#include <cuda_runtime.h>
|
| 60 |
#include <cublas_v2.h>
|
| 61 |
#include <cuda_fp16.h>
|
| 62 |
+
#endif
|
| 63 |
|
| 64 |
#include "ggml-cuda.h"
|
| 65 |
#include "ggml.h"
|