Spaces:
Sleeping
Sleeping
cuda: Add Q5_1, Q5_0, Q4_1 and Q4_0 to F32 conversion support. (llama/12000)
Browse files- ggml/src/ggml-cuda/cpy.cu +92 -7
- ggml/src/ggml-cuda/ggml-cuda.cu +12 -0
ggml/src/ggml-cuda/cpy.cu
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
#include "cpy.cuh"
|
|
|
|
| 2 |
|
| 3 |
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
| 4 |
|
|
@@ -82,13 +83,14 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
|
| 82 |
}
|
| 83 |
|
| 84 |
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
| 92 |
}
|
| 93 |
}
|
| 94 |
|
|
@@ -225,6 +227,18 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
|
| 225 |
memcpy(dsti->qh, &qh, sizeof(qh));
|
| 226 |
}
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
| 230 |
if (x <= val[0]) return 0;
|
|
@@ -387,6 +401,19 @@ static void ggml_cpy_f32_q4_0_cuda(
|
|
| 387 |
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 388 |
}
|
| 389 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
static void ggml_cpy_f32_q4_1_cuda(
|
| 391 |
const char * cx, char * cdst, const int ne,
|
| 392 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
@@ -398,6 +425,19 @@ static void ggml_cpy_f32_q4_1_cuda(
|
|
| 398 |
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 399 |
}
|
| 400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
static void ggml_cpy_f32_q5_0_cuda(
|
| 402 |
const char * cx, char * cdst, const int ne,
|
| 403 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
@@ -409,6 +449,19 @@ static void ggml_cpy_f32_q5_0_cuda(
|
|
| 409 |
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 410 |
}
|
| 411 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
static void ggml_cpy_f32_q5_1_cuda(
|
| 413 |
const char * cx, char * cdst, const int ne,
|
| 414 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
@@ -420,6 +473,19 @@ static void ggml_cpy_f32_q5_1_cuda(
|
|
| 420 |
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 421 |
}
|
| 422 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
static void ggml_cpy_f32_iq4_nl_cuda(
|
| 424 |
const char * cx, char * cdst, const int ne,
|
| 425 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
@@ -488,14 +554,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
| 488 |
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 489 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
| 490 |
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
|
|
|
| 491 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
| 492 |
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
|
|
|
| 493 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
| 494 |
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
|
|
|
| 495 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
| 496 |
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 497 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
| 498 |
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
| 499 |
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
| 500 |
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 501 |
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
@@ -524,14 +601,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
|
| 524 |
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
|
| 525 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
| 526 |
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
|
|
|
|
|
|
| 527 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
| 528 |
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
|
|
|
|
|
|
| 529 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
| 530 |
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
|
|
|
|
|
|
| 531 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
| 532 |
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
| 533 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
| 534 |
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
|
|
|
|
|
|
| 535 |
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
| 536 |
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
| 537 |
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
|
|
| 1 |
#include "cpy.cuh"
|
| 2 |
+
#include "dequantize.cuh"
|
| 3 |
|
| 4 |
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
| 5 |
|
|
|
|
| 83 |
}
|
| 84 |
|
| 85 |
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
| 86 |
+
float * cdstf = (float *)(cdsti);
|
| 87 |
+
|
| 88 |
+
#pragma unroll
|
| 89 |
+
for (int j = 0; j < QK8_0; j += 2) {
|
| 90 |
+
dfloat2 dq;
|
| 91 |
+
dequantize_q8_0(cxi, 0, j, dq);
|
| 92 |
+
*(cdstf + j) = dq.x;
|
| 93 |
+
*(cdstf + j + 1) = dq.y;
|
| 94 |
}
|
| 95 |
}
|
| 96 |
|
|
|
|
| 227 |
memcpy(dsti->qh, &qh, sizeof(qh));
|
| 228 |
}
|
| 229 |
|
| 230 |
+
template<dequantize_kernel_t dequant, int qk>
|
| 231 |
+
static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
| 232 |
+
float * cdstf = (float *)(cdsti);
|
| 233 |
+
|
| 234 |
+
#pragma unroll
|
| 235 |
+
for (int j = 0; j < qk/2; j++) {
|
| 236 |
+
dfloat2 dq;
|
| 237 |
+
dequant(cxi, 0, j, dq);
|
| 238 |
+
*(cdstf + j) = dq.x;
|
| 239 |
+
*(cdstf + j + qk/2) = dq.y;
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
|
| 243 |
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
| 244 |
if (x <= val[0]) return 0;
|
|
|
|
| 401 |
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 402 |
}
|
| 403 |
|
| 404 |
+
static void ggml_cpy_q4_0_f32_cuda(
|
| 405 |
+
const char * cx, char * cdst, const int ne,
|
| 406 |
+
const int ne00, const int ne01, const int ne02,
|
| 407 |
+
const int nb00, const int nb01, const int nb02,
|
| 408 |
+
const int nb03, const int ne10, const int ne11, const int ne12,
|
| 409 |
+
const int nb10, const int nb11, const int nb12, const int nb13,
|
| 410 |
+
cudaStream_t stream) {
|
| 411 |
+
const int num_blocks = ne;
|
| 412 |
+
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
| 413 |
+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
| 414 |
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
static void ggml_cpy_f32_q4_1_cuda(
|
| 418 |
const char * cx, char * cdst, const int ne,
|
| 419 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
|
| 425 |
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 426 |
}
|
| 427 |
|
| 428 |
+
static void ggml_cpy_q4_1_f32_cuda(
|
| 429 |
+
const char * cx, char * cdst, const int ne,
|
| 430 |
+
const int ne00, const int ne01, const int ne02,
|
| 431 |
+
const int nb00, const int nb01, const int nb02,
|
| 432 |
+
const int nb03, const int ne10, const int ne11, const int ne12,
|
| 433 |
+
const int nb10, const int nb11, const int nb12, const int nb13,
|
| 434 |
+
cudaStream_t stream) {
|
| 435 |
+
const int num_blocks = ne;
|
| 436 |
+
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
| 437 |
+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
| 438 |
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
static void ggml_cpy_f32_q5_0_cuda(
|
| 442 |
const char * cx, char * cdst, const int ne,
|
| 443 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
|
| 449 |
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 450 |
}
|
| 451 |
|
| 452 |
+
static void ggml_cpy_q5_0_f32_cuda(
|
| 453 |
+
const char * cx, char * cdst, const int ne,
|
| 454 |
+
const int ne00, const int ne01, const int ne02,
|
| 455 |
+
const int nb00, const int nb01, const int nb02,
|
| 456 |
+
const int nb03, const int ne10, const int ne11, const int ne12,
|
| 457 |
+
const int nb10, const int nb11, const int nb12, const int nb13,
|
| 458 |
+
cudaStream_t stream) {
|
| 459 |
+
const int num_blocks = ne;
|
| 460 |
+
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
| 461 |
+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
| 462 |
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
static void ggml_cpy_f32_q5_1_cuda(
|
| 466 |
const char * cx, char * cdst, const int ne,
|
| 467 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
|
| 473 |
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 474 |
}
|
| 475 |
|
| 476 |
+
static void ggml_cpy_q5_1_f32_cuda(
|
| 477 |
+
const char * cx, char * cdst, const int ne,
|
| 478 |
+
const int ne00, const int ne01, const int ne02,
|
| 479 |
+
const int nb00, const int nb01, const int nb02,
|
| 480 |
+
const int nb03, const int ne10, const int ne11, const int ne12,
|
| 481 |
+
const int nb10, const int nb11, const int nb12, const int nb13,
|
| 482 |
+
cudaStream_t stream) {
|
| 483 |
+
const int num_blocks = ne;
|
| 484 |
+
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
| 485 |
+
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
| 486 |
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
static void ggml_cpy_f32_iq4_nl_cuda(
|
| 490 |
const char * cx, char * cdst, const int ne,
|
| 491 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
|
| 554 |
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 555 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
| 556 |
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 557 |
+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
| 558 |
+
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
| 559 |
+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 560 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
| 561 |
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 562 |
+
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
| 563 |
+
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
| 564 |
+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 565 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
| 566 |
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 567 |
+
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
| 568 |
+
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
| 569 |
+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 570 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
| 571 |
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 572 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
| 573 |
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 574 |
+
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
| 575 |
+
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 576 |
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
| 577 |
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 578 |
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
|
|
| 601 |
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
|
| 602 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
| 603 |
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
| 604 |
+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
| 605 |
+
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
|
| 606 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
| 607 |
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
| 608 |
+
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
| 609 |
+
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
|
| 610 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
| 611 |
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
| 612 |
+
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
| 613 |
+
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
|
| 614 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
| 615 |
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
| 616 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
| 617 |
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
| 618 |
+
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
| 619 |
+
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
|
| 620 |
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
| 621 |
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
| 622 |
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -3075,15 +3075,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3075 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
|
| 3076 |
return true;
|
| 3077 |
}
|
|
|
|
|
|
|
|
|
|
| 3078 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
|
| 3079 |
return true;
|
| 3080 |
}
|
|
|
|
|
|
|
|
|
|
| 3081 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
|
| 3082 |
return true;
|
| 3083 |
}
|
|
|
|
|
|
|
|
|
|
| 3084 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
|
| 3085 |
return true;
|
| 3086 |
}
|
|
|
|
|
|
|
|
|
|
| 3087 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
|
| 3088 |
return true;
|
| 3089 |
}
|
|
|
|
| 3075 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
|
| 3076 |
return true;
|
| 3077 |
}
|
| 3078 |
+
if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
|
| 3079 |
+
return true;
|
| 3080 |
+
}
|
| 3081 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
|
| 3082 |
return true;
|
| 3083 |
}
|
| 3084 |
+
if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
|
| 3085 |
+
return true;
|
| 3086 |
+
}
|
| 3087 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
|
| 3088 |
return true;
|
| 3089 |
}
|
| 3090 |
+
if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
|
| 3091 |
+
return true;
|
| 3092 |
+
}
|
| 3093 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
|
| 3094 |
return true;
|
| 3095 |
}
|
| 3096 |
+
if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
|
| 3097 |
+
return true;
|
| 3098 |
+
}
|
| 3099 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
|
| 3100 |
return true;
|
| 3101 |
}
|