am17an commited on
Commit
c7936d3
·
1 Parent(s): 1c3b94c

CUDA: add bf16 and f32 support to cublas_mul_mat_batched (llama/14361)

Browse files

* CUDA: add bf16 and f32 support to cublas_mul_mat_batched

* Review: add type traits and make function more generic

* Review: make check more explicit, add back comments, and fix formatting

* Review: fix formatting, remove useless type conversion, fix naming for bools

ggml/src/ggml-cuda/convert.cu CHANGED
@@ -728,3 +728,25 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
728
  return nullptr;
729
  }
730
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
  return nullptr;
729
  }
730
  }
731
+
732
+ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
733
+ switch (type) {
734
+ case GGML_TYPE_F32:
735
+ return convert_unary_cuda<float, nv_bfloat16>;
736
+ case GGML_TYPE_F16:
737
+ return convert_unary_cuda<half, nv_bfloat16>;
738
+ default:
739
+ return nullptr;
740
+ }
741
+ }
742
+
743
+ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
744
+ switch (type) {
745
+ case GGML_TYPE_F16:
746
+ return convert_unary_cuda<half, float>;
747
+ case GGML_TYPE_BF16:
748
+ return convert_unary_cuda<nv_bfloat16, float>;
749
+ default:
750
+ return nullptr;
751
+ }
752
+ }
ggml/src/ggml-cuda/convert.cuh CHANGED
@@ -22,5 +22,10 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
22
  int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
23
  int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
24
 
 
25
  typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
 
 
 
26
  to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
 
 
22
  int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
23
  int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
24
 
25
+ typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
26
  typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
27
+ typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
28
+
29
+ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
30
  to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
31
+ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -1749,7 +1749,7 @@ static void ggml_cuda_op_mul_mat(
1749
  }
1750
 
1751
  static __global__ void k_compute_batched_ptrs(
1752
- const half * src0_as_f16, const half * src1_as_f16, char * dst,
1753
  const void ** ptrs_src, void ** ptrs_dst,
1754
  int64_t ne12, int64_t ne13,
1755
  int64_t ne23,
@@ -1772,83 +1772,131 @@ static __global__ void k_compute_batched_ptrs(
1772
  ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
1773
  }
1774
 
1775
- static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1776
  GGML_ASSERT(!ggml_is_transposed(src0));
1777
  GGML_ASSERT(!ggml_is_transposed(src1));
1778
-
1779
  GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1780
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
 
1781
 
1782
  // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
1783
  // As long as dst is contiguous this does not matter though.
1784
- GGML_ASSERT(ggml_is_contiguous(dst));
1785
 
1786
  GGML_TENSOR_BINARY_OP_LOCALS
1787
 
1788
  const int64_t ne_dst = ggml_nelements(dst);
1789
-
1790
  cudaStream_t main_stream = ctx.stream();
1791
-
1792
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
1793
 
1794
- const half * src0_f16 = (const half *) src0->data;
1795
  float * dst_ddf = (float *) dst->data;
1796
-
1797
- const half * src1_f16 = (const half *) src1->data;
1798
  const size_t ts_src1 = ggml_type_size(src1->type);
1799
  GGML_ASSERT(nb10 == ts_src1);
1800
  int64_t s11 = nb11 / ts_src1;
1801
  int64_t s12 = nb12 / ts_src1;
1802
  int64_t s13 = nb13 / ts_src1;
1803
- ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
1804
 
1805
- // convert src1 to fp16
1806
- if (src1->type != GGML_TYPE_F16) {
1807
- const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
1808
- const int64_t ne_src1 = ggml_nelements(src1);
1809
- src1_f16_alloc.alloc(ne_src1);
1810
- GGML_ASSERT(to_fp16_cuda != nullptr);
1811
 
1812
- to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
 
 
 
 
 
 
 
 
 
 
 
 
1813
 
1814
- src1_f16 = src1_f16_alloc.get();
 
 
 
1815
  s11 = ne10;
1816
  s12 = ne11*s11;
1817
  s13 = ne12*s12;
1818
  }
1819
 
1820
- ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
 
1821
  char * dst_t;
1822
-
1823
- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1824
- cudaDataType_t cu_data_type = CUDA_R_16F;
1825
-
1826
- // dst strides
1827
  size_t nbd2 = dst->nb[2];
1828
  size_t nbd3 = dst->nb[3];
1829
 
1830
- const half alpha_f16 = 1.0f;
1831
- const half beta_f16 = 0.0f;
1832
-
 
 
 
1833
  const float alpha_f32 = 1.0f;
1834
- const float beta_f32 = 0.0f;
1835
-
1836
- const void * alpha = &alpha_f16;
1837
- const void * beta = &beta_f16;
1838
 
1839
  if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1840
- dst_t = (char *) dst_f16.alloc(ne_dst);
1841
-
1842
- nbd2 /= sizeof(float) / sizeof(half);
1843
- nbd3 /= sizeof(float) / sizeof(half);
 
 
 
1844
  } else {
1845
  dst_t = (char *) dst_ddf;
1846
-
1847
  cu_compute_type = CUBLAS_COMPUTE_32F;
1848
- cu_data_type = CUDA_R_32F;
1849
-
1850
  alpha = &alpha_f32;
1851
- beta = &beta_f32;
1852
  }
1853
 
1854
  int id = ggml_cuda_get_device();
@@ -1856,7 +1904,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1856
  if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1857
  cu_compute_type = CUBLAS_COMPUTE_32F;
1858
  alpha = &alpha_f32;
1859
- beta = &beta_f32;
1860
  }
1861
 
1862
  GGML_ASSERT(ne12 % ne02 == 0);
@@ -1866,35 +1914,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1866
  const int64_t r2 = ne12/ne02;
1867
  const int64_t r3 = ne13/ne03;
1868
 
1869
- #if 0
1870
- // use cublasGemmEx
1871
- {
1872
- for (int i13 = 0; i13 < ne13; ++i13) {
1873
- for (int i12 = 0; i12 < ne12; ++i12) {
1874
- int i03 = i13 / r3;
1875
- int i02 = i12 / r2;
1876
-
1877
- CUBLAS_CHECK(
1878
- cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1879
- ne01, ne11, ne10,
1880
- alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
1881
- src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
1882
- beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
1883
- cu_compute_type,
1884
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1885
- }
1886
- }
1887
- }
1888
- #else
1889
  if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1890
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
1891
  // use cublasGemmStridedBatchedEx
1892
  CUBLAS_CHECK(
1893
  cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1894
  ne01, ne11, ne10,
1895
- alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1896
- src1_f16, CUDA_R_16F, s11, s12, // strideB
1897
- beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1898
  ne12*ne13,
1899
  cu_compute_type,
1900
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1905,34 +1933,55 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1905
  ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
1906
  ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1907
 
 
 
1908
  dim3 block_dims(ne13, ne12);
1909
  k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1910
- src0_f16, src1_f16, dst_t,
1911
  ptrs_src.get(), ptrs_dst.get(),
1912
  ne12, ne13,
1913
  ne23,
1914
  nb02, nb03,
1915
- src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
1916
- src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
1917
  nbd2, nbd3,
1918
  r2, r3);
 
1919
  CUDA_CHECK(cudaGetLastError());
1920
 
1921
  CUBLAS_CHECK(
1922
  cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1923
  ne01, ne11, ne10,
1924
- alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1925
- (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1926
- beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1927
  ne23,
1928
  cu_compute_type,
1929
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1930
  }
1931
- #endif
1932
 
1933
- if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1934
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1935
- to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1936
  }
1937
  }
1938
 
@@ -1984,6 +2033,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1984
  //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
1985
  //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
1986
 
 
 
 
 
 
 
1987
  if (!split && use_mul_mat_vec) {
1988
  // the custom F16 vector kernel can be used over batched cuBLAS GEMM
1989
  // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
@@ -1992,8 +2047,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1992
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
1993
  } else if (!split && use_mul_mat_q) {
1994
  ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
1995
- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1996
- !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1997
  // general KQ + KQV multi-batch without FlashAttention
1998
  ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
1999
  } else if (use_mul_mat_vec) {
 
1749
  }
1750
 
1751
  static __global__ void k_compute_batched_ptrs(
1752
+ const void * src0_as_f16, const void * src1_as_f16, char * dst,
1753
  const void ** ptrs_src, void ** ptrs_dst,
1754
  int64_t ne12, int64_t ne13,
1755
  int64_t ne23,
 
1772
  ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
1773
  }
1774
 
1775
+ // Type traits for mapping ggml types to CUDA/cuBLAS types
1776
+ template<ggml_type T>
1777
+ struct batched_mul_mat_traits;
1778
+
1779
+ template<>
1780
+ struct batched_mul_mat_traits<GGML_TYPE_F32> {
1781
+ using cuda_type = float;
1782
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1783
+ static inline const cudaDataType_t data_type = CUDA_R_32F;
1784
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
1785
+ static inline const float alpha = 1.0f;
1786
+ static inline const float beta = 0.0f;
1787
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
1788
+ static inline const void* get_beta() { static const float val = beta; return &val; }
1789
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
1790
+ };
1791
+
1792
+ template<>
1793
+ struct batched_mul_mat_traits<GGML_TYPE_BF16> {
1794
+ using cuda_type = nv_bfloat16;
1795
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1796
+ static inline const cudaDataType_t data_type = CUDA_R_16BF;
1797
+ static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
1798
+ static inline const float alpha = 1.0f;
1799
+ static inline const float beta = 0.0f;
1800
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
1801
+ static inline const void* get_beta() { static const float val = beta; return &val; }
1802
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
1803
+ };
1804
+
1805
+ template<>
1806
+ struct batched_mul_mat_traits<GGML_TYPE_F16> {
1807
+ using cuda_type = half;
1808
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
1809
+ static inline const cudaDataType_t data_type = CUDA_R_16F;
1810
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
1811
+ static inline const half alpha = 1.0;
1812
+ static inline const half beta = 0.0;
1813
+ static inline const void* get_alpha() { static const half val = alpha; return &val; }
1814
+ static inline const void* get_beta() { static const half val = beta; return &val; }
1815
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
1816
+ };
1817
+
1818
+ template<ggml_type src0_type>
1819
+ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1820
+ using traits = batched_mul_mat_traits<src0_type>;
1821
+ using cuda_t = typename traits::cuda_type;
1822
+
1823
  GGML_ASSERT(!ggml_is_transposed(src0));
1824
  GGML_ASSERT(!ggml_is_transposed(src1));
 
1825
  GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1826
+ GGML_ASSERT(src0->type == src0_type);
1827
+ GGML_ASSERT(ggml_is_contiguous(dst));
1828
 
1829
  // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
1830
  // As long as dst is contiguous this does not matter though.
 
1831
 
1832
  GGML_TENSOR_BINARY_OP_LOCALS
1833
 
1834
  const int64_t ne_dst = ggml_nelements(dst);
 
1835
  cudaStream_t main_stream = ctx.stream();
 
1836
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
1837
 
 
1838
  float * dst_ddf = (float *) dst->data;
 
 
1839
  const size_t ts_src1 = ggml_type_size(src1->type);
1840
  GGML_ASSERT(nb10 == ts_src1);
1841
  int64_t s11 = nb11 / ts_src1;
1842
  int64_t s12 = nb12 / ts_src1;
1843
  int64_t s13 = nb13 / ts_src1;
 
1844
 
1845
+ const cuda_t * src0_ptr = nullptr;
1846
+ const cuda_t * src1_ptr = nullptr;
 
 
 
 
1847
 
1848
+ ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
1849
+ ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
1850
+
1851
+ // Handle src0
1852
+ src0_ptr = (const cuda_t *) src0->data;
1853
+
1854
+ // Handle src1 - convert if necessary
1855
+ if (src1->type == src0_type) {
1856
+ src1_ptr = (const cuda_t *) src1->data;
1857
+ } else {
1858
+ // Convert src1 to target type using traits conversion functions
1859
+ const int64_t ne_src1 = ggml_nelements(src1);
1860
+ src1_alloc.alloc(ne_src1);
1861
 
1862
+ const auto convert_func = traits::get_nc_converter(src1->type);
1863
+ GGML_ASSERT(convert_func != nullptr);
1864
+ convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1865
+ src1_ptr = src1_alloc.get();
1866
  s11 = ne10;
1867
  s12 = ne11*s11;
1868
  s13 = ne12*s12;
1869
  }
1870
 
1871
+ // Setup destination buffer
1872
+ ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
1873
  char * dst_t;
 
 
 
 
 
1874
  size_t nbd2 = dst->nb[2];
1875
  size_t nbd3 = dst->nb[3];
1876
 
1877
+ cublasComputeType_t cu_compute_type = traits::compute_type;
1878
+ cudaDataType_t cu_data_type = traits::data_type;
1879
+ cudaDataType_t cu_data_type_a = traits::data_type;
1880
+ cudaDataType_t cu_data_type_b = traits::data_type;
1881
+ const void * alpha = traits::get_alpha();
1882
+ const void * beta = traits::get_beta();
1883
  const float alpha_f32 = 1.0f;
1884
+ const float beta_f32 = 0.0f;
 
 
 
1885
 
1886
  if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1887
+ if constexpr (src0_type == GGML_TYPE_F32) {
1888
+ dst_t = (char *) dst_ddf; // Direct F32 output
1889
+ } else {
1890
+ dst_t = (char *) dst_temp.alloc(ne_dst);
1891
+ nbd2 /= sizeof(float) / sizeof(cuda_t);
1892
+ nbd3 /= sizeof(float) / sizeof(cuda_t);
1893
+ }
1894
  } else {
1895
  dst_t = (char *) dst_ddf;
 
1896
  cu_compute_type = CUBLAS_COMPUTE_32F;
1897
+ cu_data_type = CUDA_R_32F;
 
1898
  alpha = &alpha_f32;
1899
+ beta = &beta_f32;
1900
  }
1901
 
1902
  int id = ggml_cuda_get_device();
 
1904
  if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1905
  cu_compute_type = CUBLAS_COMPUTE_32F;
1906
  alpha = &alpha_f32;
1907
+ beta = &beta_f32;
1908
  }
1909
 
1910
  GGML_ASSERT(ne12 % ne02 == 0);
 
1914
  const int64_t r2 = ne12/ne02;
1915
  const int64_t r3 = ne13/ne03;
1916
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1917
  if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1918
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
1919
  // use cublasGemmStridedBatchedEx
1920
  CUBLAS_CHECK(
1921
  cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1922
  ne01, ne11, ne10,
1923
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1924
+ src1_ptr, cu_data_type_b, s11, s12, // strideB
1925
+ beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1926
  ne12*ne13,
1927
  cu_compute_type,
1928
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 
1933
  ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
1934
  ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1935
 
1936
+ size_t src1_stride_size = sizeof(cuda_t);
1937
+
1938
  dim3 block_dims(ne13, ne12);
1939
  k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1940
+ src0_ptr, src1_ptr, dst_t,
1941
  ptrs_src.get(), ptrs_dst.get(),
1942
  ne12, ne13,
1943
  ne23,
1944
  nb02, nb03,
1945
+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
1946
+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
1947
  nbd2, nbd3,
1948
  r2, r3);
1949
+
1950
  CUDA_CHECK(cudaGetLastError());
1951
 
1952
  CUBLAS_CHECK(
1953
  cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1954
  ne01, ne11, ne10,
1955
+ alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
1956
+ (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
1957
+ beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1958
  ne23,
1959
  cu_compute_type,
1960
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1961
  }
 
1962
 
1963
+ // Convert output back to F32 if needed
1964
+ if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1965
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
1966
+ to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
1967
+ }
1968
+ }
1969
+
1970
+ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1971
+ GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1972
+
1973
+ switch (src0->type) {
1974
+ case GGML_TYPE_F32:
1975
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1976
+ break;
1977
+ case GGML_TYPE_BF16:
1978
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1979
+ break;
1980
+ case GGML_TYPE_F16:
1981
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1982
+ break;
1983
+ default:
1984
+ GGML_ABORT("Unsupported type");
1985
  }
1986
  }
1987
 
 
2033
  //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
2034
  //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
2035
 
2036
+ //TODO update for generic tensor parallelism
2037
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2038
+ bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039
+ bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2040
+ bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2041
+
2042
  if (!split && use_mul_mat_vec) {
2043
  // the custom F16 vector kernel can be used over batched cuBLAS GEMM
2044
  // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
 
2047
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
2048
  } else if (!split && use_mul_mat_q) {
2049
  ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
2050
+ } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
2051
+ && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2052
  // general KQ + KQV multi-batch without FlashAttention
2053
  ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
2054
  } else if (use_mul_mat_vec) {