Garf commited on
Commit
6cb8158
·
1 Parent(s): 6641178

cuda: Add Q5_1, Q5_0, Q4_1 and Q4_0 to F32 conversion support. (llama/12000)

Browse files
ggml/src/ggml-cuda/cpy.cu CHANGED
@@ -1,4 +1,5 @@
1
  #include "cpy.cuh"
 
2
 
3
  typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
4
 
@@ -82,13 +83,14 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
82
  }
83
 
84
  static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
85
- const block_q8_0 * xi = (const block_q8_0 *) cxi;
86
- float * dsti = (float *) cdsti;
87
-
88
- const float d = (float)xi->d;
89
-
90
- for (int j = 0; j < QK8_0; j++) {
91
- dsti[j] = xi->qs[j] * d;
 
92
  }
93
  }
94
 
@@ -225,6 +227,18 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
225
  memcpy(dsti->qh, &qh, sizeof(qh));
226
  }
227
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
230
  if (x <= val[0]) return 0;
@@ -387,6 +401,19 @@ static void ggml_cpy_f32_q4_0_cuda(
387
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
388
  }
389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  static void ggml_cpy_f32_q4_1_cuda(
391
  const char * cx, char * cdst, const int ne,
392
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -398,6 +425,19 @@ static void ggml_cpy_f32_q4_1_cuda(
398
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
399
  }
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  static void ggml_cpy_f32_q5_0_cuda(
402
  const char * cx, char * cdst, const int ne,
403
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -409,6 +449,19 @@ static void ggml_cpy_f32_q5_0_cuda(
409
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
410
  }
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  static void ggml_cpy_f32_q5_1_cuda(
413
  const char * cx, char * cdst, const int ne,
414
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -420,6 +473,19 @@ static void ggml_cpy_f32_q5_1_cuda(
420
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
421
  }
422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  static void ggml_cpy_f32_iq4_nl_cuda(
424
  const char * cx, char * cdst, const int ne,
425
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -488,14 +554,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
488
  ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
489
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
490
  ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
 
 
 
491
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
492
  ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
 
 
 
493
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
494
  ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
 
 
 
495
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
496
  ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
497
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
498
  ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
 
 
499
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
500
  ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
501
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@@ -524,14 +601,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
524
  return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
525
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
526
  return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
 
 
527
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
528
  return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
 
 
529
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
530
  return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
 
 
531
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
532
  return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
533
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
534
  return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
 
 
535
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
536
  return (void*) cpy_f32_f16<cpy_1_f32_f16>;
537
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
 
1
  #include "cpy.cuh"
2
+ #include "dequantize.cuh"
3
 
4
  typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
5
 
 
83
  }
84
 
85
  static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
86
+ float * cdstf = (float *)(cdsti);
87
+
88
+ #pragma unroll
89
+ for (int j = 0; j < QK8_0; j += 2) {
90
+ dfloat2 dq;
91
+ dequantize_q8_0(cxi, 0, j, dq);
92
+ *(cdstf + j) = dq.x;
93
+ *(cdstf + j + 1) = dq.y;
94
  }
95
  }
96
 
 
227
  memcpy(dsti->qh, &qh, sizeof(qh));
228
  }
229
 
230
+ template<dequantize_kernel_t dequant, int qk>
231
+ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
232
+ float * cdstf = (float *)(cdsti);
233
+
234
+ #pragma unroll
235
+ for (int j = 0; j < qk/2; j++) {
236
+ dfloat2 dq;
237
+ dequant(cxi, 0, j, dq);
238
+ *(cdstf + j) = dq.x;
239
+ *(cdstf + j + qk/2) = dq.y;
240
+ }
241
+ }
242
 
243
  static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
244
  if (x <= val[0]) return 0;
 
401
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
402
  }
403
 
404
+ static void ggml_cpy_q4_0_f32_cuda(
405
+ const char * cx, char * cdst, const int ne,
406
+ const int ne00, const int ne01, const int ne02,
407
+ const int nb00, const int nb01, const int nb02,
408
+ const int nb03, const int ne10, const int ne11, const int ne12,
409
+ const int nb10, const int nb11, const int nb12, const int nb13,
410
+ cudaStream_t stream) {
411
+ const int num_blocks = ne;
412
+ cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
413
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
414
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
415
+ }
416
+
417
  static void ggml_cpy_f32_q4_1_cuda(
418
  const char * cx, char * cdst, const int ne,
419
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
 
425
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
426
  }
427
 
428
+ static void ggml_cpy_q4_1_f32_cuda(
429
+ const char * cx, char * cdst, const int ne,
430
+ const int ne00, const int ne01, const int ne02,
431
+ const int nb00, const int nb01, const int nb02,
432
+ const int nb03, const int ne10, const int ne11, const int ne12,
433
+ const int nb10, const int nb11, const int nb12, const int nb13,
434
+ cudaStream_t stream) {
435
+ const int num_blocks = ne;
436
+ cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
437
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
438
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
439
+ }
440
+
441
  static void ggml_cpy_f32_q5_0_cuda(
442
  const char * cx, char * cdst, const int ne,
443
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
 
449
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
450
  }
451
 
452
+ static void ggml_cpy_q5_0_f32_cuda(
453
+ const char * cx, char * cdst, const int ne,
454
+ const int ne00, const int ne01, const int ne02,
455
+ const int nb00, const int nb01, const int nb02,
456
+ const int nb03, const int ne10, const int ne11, const int ne12,
457
+ const int nb10, const int nb11, const int nb12, const int nb13,
458
+ cudaStream_t stream) {
459
+ const int num_blocks = ne;
460
+ cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
461
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
462
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
463
+ }
464
+
465
  static void ggml_cpy_f32_q5_1_cuda(
466
  const char * cx, char * cdst, const int ne,
467
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
 
473
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
474
  }
475
 
476
+ static void ggml_cpy_q5_1_f32_cuda(
477
+ const char * cx, char * cdst, const int ne,
478
+ const int ne00, const int ne01, const int ne02,
479
+ const int nb00, const int nb01, const int nb02,
480
+ const int nb03, const int ne10, const int ne11, const int ne12,
481
+ const int nb10, const int nb11, const int nb12, const int nb13,
482
+ cudaStream_t stream) {
483
+ const int num_blocks = ne;
484
+ cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
485
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
486
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
487
+ }
488
+
489
  static void ggml_cpy_f32_iq4_nl_cuda(
490
  const char * cx, char * cdst, const int ne,
491
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
 
554
  ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
555
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
556
  ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
557
+ } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
558
+ ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
559
+ nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
560
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
561
  ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
562
+ } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
563
+ ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
564
+ nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
565
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
566
  ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
567
+ } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
568
+ ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
569
+ nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
570
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
571
  ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
572
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
573
  ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
574
+ } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
575
+ ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
576
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
577
  ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
578
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
 
601
  return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
602
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
603
  return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
604
+ } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
605
+ return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
606
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
607
  return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
608
+ } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
609
+ return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
610
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
611
  return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
612
+ } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
613
+ return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
614
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
615
  return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
616
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
617
  return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
618
+ } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
619
+ return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
620
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
621
  return (void*) cpy_f32_f16<cpy_1_f32_f16>;
622
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3075,15 +3075,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3075
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
3076
  return true;
3077
  }
 
 
 
3078
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
3079
  return true;
3080
  }
 
 
 
3081
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
3082
  return true;
3083
  }
 
 
 
3084
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
3085
  return true;
3086
  }
 
 
 
3087
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
3088
  return true;
3089
  }
 
3075
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
3076
  return true;
3077
  }
3078
+ if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
3079
+ return true;
3080
+ }
3081
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
3082
  return true;
3083
  }
3084
+ if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
3085
+ return true;
3086
+ }
3087
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
3088
  return true;
3089
  }
3090
+ if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
3091
+ return true;
3092
+ }
3093
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
3094
  return true;
3095
  }
3096
+ if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
3097
+ return true;
3098
+ }
3099
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
3100
  return true;
3101
  }