JohannesGaessler commited on
Commit
3ff7660
·
unverified ·
1 Parent(s): 5d130aa

CUDA: fixed mmvq kernel for bs 2,3,4 and -sm row (llama/5386)

Browse files
Files changed (1) hide show
  1. ggml-cuda.cu +39 -27
ggml-cuda.cu CHANGED
@@ -5313,7 +5313,7 @@ template <bool need_check> static __global__ void
5313
  template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5314
  static __global__ void mul_mat_vec_q(
5315
  const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
5316
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par) {
5317
 
5318
  const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
5319
 
@@ -5352,7 +5352,7 @@ static __global__ void mul_mat_vec_q(
5352
  tmp[j] = warp_reduce_sum(tmp[j]);
5353
 
5354
  if (threadIdx.x == 0) {
5355
- dst[j*nrows_x + row] = tmp[j];
5356
  }
5357
  }
5358
  }
@@ -6828,7 +6828,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
6828
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
6829
  static void mul_mat_vec_q_cuda(
6830
  const void * vx, const void * vy, float * dst,
6831
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
6832
 
6833
  GGML_ASSERT(ncols_x % qk == 0);
6834
  GGML_ASSERT(ncols_y <= 4);
@@ -6839,40 +6839,40 @@ static void mul_mat_vec_q_cuda(
6839
  switch (ncols_y) {
6840
  case 1:
6841
  mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
6842
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6843
  break;
6844
  case 2:
6845
  mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
6846
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6847
  break;
6848
  case 3:
6849
  mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
6850
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6851
  break;
6852
  case 4:
6853
  mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
6854
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6855
  break;
6856
  // case 5:
6857
  // mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
6858
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6859
  // break;
6860
  // case 6:
6861
  // mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
6862
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6863
  // break;
6864
  // case 7:
6865
  // mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
6866
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6867
  // break;
6868
  // case 8:
6869
  // mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
6870
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6871
  // break;
6872
  default:
6873
  GGML_ASSERT(false);
6874
  // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
6875
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6876
  break;
6877
  }
6878
  }
@@ -8391,7 +8391,7 @@ static void ggml_cuda_op_mul_mat_q(
8391
  CUDA_CHECK(cudaGetDevice(&id));
8392
 
8393
  // the main device has a larger memory buffer to hold the results from all GPUs
8394
- // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
8395
  const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
8396
 
8397
  switch (src0->type) {
@@ -8525,58 +8525,70 @@ static void ggml_cuda_op_mul_mat_vec_q(
8525
  const int64_t ne00 = src0->ne[0];
8526
  const int64_t row_diff = row_high - row_low;
8527
 
 
 
 
 
 
 
 
 
 
 
 
 
8528
  switch (src0->type) {
8529
  case GGML_TYPE_Q4_0:
8530
  mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
8531
- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8532
  break;
8533
  case GGML_TYPE_Q4_1:
8534
  mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
8535
- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8536
  break;
8537
  case GGML_TYPE_Q5_0:
8538
  mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
8539
- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8540
  break;
8541
  case GGML_TYPE_Q5_1:
8542
  mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
8543
- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8544
  break;
8545
  case GGML_TYPE_Q8_0:
8546
  mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
8547
- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8548
  break;
8549
  case GGML_TYPE_Q2_K:
8550
  mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
8551
- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8552
  break;
8553
  case GGML_TYPE_Q3_K:
8554
  mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
8555
- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8556
  break;
8557
  case GGML_TYPE_Q4_K:
8558
  mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
8559
- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8560
  break;
8561
  case GGML_TYPE_Q5_K:
8562
  mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
8563
- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8564
  break;
8565
  case GGML_TYPE_Q6_K:
8566
  mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
8567
- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8568
  break;
8569
  case GGML_TYPE_IQ2_XXS:
8570
  mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
8571
- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8572
  break;
8573
  case GGML_TYPE_IQ2_XS:
8574
  mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
8575
- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8576
  break;
8577
  case GGML_TYPE_IQ3_XXS:
8578
  mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
8579
- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8580
  break;
8581
  default:
8582
  GGML_ASSERT(false);
@@ -9909,7 +9921,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
9909
  ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
9910
  }
9911
  } else {
9912
- if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) {
9913
  ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
9914
  } else if (use_mul_mat_q) {
9915
  ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
 
5313
  template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5314
  static __global__ void mul_mat_vec_q(
5315
  const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
5316
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
5317
 
5318
  const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
5319
 
 
5352
  tmp[j] = warp_reduce_sum(tmp[j]);
5353
 
5354
  if (threadIdx.x == 0) {
5355
+ dst[j*nrows_dst + row] = tmp[j];
5356
  }
5357
  }
5358
  }
 
6828
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
6829
  static void mul_mat_vec_q_cuda(
6830
  const void * vx, const void * vy, float * dst,
6831
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
6832
 
6833
  GGML_ASSERT(ncols_x % qk == 0);
6834
  GGML_ASSERT(ncols_y <= 4);
 
6839
  switch (ncols_y) {
6840
  case 1:
6841
  mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
6842
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6843
  break;
6844
  case 2:
6845
  mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
6846
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6847
  break;
6848
  case 3:
6849
  mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
6850
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6851
  break;
6852
  case 4:
6853
  mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
6854
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6855
  break;
6856
  // case 5:
6857
  // mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
6858
+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6859
  // break;
6860
  // case 6:
6861
  // mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
6862
+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6863
  // break;
6864
  // case 7:
6865
  // mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
6866
+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6867
  // break;
6868
  // case 8:
6869
  // mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
6870
+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6871
  // break;
6872
  default:
6873
  GGML_ASSERT(false);
6874
  // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
6875
+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6876
  break;
6877
  }
6878
  }
 
8391
  CUDA_CHECK(cudaGetDevice(&id));
8392
 
8393
  // the main device has a larger memory buffer to hold the results from all GPUs
8394
+ // nrows_dst == nrows of the matrix that the kernel writes into
8395
  const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
8396
 
8397
  switch (src0->type) {
 
8525
  const int64_t ne00 = src0->ne[0];
8526
  const int64_t row_diff = row_high - row_low;
8527
 
8528
+ const int64_t ne10 = src1->ne[0];
8529
+ GGML_ASSERT(ne10 % QK8_1 == 0);
8530
+
8531
+ const int64_t ne0 = dst->ne[0];
8532
+
8533
+ int id;
8534
+ CUDA_CHECK(cudaGetDevice(&id));
8535
+
8536
+ // the main device has a larger memory buffer to hold the results from all GPUs
8537
+ // nrows_dst == nrows of the matrix that the kernel writes into
8538
+ const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
8539
+
8540
  switch (src0->type) {
8541
  case GGML_TYPE_Q4_0:
8542
  mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
8543
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8544
  break;
8545
  case GGML_TYPE_Q4_1:
8546
  mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
8547
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8548
  break;
8549
  case GGML_TYPE_Q5_0:
8550
  mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
8551
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8552
  break;
8553
  case GGML_TYPE_Q5_1:
8554
  mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
8555
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8556
  break;
8557
  case GGML_TYPE_Q8_0:
8558
  mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
8559
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8560
  break;
8561
  case GGML_TYPE_Q2_K:
8562
  mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
8563
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8564
  break;
8565
  case GGML_TYPE_Q3_K:
8566
  mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
8567
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8568
  break;
8569
  case GGML_TYPE_Q4_K:
8570
  mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
8571
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8572
  break;
8573
  case GGML_TYPE_Q5_K:
8574
  mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
8575
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8576
  break;
8577
  case GGML_TYPE_Q6_K:
8578
  mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
8579
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8580
  break;
8581
  case GGML_TYPE_IQ2_XXS:
8582
  mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
8583
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8584
  break;
8585
  case GGML_TYPE_IQ2_XS:
8586
  mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
8587
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8588
  break;
8589
  case GGML_TYPE_IQ3_XXS:
8590
  mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
8591
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8592
  break;
8593
  default:
8594
  GGML_ASSERT(false);
 
9921
  ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
9922
  }
9923
  } else {
9924
+ if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32) {
9925
  ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
9926
  } else if (use_mul_mat_q) {
9927
  ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);