JohannesGaessler commited on
Commit
7b7c5d3
·
1 Parent(s): 6f5687a

CUDA: fix FP16 cuBLAS GEMM (llama/11396)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml-cuda/ggml-cuda.cu +5 -5
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -1114,8 +1114,8 @@ static void ggml_cuda_op_mul_mat_cublas(
1114
  CUBLAS_CHECK(
1115
  cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1116
  row_diff, src1_ncols, ne10,
1117
- &alpha, src0_ptr, CUDA_R_16F, ne00,
1118
- src1_ptr, CUDA_R_16F, ne10,
1119
  &beta, dst_dd_i, CUDA_R_32F, ldc,
1120
  CUBLAS_COMPUTE_32F,
1121
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1128,9 +1128,9 @@ static void ggml_cuda_op_mul_mat_cublas(
1128
  CUBLAS_CHECK(
1129
  cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1130
  row_diff, src1_ncols, ne10,
1131
- &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1132
- src1_ptr, CUDA_R_16F, ne10,
1133
- &beta_f16, dst_dd_i, CUDA_R_16F, ldc,
1134
  CUBLAS_COMPUTE_16F,
1135
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1136
 
 
1114
  CUBLAS_CHECK(
1115
  cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1116
  row_diff, src1_ncols, ne10,
1117
+ &alpha, src0_ptr, CUDA_R_16F, ne00,
1118
+ src1_ptr, CUDA_R_16F, ne10,
1119
  &beta, dst_dd_i, CUDA_R_32F, ldc,
1120
  CUBLAS_COMPUTE_32F,
1121
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 
1128
  CUBLAS_CHECK(
1129
  cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1130
  row_diff, src1_ncols, ne10,
1131
+ &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1132
+ src1_ptr, CUDA_R_16F, ne10,
1133
+ &beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
1134
  CUBLAS_COMPUTE_16F,
1135
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1136