jeffbolznv slaren commited on
Commit
737f12d
·
1 Parent(s): c7936d3

vulkan: Add fusion support for RMS_NORM+MUL (llama/14366)

Browse files

* vulkan: Add fusion support for RMS_NORM+MUL

- Add a use_count to ggml_tensor, so we can detect if an output is used more than once.
- Change the ggml-vulkan rms_norm shader to optionally multiply by another tensor.
- Add detection logic and basic fusion logic in ggml-vulkan.
- Add some testing support for fusion. Rather than computing one node at a time, allow
for computing the whole graph and just testing one node's results. Add rms_norm_mul tests
and enable a llama test.

* extract some common fusion logic

* fix -Winconsistent-missing-override

* move ggml_can_fuse to a common function

* build fix

* C and C++ versions of can_fuse

* move use count to the graph to avoid data races and double increments when used in multiple threads

* use hash table lookup to find node index

* change use_counts to be indexed by hash table slot

* minimize hash lookups

style fixes

* last node doesn't need single use.
fix type.
handle mul operands being swapped.

* remove redundant parameter

---------

Co-authored-by: slaren <[email protected]>

ggml/include/ggml-backend.h CHANGED
@@ -339,7 +339,7 @@ extern "C" {
339
  typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
340
 
341
  // Compare the output of two backends
342
- GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
343
 
344
  // Tensor initialization
345
  GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
 
339
  typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
340
 
341
  // Compare the output of two backends
342
+ GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node);
343
 
344
  // Tensor initialization
345
  GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
ggml/src/ggml-backend.cpp CHANGED
@@ -817,8 +817,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
817
  }
818
  if (sched->debug > 1) {
819
  ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
820
- GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
821
- fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
 
822
  for (int j = 0; j < GGML_MAX_SRC; j++) {
823
  struct ggml_tensor * src = node->src[j];
824
  if (src == NULL) {
@@ -1826,7 +1827,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
1826
  ggml_free(copy.ctx_unallocated);
1827
  }
1828
 
1829
- bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
1830
  struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
1831
  if (copy.buffer == NULL) {
1832
  return false;
@@ -1837,28 +1838,45 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
1837
 
1838
  assert(g1->n_nodes == g2->n_nodes);
1839
 
1840
- for (int i = 0; i < g1->n_nodes; i++) {
1841
- struct ggml_tensor * t1 = g1->nodes[i];
1842
- struct ggml_tensor * t2 = g2->nodes[i];
 
1843
 
1844
- assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
 
 
 
 
 
 
 
 
1845
 
1846
- struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
1847
- struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
 
 
 
1848
 
1849
- ggml_backend_graph_compute(backend1, &g1v);
1850
- ggml_backend_graph_compute(backend2, &g2v);
1851
 
1852
- if (ggml_is_view_op(t1->op)) {
1853
- continue;
1854
- }
1855
 
1856
- // compare results, calculate rms etc
1857
- if (!callback(i, t1, t2, user_data)) {
1858
- break;
 
 
 
 
 
 
 
 
1859
  }
1860
  }
1861
-
1862
  ggml_backend_graph_copy_free(copy);
1863
 
1864
  return true;
 
817
  }
818
  if (sched->debug > 1) {
819
  ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
820
+ GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
821
+ fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
822
+ graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]);
823
  for (int j = 0; j < GGML_MAX_SRC; j++) {
824
  struct ggml_tensor * src = node->src[j];
825
  if (src == NULL) {
 
1827
  ggml_free(copy.ctx_unallocated);
1828
  }
1829
 
1830
+ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) {
1831
  struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
1832
  if (copy.buffer == NULL) {
1833
  return false;
 
1838
 
1839
  assert(g1->n_nodes == g2->n_nodes);
1840
 
1841
+ if (test_node != nullptr) {
1842
+ // Compute the whole graph and only test the output for a specific tensor
1843
+ ggml_backend_graph_compute(backend1, g1);
1844
+ ggml_backend_graph_compute(backend2, g2);
1845
 
1846
+ int test_node_idx = -1;
1847
+ for (int i = 0; i < g1->n_nodes; i++) {
1848
+ struct ggml_tensor * t1 = g1->nodes[i];
1849
+ if (t1 == test_node) {
1850
+ test_node_idx = i;
1851
+ break;
1852
+ }
1853
+ }
1854
+ GGML_ASSERT(test_node_idx != -1);
1855
 
1856
+ callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
1857
+ } else {
1858
+ for (int i = 0; i < g1->n_nodes; i++) {
1859
+ struct ggml_tensor * t1 = g1->nodes[i];
1860
+ struct ggml_tensor * t2 = g2->nodes[i];
1861
 
1862
+ assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
 
1863
 
1864
+ struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
1865
+ struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
 
1866
 
1867
+ ggml_backend_graph_compute(backend1, &g1v);
1868
+ ggml_backend_graph_compute(backend2, &g2v);
1869
+
1870
+ if (ggml_is_view_op(t1->op)) {
1871
+ continue;
1872
+ }
1873
+
1874
+ // compare results, calculate rms etc
1875
+ if (!callback(i, t1, t2, user_data)) {
1876
+ break;
1877
+ }
1878
  }
1879
  }
 
1880
  ggml_backend_graph_copy_free(copy);
1881
 
1882
  return true;
ggml/src/ggml-impl.h CHANGED
@@ -301,6 +301,7 @@ struct ggml_cgraph {
301
  struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes
302
  struct ggml_tensor ** grad_accs; // accumulators for node gradients
303
  struct ggml_tensor ** leafs; // tensors with constant data
 
304
 
305
  struct ggml_hash_set visited_hash_set;
306
 
@@ -467,13 +468,76 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
467
  #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
468
  #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  #ifdef __cplusplus
471
  }
472
  #endif
473
 
474
  #ifdef __cplusplus
 
475
  #include <vector>
476
 
 
 
 
 
 
477
  // expose GGUF internals for test code
478
  GGML_API size_t gguf_type_size(enum gguf_type type);
479
  GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
 
301
  struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes
302
  struct ggml_tensor ** grad_accs; // accumulators for node gradients
303
  struct ggml_tensor ** leafs; // tensors with constant data
304
+ int32_t * use_counts;// number of uses of each tensor, indexed by hash table slot
305
 
306
  struct ggml_hash_set visited_hash_set;
307
 
 
468
  #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
469
  #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
470
 
471
+ // return true if the node's results are only used by N other nodes
472
+ // and can be fused into their calculations.
473
+ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
474
+ const struct ggml_tensor * node = cgraph->nodes[node_idx];
475
+
476
+ // check the use count against how many we're replacing
477
+ size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
478
+ if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
479
+ return false;
480
+ }
481
+
482
+ // if node is a view, some other node might be using the intermediate result
483
+ // via the view source.
484
+ if (node->view_src) {
485
+ return false;
486
+ }
487
+
488
+ // If the user requested output for the node, can't fuse
489
+ if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
490
+ return false;
491
+ }
492
+
493
+ return true;
494
+ }
495
+
496
+ // Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[]
497
+ // and are fusable. Nodes are considered fusable according to this function if:
498
+ // - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
499
+ // - all nodes except the last are a src of the following node.
500
+ // - all nodes are the same shape.
501
+ // TODO: Consider allowing GGML_OP_NONE nodes in between
502
+ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
503
+ if (node_idx + num_ops > cgraph->n_nodes) {
504
+ return false;
505
+ }
506
+
507
+ for (int i = 0; i < num_ops; ++i) {
508
+ struct ggml_tensor * node = cgraph->nodes[node_idx + i];
509
+ if (node->op != ops[i]) {
510
+ return false;
511
+ }
512
+ if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
513
+ return false;
514
+ }
515
+ if (i > 0) {
516
+ struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
517
+ if (node->src[0] != prev && node->src[1] != prev) {
518
+ return false;
519
+ }
520
+ if (!ggml_are_same_shape(node, prev)) {
521
+ return false;
522
+ }
523
+ }
524
+ }
525
+ return true;
526
+ }
527
+
528
  #ifdef __cplusplus
529
  }
530
  #endif
531
 
532
  #ifdef __cplusplus
533
+ #include <initializer_list>
534
  #include <vector>
535
 
536
+ // nicer C++ syntax for ggml_can_fuse
537
+ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
538
+ return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
539
+ }
540
+
541
  // expose GGUF internals for test code
542
  GGML_API size_t gguf_type_size(enum gguf_type type);
543
  GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -425,6 +425,7 @@ struct vk_device_struct {
425
  vk_pipeline pipeline_norm_f32;
426
  vk_pipeline pipeline_group_norm_f32;
427
  vk_pipeline pipeline_rms_norm_f32;
 
428
  vk_pipeline pipeline_rms_norm_back_f32;
429
  vk_pipeline pipeline_l2_norm_f32;
430
 
@@ -978,6 +979,10 @@ struct ggml_backend_vk_context {
978
 
979
  vk_command_pool compute_cmd_pool;
980
  vk_command_pool transfer_cmd_pool;
 
 
 
 
981
  };
982
 
983
  static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@@ -2655,7 +2660,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2655
 
2656
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2657
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2658
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
 
2659
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2660
  ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2661
 
@@ -6430,7 +6436,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6430
  return nullptr;
6431
  case GGML_OP_RMS_NORM:
6432
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6433
- return ctx->device->pipeline_rms_norm_f32;
6434
  }
6435
  return nullptr;
6436
  case GGML_OP_RMS_NORM_BACK:
@@ -7530,18 +7536,19 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
7530
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
7531
  }
7532
 
7533
- static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7534
  float * op_params = (float *)dst->op_params;
7535
  const uint32_t src0_type_size = ggml_type_size(src0->type);
 
7536
  const uint32_t dst_type_size = ggml_type_size(dst->type);
7537
 
7538
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
7539
  (uint32_t)ggml_nelements(src0),
7540
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7541
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
 
7542
  0,
7543
- op_params[0], 0.0f,
7544
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7545
  }, dryrun);
7546
  }
7547
 
@@ -8736,7 +8743,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* t
8736
 
8737
  // Returns true if node has enqueued work into the queue, false otherwise
8738
  // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
8739
- static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
 
8740
  if (ggml_is_empty(node) || !node->buffer) {
8741
  return false;
8742
  }
@@ -8974,8 +8982,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8974
 
8975
  break;
8976
  case GGML_OP_RMS_NORM:
8977
- ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
8978
-
 
 
 
 
 
 
8979
  break;
8980
  case GGML_OP_RMS_NORM_BACK:
8981
  ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9710,10 +9724,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9710
 
9711
  uint64_t total_mat_mul_bytes = 0;
9712
  for (int i = 0; i < cgraph->n_nodes; i++) {
9713
- ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
 
 
 
9714
  if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
9715
  total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
9716
  }
 
 
9717
  }
9718
  if (ctx->device->need_compiles) {
9719
  ggml_vk_load_shaders(ctx->device);
@@ -9775,14 +9794,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9775
  mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
9776
  }
9777
 
 
 
 
 
9778
  // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
9779
  bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
9780
  bool submit = (submitted_nodes >= nodes_per_submit) ||
9781
  (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
9782
- (i == last_node) ||
9783
  (almost_ready && !ctx->almost_ready_fence_pending);
9784
 
9785
- bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
9786
 
9787
  if (vk_perf_logger_enabled) {
9788
  if (ctx->compute_ctx.expired()) {
@@ -9792,7 +9815,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9792
  } else {
9793
  compute_ctx = ctx->compute_ctx.lock();
9794
  }
9795
- compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
 
 
 
9796
  }
9797
 
9798
  if (enqueued) {
@@ -9814,6 +9840,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9814
  }
9815
  submit_count++;
9816
  }
 
 
9817
  }
9818
 
9819
  if (vk_perf_logger_enabled) {
 
425
  vk_pipeline pipeline_norm_f32;
426
  vk_pipeline pipeline_group_norm_f32;
427
  vk_pipeline pipeline_rms_norm_f32;
428
+ vk_pipeline pipeline_rms_norm_mul_f32;
429
  vk_pipeline pipeline_rms_norm_back_f32;
430
  vk_pipeline pipeline_l2_norm_f32;
431
 
 
979
 
980
  vk_command_pool compute_cmd_pool;
981
  vk_command_pool transfer_cmd_pool;
982
+
983
+ // number of additional consecutive nodes that are being fused with the
984
+ // node currently being processed
985
+ uint32_t num_additional_fused_ops {};
986
  };
987
 
988
  static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
 
2660
 
2661
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2662
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2663
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2664
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
2665
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2666
  ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2667
 
 
6436
  return nullptr;
6437
  case GGML_OP_RMS_NORM:
6438
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6439
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
6440
  }
6441
  return nullptr;
6442
  case GGML_OP_RMS_NORM_BACK:
 
7536
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
7537
  }
7538
 
7539
+ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7540
  float * op_params = (float *)dst->op_params;
7541
  const uint32_t src0_type_size = ggml_type_size(src0->type);
7542
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
7543
  const uint32_t dst_type_size = ggml_type_size(dst->type);
7544
 
7545
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
7546
  (uint32_t)ggml_nelements(src0),
7547
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7548
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7549
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7550
  0,
7551
+ op_params[0], 0.0f, 0,
 
7552
  }, dryrun);
7553
  }
7554
 
 
8743
 
8744
  // Returns true if node has enqueued work into the queue, false otherwise
8745
  // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
8746
+ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
8747
+ ggml_tensor * node = cgraph->nodes[node_idx];
8748
  if (ggml_is_empty(node) || !node->buffer) {
8749
  return false;
8750
  }
 
8982
 
8983
  break;
8984
  case GGML_OP_RMS_NORM:
8985
+ if (ctx->num_additional_fused_ops > 0) {
8986
+ // fused rms_norm + mul
8987
+ ggml_tensor *mul = cgraph->nodes[node_idx + 1];
8988
+ ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
8989
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
8990
+ } else {
8991
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
8992
+ }
8993
  break;
8994
  case GGML_OP_RMS_NORM_BACK:
8995
  ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
 
9724
 
9725
  uint64_t total_mat_mul_bytes = 0;
9726
  for (int i = 0; i < cgraph->n_nodes; i++) {
9727
+ if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9728
+ ctx->num_additional_fused_ops = 1;
9729
+ }
9730
+ ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
9731
  if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
9732
  total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
9733
  }
9734
+ i += ctx->num_additional_fused_ops;
9735
+ ctx->num_additional_fused_ops = 0;
9736
  }
9737
  if (ctx->device->need_compiles) {
9738
  ggml_vk_load_shaders(ctx->device);
 
9794
  mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
9795
  }
9796
 
9797
+ if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9798
+ ctx->num_additional_fused_ops = 1;
9799
+ }
9800
+
9801
  // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
9802
  bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
9803
  bool submit = (submitted_nodes >= nodes_per_submit) ||
9804
  (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
9805
+ (i + ctx->num_additional_fused_ops == last_node) ||
9806
  (almost_ready && !ctx->almost_ready_fence_pending);
9807
 
9808
+ bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
9809
 
9810
  if (vk_perf_logger_enabled) {
9811
  if (ctx->compute_ctx.expired()) {
 
9815
  } else {
9816
  compute_ctx = ctx->compute_ctx.lock();
9817
  }
9818
+ // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
9819
+ for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
9820
+ compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
9821
+ }
9822
  }
9823
 
9824
  if (enqueued) {
 
9840
  }
9841
  submit_count++;
9842
  }
9843
+ i += ctx->num_additional_fused_ops;
9844
+ ctx->num_additional_fused_ops = 0;
9845
  }
9846
 
9847
  if (vk_perf_logger_enabled) {
ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp CHANGED
@@ -1,11 +1,13 @@
1
  #version 450
2
 
3
- #include "generic_unary_head.comp"
4
  #include "types.comp"
5
 
6
  #extension GL_EXT_control_flow_attributes : enable
7
  #define BLOCK_SIZE 512
8
 
 
 
9
  layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
10
 
11
  shared FLOAT_TYPE sum[BLOCK_SIZE];
@@ -25,6 +27,7 @@ void main() {
25
  const uint stride_sample = p.nb03;
26
 
27
  uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
 
28
  uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
29
 
30
  sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
@@ -46,7 +49,13 @@ void main() {
46
  const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
47
  const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
48
 
49
- [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
50
- data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
 
 
 
 
 
 
51
  }
52
  }
 
1
  #version 450
2
 
3
+ #include "generic_binary_head.comp"
4
  #include "types.comp"
5
 
6
  #extension GL_EXT_control_flow_attributes : enable
7
  #define BLOCK_SIZE 512
8
 
9
+ layout (constant_id = 1) const bool do_multiply = false;
10
+
11
  layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
12
 
13
  shared FLOAT_TYPE sum[BLOCK_SIZE];
 
27
  const uint stride_sample = p.nb03;
28
 
29
  uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
30
+ uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
31
  uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
32
 
33
  sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
 
49
  const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
50
  const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
51
 
52
+ if (do_multiply) {
53
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
54
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
55
+ }
56
+ } else {
57
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
58
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
59
+ }
60
  }
61
  }
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -497,7 +497,7 @@ void process_shaders() {
497
  // Norms
498
  string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
499
  string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
500
- string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
501
  string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
502
  string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
503
 
 
497
  // Norms
498
  string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
499
  string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
500
+ string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
501
  string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
502
  string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
503
 
ggml/src/ggml.c CHANGED
@@ -5850,19 +5850,32 @@ static void ggml_compute_backward(
5850
  GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
5851
  }
5852
 
5853
- static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
5854
  // check if already visited
5855
- if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
5856
- return;
 
 
 
 
 
 
 
 
5857
  }
5858
 
5859
  for (int i = 0; i < GGML_MAX_SRC; ++i) {
5860
  const int k =
5861
  (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
5862
  (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
5863
- /* unknown order, just fall back to using i*/ i;
5864
- if (node->src[k]) {
5865
- ggml_visit_parents(cgraph, node->src[k]);
 
 
 
 
 
5866
  }
5867
  }
5868
 
@@ -5886,6 +5899,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
5886
  cgraph->nodes[cgraph->n_nodes] = node;
5887
  cgraph->n_nodes++;
5888
  }
 
 
5889
  }
5890
 
5891
  static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
@@ -6023,6 +6038,7 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) {
6023
  incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
6024
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
6025
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
 
6026
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
6027
  if (grads) {
6028
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
@@ -6052,11 +6068,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
6052
 
6053
  void * p = cgraph + 1;
6054
 
6055
- struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6056
- struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6057
- struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6058
- struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6059
- struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
 
6060
 
6061
  ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
6062
 
@@ -6071,6 +6088,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
6071
  /*.grads =*/ grads_ptr,
6072
  /*.grad_accs =*/ grad_accs_ptr,
6073
  /*.leafs =*/ leafs_ptr,
 
6074
  /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
6075
  /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
6076
  };
@@ -6097,7 +6115,8 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
6097
  /*.grads =*/ NULL, // gradients would need visited_hash_set
6098
  /*.grad_accs =*/ NULL,
6099
  /*.leafs =*/ NULL,
6100
- /*.visited_hash_set =*/ { 0, NULL, NULL },
 
6101
  /*.order =*/ cgraph0->order,
6102
  };
6103
 
@@ -6124,7 +6143,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
6124
  for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
6125
  // copy all hashset keys (tensors) that are in use
6126
  if (ggml_bitset_get(src->visited_hash_set.used, i)) {
6127
- ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
 
6128
  }
6129
  }
6130
 
 
5850
  GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
5851
  }
5852
 
5853
+ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
5854
  // check if already visited
5855
+ size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
5856
+ GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
5857
+ if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
5858
+ // This is the first time we see this node in the current graph.
5859
+ cgraph->visited_hash_set.keys[node_hash_pos] = node;
5860
+ ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
5861
+ cgraph->use_counts[node_hash_pos] = 0;
5862
+ } else {
5863
+ // already visited
5864
+ return node_hash_pos;
5865
  }
5866
 
5867
  for (int i = 0; i < GGML_MAX_SRC; ++i) {
5868
  const int k =
5869
  (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
5870
  (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
5871
+ /* unknown order, just fall back to using i */ i;
5872
+
5873
+ struct ggml_tensor * src = node->src[k];
5874
+ if (src) {
5875
+ size_t src_hash_pos = ggml_visit_parents(cgraph, src);
5876
+
5877
+ // Update the use count for this operand.
5878
+ cgraph->use_counts[src_hash_pos]++;
5879
  }
5880
  }
5881
 
 
5899
  cgraph->nodes[cgraph->n_nodes] = node;
5900
  cgraph->n_nodes++;
5901
  }
5902
+
5903
+ return node_hash_pos;
5904
  }
5905
 
5906
  static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
 
6038
  incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
6039
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
6040
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
6041
+ incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
6042
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
6043
  if (grads) {
6044
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
 
6068
 
6069
  void * p = cgraph + 1;
6070
 
6071
+ struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6072
+ struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6073
+ int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
6074
+ struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6075
+ struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6076
+ struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6077
 
6078
  ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
6079
 
 
6088
  /*.grads =*/ grads_ptr,
6089
  /*.grad_accs =*/ grad_accs_ptr,
6090
  /*.leafs =*/ leafs_ptr,
6091
+ /*.use_counts =*/ use_counts_ptr,
6092
  /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
6093
  /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
6094
  };
 
6115
  /*.grads =*/ NULL, // gradients would need visited_hash_set
6116
  /*.grad_accs =*/ NULL,
6117
  /*.leafs =*/ NULL,
6118
+ /*.use_counts =*/ cgraph0->use_counts,
6119
+ /*.visited_hash_set =*/ cgraph0->visited_hash_set,
6120
  /*.order =*/ cgraph0->order,
6121
  };
6122
 
 
6143
  for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
6144
  // copy all hashset keys (tensors) that are in use
6145
  if (ggml_bitset_get(src->visited_hash_set.used, i)) {
6146
+ size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6147
+ dst->use_counts[new_hash_pos] = src->use_counts[i];
6148
  }
6149
  }
6150