etasnadi commited on
Commit
32985b0
·
1 Parent(s): 9896625

ggml-vulkan: adds support for op CONV_TRANSPOSE_1D (llama/13813)

Browse files

* * ggml-vulkan: adds op CONV_TRANSPOSE_1D

* test-backend-ops: adds more spohisticated tests for CONV_TRANSPOSE_1D

* Missing barrier added to shader.
Number of additional tests reduced to 108.

* * Fixes typo in variable name.

* Removes extra whitespaces.

* Adds int64->int32 casts to prevent possible warnings.

* Problem size reduced in tests to pass tests with llvmpipe.

* supports_op condition moved from unintended position

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -396,6 +396,7 @@ struct vk_device_struct {
396
  vk_pipeline pipeline_count_equal_i32;
397
  vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
398
  vk_pipeline pipeline_timestep_embedding_f32;
 
399
  vk_pipeline pipeline_pool2d_f32;
400
  vk_pipeline pipeline_rwkv_wkv6_f32;
401
  vk_pipeline pipeline_rwkv_wkv7_f32;
@@ -706,6 +707,21 @@ struct vk_op_timestep_embedding_push_constants {
706
  uint32_t max_period;
707
  };
708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709
  struct vk_op_pool2d_push_constants {
710
  uint32_t IW; uint32_t IH;
711
  uint32_t OW; uint32_t OH;
@@ -2726,6 +2742,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2726
 
2727
  ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
2728
 
 
 
2729
  ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
2730
 
2731
  ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
@@ -6392,6 +6410,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6392
  return ctx->device->pipeline_timestep_embedding_f32;
6393
  }
6394
  return nullptr;
 
 
 
 
 
6395
  case GGML_OP_POOL_2D:
6396
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6397
  return ctx->device->pipeline_pool2d_f32;
@@ -6726,6 +6749,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6726
  uint32_t half_ceil = (dim + 1) / 2;
6727
  elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
6728
  } break;
 
 
 
 
6729
  case GGML_OP_POOL_2D:
6730
  {
6731
  const uint32_t N = dst->ne[3];
@@ -7529,6 +7556,37 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
7529
  }, dryrun);
7530
  }
7531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7532
  static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7533
  uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
7534
  const int32_t k1 = dst->op_params[1];
@@ -8600,6 +8658,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8600
  case GGML_OP_COUNT_EQUAL:
8601
  case GGML_OP_IM2COL:
8602
  case GGML_OP_TIMESTEP_EMBEDDING:
 
8603
  case GGML_OP_POOL_2D:
8604
  case GGML_OP_CONV_2D_DW:
8605
  case GGML_OP_RWKV_WKV6:
@@ -8664,6 +8723,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8664
  case GGML_OP_COUNT_EQUAL:
8665
  case GGML_OP_IM2COL:
8666
  case GGML_OP_TIMESTEP_EMBEDDING:
 
8667
  case GGML_OP_POOL_2D:
8668
  case GGML_OP_CONV_2D_DW:
8669
  case GGML_OP_LEAKY_RELU:
@@ -8835,6 +8895,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8835
  case GGML_OP_TIMESTEP_EMBEDDING:
8836
  ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
8837
 
 
 
 
 
8838
  break;
8839
  case GGML_OP_POOL_2D:
8840
  ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
@@ -8963,6 +9027,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
8963
  case GGML_OP_COUNT_EQUAL:
8964
  case GGML_OP_IM2COL:
8965
  case GGML_OP_TIMESTEP_EMBEDDING:
 
8966
  case GGML_OP_POOL_2D:
8967
  case GGML_OP_CONV_2D_DW:
8968
  case GGML_OP_RWKV_WKV6:
@@ -10024,6 +10089,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10024
  case GGML_OP_LEAKY_RELU:
10025
  case GGML_OP_OPT_STEP_ADAMW:
10026
  return true;
 
 
10027
  default:
10028
  return false;
10029
  }
@@ -10515,6 +10582,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10515
  const int32_t dim = tensor->op_params[0];
10516
  const int32_t max_period = tensor->op_params[1];
10517
  tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
 
 
 
 
 
10518
  } else if (tensor->op == GGML_OP_POOL_2D) {
10519
  enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
10520
  const int32_t k0 = tensor->op_params[1];
 
396
  vk_pipeline pipeline_count_equal_i32;
397
  vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
398
  vk_pipeline pipeline_timestep_embedding_f32;
399
+ vk_pipeline pipeline_conv_transpose_1d_f32;
400
  vk_pipeline pipeline_pool2d_f32;
401
  vk_pipeline pipeline_rwkv_wkv6_f32;
402
  vk_pipeline pipeline_rwkv_wkv7_f32;
 
707
  uint32_t max_period;
708
  };
709
 
710
+ struct vk_op_conv_transpose_1d_push_constants {
711
+ uint32_t Cout;
712
+ uint32_t Cin;
713
+ uint32_t K;
714
+ uint32_t L;
715
+ uint32_t KL;
716
+
717
+ uint32_t nb01;
718
+ uint32_t nb02;
719
+ uint32_t nb11;
720
+ uint32_t nb1;
721
+
722
+ int32_t s0;
723
+ };
724
+
725
  struct vk_op_pool2d_push_constants {
726
  uint32_t IW; uint32_t IH;
727
  uint32_t OW; uint32_t OH;
 
2742
 
2743
  ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
2744
 
2745
+ ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
2746
+
2747
  ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
2748
 
2749
  ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
 
6410
  return ctx->device->pipeline_timestep_embedding_f32;
6411
  }
6412
  return nullptr;
6413
+ case GGML_OP_CONV_TRANSPOSE_1D:
6414
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6415
+ return ctx->device->pipeline_conv_transpose_1d_f32;
6416
+ }
6417
+ return nullptr;
6418
  case GGML_OP_POOL_2D:
6419
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6420
  return ctx->device->pipeline_pool2d_f32;
 
6749
  uint32_t half_ceil = (dim + 1) / 2;
6750
  elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
6751
  } break;
6752
+ case GGML_OP_CONV_TRANSPOSE_1D:
6753
+ {
6754
+ elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
6755
+ } break;
6756
  case GGML_OP_POOL_2D:
6757
  {
6758
  const uint32_t N = dst->ne[3];
 
7556
  }, dryrun);
7557
  }
7558
 
7559
+ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7560
+ // src0: (K, Cout, Cin, 1) -- kernel
7561
+ // src1: (L, Cin, 1, 1) -- input
7562
+ // dst: (*, Cout, 1, 1)
7563
+
7564
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
7565
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
7566
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
7567
+
7568
+ GGML_TENSOR_BINARY_OP_LOCALS
7569
+
7570
+ GGML_ASSERT(nb00 == sizeof(float));
7571
+ GGML_ASSERT(nb10 == sizeof(float));
7572
+
7573
+ const int32_t s0 = dst->op_params[0];
7574
+
7575
+ vk_op_conv_transpose_1d_push_constants p{};
7576
+ p.Cout = static_cast<uint32_t>(ne01);
7577
+ p.Cin = static_cast<uint32_t>(ne02);
7578
+ p.K = static_cast<uint32_t>(ne00);
7579
+ p.L = static_cast<uint32_t>(ne10);
7580
+ p.KL = static_cast<uint32_t>(ne0);
7581
+ p.nb01 = static_cast<uint32_t>(nb01 / nb00);
7582
+ p.nb02 = static_cast<uint32_t>(nb02 / nb00);
7583
+ p.nb11 = static_cast<uint32_t>(nb11 / nb10);
7584
+ p.nb1 = static_cast<uint32_t>(nb1 / nb0);
7585
+ p.s0 = static_cast<uint32_t>(s0);
7586
+
7587
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
7588
+ }
7589
+
7590
  static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7591
  uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
7592
  const int32_t k1 = dst->op_params[1];
 
8658
  case GGML_OP_COUNT_EQUAL:
8659
  case GGML_OP_IM2COL:
8660
  case GGML_OP_TIMESTEP_EMBEDDING:
8661
+ case GGML_OP_CONV_TRANSPOSE_1D:
8662
  case GGML_OP_POOL_2D:
8663
  case GGML_OP_CONV_2D_DW:
8664
  case GGML_OP_RWKV_WKV6:
 
8723
  case GGML_OP_COUNT_EQUAL:
8724
  case GGML_OP_IM2COL:
8725
  case GGML_OP_TIMESTEP_EMBEDDING:
8726
+ case GGML_OP_CONV_TRANSPOSE_1D:
8727
  case GGML_OP_POOL_2D:
8728
  case GGML_OP_CONV_2D_DW:
8729
  case GGML_OP_LEAKY_RELU:
 
8895
  case GGML_OP_TIMESTEP_EMBEDDING:
8896
  ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
8897
 
8898
+ break;
8899
+ case GGML_OP_CONV_TRANSPOSE_1D:
8900
+ ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun);
8901
+
8902
  break;
8903
  case GGML_OP_POOL_2D:
8904
  ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
 
9027
  case GGML_OP_COUNT_EQUAL:
9028
  case GGML_OP_IM2COL:
9029
  case GGML_OP_TIMESTEP_EMBEDDING:
9030
+ case GGML_OP_CONV_TRANSPOSE_1D:
9031
  case GGML_OP_POOL_2D:
9032
  case GGML_OP_CONV_2D_DW:
9033
  case GGML_OP_RWKV_WKV6:
 
10089
  case GGML_OP_LEAKY_RELU:
10090
  case GGML_OP_OPT_STEP_ADAMW:
10091
  return true;
10092
+ case GGML_OP_CONV_TRANSPOSE_1D:
10093
+ return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
10094
  default:
10095
  return false;
10096
  }
 
10582
  const int32_t dim = tensor->op_params[0];
10583
  const int32_t max_period = tensor->op_params[1];
10584
  tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
10585
+ } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){
10586
+ const int32_t s0 = tensor->op_params[0];
10587
+ const int32_t p0 = tensor->op_params[1];
10588
+ const int32_t d0 = tensor->op_params[2];
10589
+ tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
10590
  } else if (tensor->op == GGML_OP_POOL_2D) {
10591
  enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
10592
  const int32_t k0 = tensor->op_params[1];
ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #include "types.comp"
4
+
5
+ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
6
+ layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
7
+ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout]
8
+
9
+ layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;
10
+
11
+ layout (push_constant) uniform parameter {
12
+ uint32_t Cout;
13
+ uint32_t Cin;
14
+ uint32_t K;
15
+ uint32_t L;
16
+ uint32_t KL;
17
+
18
+ uint32_t nb01;
19
+ uint32_t nb02;
20
+ uint32_t nb11;
21
+ uint32_t nb1;
22
+
23
+ int32_t s0;
24
+ } p;
25
+
26
+
27
+ uint32_t Cout_idx = gl_WorkGroupID.x;
28
+ const uint32_t bs = gl_WorkGroupSize.x;
29
+ uint32_t tid = gl_LocalInvocationID.x;
30
+ // Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
31
+ uint32_t tmp_len = bs*p.s0+p.K;
32
+ shared D_TYPE tmp[4096];
33
+
34
+ uint splitWork(uint workSize){
35
+ return (bs + workSize -1) / bs;
36
+ }
37
+
38
+ void main(){
39
+ for(uint32_t i = 0; i < splitWork(tmp_len); i++){
40
+ uint32_t idx = i*bs+tid;
41
+ if(idx < tmp_len){
42
+ tmp[idx] = 0.0;
43
+ }
44
+ }
45
+
46
+ uint32_t L_blocks = splitWork(p.L);
47
+ for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
48
+ if(L_block_id > 0){
49
+ barrier();
50
+ // Shift values in tmp to the current processing window
51
+ for(int i = 0; i < splitWork(tmp_len); i++){
52
+ uint32_t idx = i*bs+tid;
53
+ if(idx >= bs*p.s0 && idx < tmp_len){
54
+ tmp[idx-bs*p.s0] = tmp[idx];
55
+ tmp[idx] = 0.0;
56
+ }else if(idx >= p.K && idx < bs*p.s0){
57
+ tmp[idx] = 0.0;
58
+ }
59
+ }
60
+ }
61
+ barrier();
62
+
63
+ // Save contributions of the block to tmp
64
+ uint32_t L_idx = L_block_id*bs + tid;
65
+ for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
66
+ D_TYPE dp = 0.0;
67
+ for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
68
+ A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
69
+ if(L_idx < p.L){
70
+ B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
71
+ dp = fma(elemKrn, elemInp, dp);
72
+ }
73
+ }
74
+ tmp[tid*p.s0 + K_idx] += dp;
75
+ barrier();
76
+ }
77
+
78
+ // Save the computed values except the last block that can have different size
79
+ uint32_t KLb_idx = L_block_id*bs*p.s0;
80
+ if(L_block_id < L_blocks-1){
81
+ for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
82
+ uint32_t sh_idx = p.s0*tid+s0_idx;
83
+ uint32_t KL_idx = KLb_idx+sh_idx;
84
+ if(KL_idx < p.KL){
85
+ data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
86
+ }
87
+ }
88
+ }
89
+ }
90
+
91
+ for(uint32_t i = 0; i < splitWork(tmp_len); i++){
92
+ uint32_t idx = i*bs+tid;
93
+ uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
94
+ if(KL_idx < p.KL){
95
+ data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
96
+ }
97
+ }
98
+ }
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -622,6 +622,8 @@ void process_shaders() {
622
 
623
  string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
624
 
 
 
625
  string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
626
 
627
  string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
622
 
623
  string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
624
 
625
+ string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
626
+
627
  string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
628
 
629
  string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));