ardfork commited on
Commit
e093092
·
unverified ·
1 Parent(s): 8ae21a0

whisper : initial hipBLAS support (#1209)

Browse files
Files changed (3) hide show
  1. CMakeLists.txt +32 -0
  2. Makefile +15 -0
  3. ggml-cuda.cu +51 -0
CMakeLists.txt CHANGED
@@ -65,6 +65,7 @@ else()
65
  option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic)
66
  option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF)
67
  option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
 
68
  option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
69
  endif()
70
 
@@ -191,6 +192,37 @@ if (WHISPER_CUBLAS)
191
  endif()
192
  endif()
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  if (WHISPER_CLBLAST)
195
  find_package(CLBlast)
196
  if (CLBlast_FOUND)
 
65
  option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic)
66
  option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF)
67
  option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
68
+ option(WHISPER_HIPBLAS "whisper: support for hipBLAS" OFF)
69
  option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
70
  endif()
71
 
 
192
  endif()
193
  endif()
194
 
195
+
196
+ if (WHISPER_HIPBLAS)
197
+ list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
198
+ if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
199
+ message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
200
+ endif()
201
+ if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
202
+ message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
203
+ endif()
204
+
205
+ find_package(hip)
206
+ find_package(hipblas)
207
+ find_package(rocblas)
208
+
209
+ if (${hipblas_FOUND} AND ${hip_FOUND})
210
+ message(STATUS "HIP and hipBLAS found")
211
+ add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
212
+ add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
213
+ set_property(TARGET ggml-rocm PROPERTY POSITION_INDEPENDENT_CODE ON)
214
+ set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
215
+ target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
216
+
217
+ if (WHISPER_STATIC)
218
+ message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
219
+ endif()
220
+ set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm)
221
+ else()
222
+ message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
223
+ endif()
224
+ endif()
225
+
226
  if (WHISPER_CLBLAST)
227
  find_package(CLBlast)
228
  if (CLBlast_FOUND)
Makefile CHANGED
@@ -161,6 +161,21 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
161
  $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
162
  endif
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  ifdef WHISPER_CLBLAST
165
  CFLAGS += -DGGML_USE_CLBLAST
166
  CXXFLAGS += -DGGML_USE_CLBLAST
 
161
  $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
162
  endif
163
 
164
+ ifdef WHISPER_HIPBLAS
165
+ ROCM_PATH ?= /opt/rocm
166
+ HIPCC ?= $(ROCM_PATH)/bin/hipcc
167
+ GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
168
+ CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
169
+ CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
170
+ LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
171
+ LDFLAGS += -lhipblas -lamdhip64 -lrocblas
172
+ HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
173
+ WHISPER_OBJ += ggml-cuda.o
174
+
175
+ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
176
+ $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
177
+ endif
178
+
179
  ifdef WHISPER_CLBLAST
180
  CFLAGS += -DGGML_USE_CLBLAST
181
  CXXFLAGS += -DGGML_USE_CLBLAST
ggml-cuda.cu CHANGED
@@ -6,9 +6,60 @@
6
  #include <atomic>
7
  #include <assert.h>
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  #include <cuda_runtime.h>
10
  #include <cublas_v2.h>
11
  #include <cuda_fp16.h>
 
12
 
13
  #include "ggml-cuda.h"
14
  #include "ggml.h"
 
6
  #include <atomic>
7
  #include <assert.h>
8
 
9
+ #if defined(GGML_USE_HIPBLAS)
10
+ #include <hip/hip_runtime.h>
11
+ #include <hipblas/hipblas.h>
12
+ #include <hip/hip_fp16.h>
13
+ #include <rocblas/rocblas.h>
14
+ #define CUBLAS_OP_N HIPBLAS_OP_N
15
+ #define CUBLAS_OP_T HIPBLAS_OP_T
16
+ #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
17
+ #define CUBLAS_TF32_TENSOR_OP_MATH 0
18
+ #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
19
+ #define cublasCreate hipblasCreate
20
+ #define cublasGetStatusString rocblas_status_to_string
21
+ #define cublasHandle_t hipblasHandle_t
22
+ #define cublasLoggerConfigure(logIsOn, logToStdOut, logToStdErr, logFileName) CUBLAS_STATUS_SUCCESS
23
+ #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
24
+ #define cublasSetStream hipblasSetStream
25
+ #define cublasSgemm hipblasSgemm
26
+ #define cublasStatus_t hipblasStatus_t
27
+ #define cudaDeviceProp hipDeviceProp_t
28
+ #define cudaDeviceSynchronize hipDeviceSynchronize
29
+ #define cudaError_t hipError_t
30
+ #define cudaEventCreateWithFlags hipEventCreateWithFlags
31
+ #define cudaEventDestroy hipEventDestroy
32
+ #define cudaEventDisableTiming hipEventDisableTiming
33
+ #define cudaEventRecord hipEventRecord
34
+ #define cudaEvent_t hipEvent_t
35
+ #define cudaFree hipFree
36
+ #define cudaFreeHost hipHostFree
37
+ #define cudaGetDevice hipGetDevice
38
+ #define cudaGetDeviceCount hipGetDeviceCount
39
+ #define cudaGetDeviceProperties hipGetDeviceProperties
40
+ #define cudaGetErrorString hipGetErrorString
41
+ #define cudaGetLastError hipGetLastError
42
+ #define cudaMalloc hipMalloc
43
+ #define cudaMallocHost hipHostMalloc
44
+ #define cudaMemcpy hipMemcpy
45
+ #define cudaMemcpy2DAsync hipMemcpy2DAsync
46
+ #define cudaMemcpyAsync hipMemcpyAsync
47
+ #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
48
+ #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
49
+ #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
50
+ #define cudaMemcpyKind hipMemcpyKind
51
+ #define cudaMemset hipMemset
52
+ #define cudaSetDevice hipSetDevice
53
+ #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
54
+ #define cudaStreamNonBlocking hipStreamNonBlocking
55
+ #define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
56
+ #define cudaStream_t hipStream_t
57
+ #define cudaSuccess hipSuccess
58
+ #else
59
  #include <cuda_runtime.h>
60
  #include <cublas_v2.h>
61
  #include <cuda_fp16.h>
62
+ #endif
63
 
64
  #include "ggml-cuda.h"
65
  #include "ggml.h"