Chenguang Li commited on
Commit
137a0dc
·
1 Parent(s): bf225d6

CANN: add support for ACL Graph (llama/15065)

Browse files

* feat(cann): add optional support for ACL Graph execution

This commit adds support for executing ggml computational graphs using
Huawei's ACL graph mode via the USE_CANN_GRAPH flag. The support can be
enabled at compile time using the CMake option:

-DUSE_CANN_GRAPH=ON

By default, ACL graph execution is **disabled**, and the fallback path
uses node-by-node execution.

Key additions:
- CMake option to toggle graph mode
- Graph capture and execution logic using
- Tensor property matching to determine whether graph update is required
- Safe fallback and logging if the environment variable LLAMA_SET_ROWS
is unset or invalid

This prepares the backend for performance improvements in repetitive graph
execution scenarios on Ascend devices.

Signed-off-by: noemotiovon <[email protected]>

* Fix review comments

Signed-off-by: noemotiovon <[email protected]>

* remane USE_CANN_GRAPH to USE_ACL_GRAPH

Signed-off-by: noemotiovon <[email protected]>

* fix typo

Signed-off-by: noemotiovon <[email protected]>

---------

Signed-off-by: noemotiovon <[email protected]>

ggml/src/ggml-cann/CMakeLists.txt CHANGED
@@ -31,6 +31,13 @@ string(REGEX MATCH "[0-9]+[a-zA-Z]" SOC_TYPE_MAJOR_SN "${SOC_VERSION}")
31
  set(SOC_TYPE_COMPILE_OPTION "ASCEND_${SOC_TYPE_MAJOR_SN}")
32
  string(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION)
33
  message(STATUS "CANN: SOC_VERSION = ${SOC_VERSION}")
 
 
 
 
 
 
 
34
 
35
  if (CANN_INSTALL_DIR)
36
  # Only Support Linux.
@@ -68,6 +75,13 @@ if (CANN_INSTALL_DIR)
68
 
69
  target_compile_definitions(ggml-cann PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}")
70
 
 
 
 
 
 
 
 
71
  message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}")
72
  message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}")
73
  else()
 
31
  set(SOC_TYPE_COMPILE_OPTION "ASCEND_${SOC_TYPE_MAJOR_SN}")
32
  string(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION)
33
  message(STATUS "CANN: SOC_VERSION = ${SOC_VERSION}")
34
+ option(USE_ACL_GRAPH "Enable CANN graph execution (ACL graph mode)" OFF)
35
+
36
+ if(USE_ACL_GRAPH AND (SOC_TYPE_MAJOR_SN STREQUAL "310P" OR SOC_TYPE_COMPILE_OPTION STREQUAL "ASCEND_310P"))
37
+ message(FATAL_ERROR
38
+ "CANN Graph (ACL graph mode) is not supported on 310P devices. "
39
+ "Please build with -DUSE_ACL_GRAPH=OFF or use a supported SOC.")
40
+ endif()
41
 
42
  if (CANN_INSTALL_DIR)
43
  # Only Support Linux.
 
75
 
76
  target_compile_definitions(ggml-cann PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}")
77
 
78
+ if (USE_ACL_GRAPH)
79
+ target_compile_definitions(ggml-cann PRIVATE USE_ACL_GRAPH)
80
+ message(STATUS "CANN: USE_ACL_GRAPH is enabled.")
81
+ else()
82
+ message(STATUS "CANN: USE_ACL_GRAPH is disabled.")
83
+ endif()
84
+
85
  message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}")
86
  message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}")
87
  else()
ggml/src/ggml-cann/common.h CHANGED
@@ -337,6 +337,29 @@ private:
337
  int32_t device_;
338
  };
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  /**
341
  * @brief Context for managing CANN backend operations.
342
  */
@@ -345,8 +368,13 @@ struct ggml_backend_cann_context {
345
  std::string name; /**< Name of the device. */
346
  std::string description; /**< Description of the device. */
347
  aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
 
 
 
 
348
  cann_task_queue task_queue;
349
  bool async_mode;
 
350
 
351
  aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
352
 
@@ -362,6 +390,14 @@ struct ggml_backend_cann_context {
362
  async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
363
  GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
364
  device, async_mode ? "ON" : "OFF");
 
 
 
 
 
 
 
 
365
  }
366
 
367
  /**
 
337
  int32_t device_;
338
  };
339
 
340
+ #ifdef USE_ACL_GRAPH
341
+ struct ggml_graph_node_properties {
342
+ void * node_address;
343
+ ggml_op node_op;
344
+ int64_t ne[GGML_MAX_DIMS];
345
+ size_t nb[GGML_MAX_DIMS];
346
+ void * src_address[GGML_MAX_SRC];
347
+ int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
348
+ };
349
+
350
+ struct ggml_cann_graph {
351
+ ~ggml_cann_graph() {
352
+ if (graph != nullptr) {
353
+ aclmdlRIDestroy(graph);
354
+ }
355
+ }
356
+
357
+ aclmdlRI graph = nullptr;
358
+
359
+ std::vector<ggml_graph_node_properties> ggml_graph_properties;
360
+ };
361
+ #endif // USE_ACL_GRAPH
362
+
363
  /**
364
  * @brief Context for managing CANN backend operations.
365
  */
 
368
  std::string name; /**< Name of the device. */
369
  std::string description; /**< Description of the device. */
370
  aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
371
+ #ifdef USE_ACL_GRAPH
372
+ /// Cached CANN ACL graph used for executing the current ggml computation graph.
373
+ std::unique_ptr<ggml_cann_graph> cann_graph;
374
+ #endif
375
  cann_task_queue task_queue;
376
  bool async_mode;
377
+ bool support_set_rows;
378
 
379
  aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
380
 
 
390
  async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
391
  GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
392
  device, async_mode ? "ON" : "OFF");
393
+
394
+ support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or(""));
395
+ GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF");
396
+
397
+ if (!support_set_rows) {
398
+ GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. "
399
+ "Falling back to eager mode.\n", __func__);
400
+ }
401
  }
402
 
403
  /**
ggml/src/ggml-cann/ggml-cann.cpp CHANGED
@@ -2075,6 +2075,160 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
2075
  ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
2076
  }
2077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2078
  /**
2079
  * @brief Computes a computational graph using a CANN backend.
2080
  *
@@ -2091,26 +2245,37 @@ static enum ggml_status ggml_backend_cann_graph_compute(
2091
  ggml_backend_t backend, ggml_cgraph* cgraph) {
2092
  ggml_backend_cann_context* cann_ctx =
2093
  (ggml_backend_cann_context*)backend->context;
2094
-
2095
  ggml_cann_set_device(cann_ctx->device);
2096
- //release temp buffer create by set tensor.
2097
  release_nz_workspace();
 
 
 
2098
 
2099
- for (int i = 0; i < cgraph->n_nodes; i++) {
2100
- ggml_tensor* node = cgraph->nodes[i];
 
 
2101
 
2102
- if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
2103
- continue;
 
 
2104
  }
2105
 
2106
- bool ok = ggml_cann_compute_forward(*cann_ctx, node);
2107
-
2108
- if (!ok) {
2109
- GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
2110
- node->name, ggml_op_name(node->op));
2111
- }
2112
- GGML_ASSERT(ok);
2113
  }
 
 
 
 
 
 
 
 
 
 
 
2114
 
2115
  return GGML_STATUS_SUCCESS;
2116
  }
@@ -2226,12 +2391,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2226
  // only support F32 and F16.
2227
  return false;
2228
  }
2229
-
2230
- if (!ggml_are_same_shape(op, src) && !ggml_is_contiguous(op)) {
2231
- // unsupport dst is not contiguous.
2232
- return false;
2233
- }
2234
-
2235
  return true;
2236
  } break;
2237
  case GGML_OP_CONT: {
 
2075
  ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
2076
  }
2077
 
2078
+ #ifdef USE_ACL_GRAPH
2079
+ /**
2080
+ * @brief Populate the internal CANN graph node properties from the ggml computation graph.
2081
+ *
2082
+ * This function copies all node attributes (operation type, dimensions, strides, input sources,
2083
+ * and operation parameters) into the cached CANN graph structure for later reuse or comparison.
2084
+ *
2085
+ * @param cann_ctx The CANN backend context.
2086
+ * @param cgraph The ggml computational graph.
2087
+ */
2088
+ static void set_ggml_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
2089
+ for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
2090
+ ggml_tensor * node = cgraph->nodes[node_idx];
2091
+ cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_address = node->data;
2092
+ cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_op = node->op;
2093
+
2094
+ for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
2095
+ cann_ctx->cann_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim];
2096
+ cann_ctx->cann_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim];
2097
+ }
2098
+ for (int src = 0; src < GGML_MAX_SRC; src++) {
2099
+ cann_ctx->cann_graph->ggml_graph_properties[node_idx].src_address[src] =
2100
+ node->src[src] ? node->src[src]->data : nullptr;
2101
+ }
2102
+ memcpy(cann_ctx->cann_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS);
2103
+ }
2104
+ }
2105
+
2106
+ /**
2107
+ * @brief Check if a ggml tensor node matches a previously captured CANN graph node.
2108
+ *
2109
+ * This function compares all relevant fields (address, op type, shape, source inputs, op params)
2110
+ * to determine whether the current node matches a previously recorded version.
2111
+ *
2112
+ * @param node The current ggml tensor node.
2113
+ * @param graph_node_properties The stored properties of a CANN graph node.
2114
+ * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
2115
+ */
2116
+ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2117
+ if (node->data != graph_node_properties->node_address &&
2118
+ node->op != GGML_OP_VIEW) {
2119
+ return false;
2120
+ }
2121
+ if (node->op != graph_node_properties->node_op) {
2122
+ return false;
2123
+ }
2124
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
2125
+ if (node->ne[i] != graph_node_properties->ne[i]) {
2126
+ return false;
2127
+ }
2128
+ if (node->nb[i] != graph_node_properties->nb[i]) {
2129
+ return false;
2130
+ }
2131
+ }
2132
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
2133
+ if (node->src[i] &&
2134
+ node->src[i]->data != graph_node_properties->src_address[i] &&
2135
+ node->op != GGML_OP_VIEW
2136
+ ) {
2137
+ return false;
2138
+ }
2139
+ }
2140
+ if (node->op == GGML_OP_SCALE &&
2141
+ memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2142
+ return false;
2143
+ }
2144
+ return true;
2145
+ }
2146
+
2147
+ /**
2148
+ * @brief Determine if the CANN graph needs to be rebuilt due to graph changes.
2149
+ *
2150
+ * This checks whether the number or properties of ggml graph nodes have changed
2151
+ * compared to the last captured CANN graph. If so, the CANN graph must be re-captured.
2152
+ *
2153
+ * @param cann_ctx The CANN backend context.
2154
+ * @param cgraph The current ggml computation graph.
2155
+ * @return true if an update is required; false otherwise.
2156
+ */
2157
+ static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
2158
+ // The number of nodes is different, so the graph needs to be reconstructed.
2159
+ if (cann_ctx->cann_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
2160
+ cann_ctx->cann_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2161
+ return true;
2162
+ }
2163
+
2164
+ // The number of nodes is the same; iterate over each node to check whether they match.
2165
+ for (int i = 0; i < cgraph->n_nodes; i++) {
2166
+ bool has_matching_properties = ggml_graph_node_has_matching_properties(
2167
+ cgraph->nodes[i], &cann_ctx->cann_graph->ggml_graph_properties[i]);
2168
+ if(!has_matching_properties) {
2169
+ return true;
2170
+ }
2171
+ }
2172
+ return false;
2173
+ }
2174
+ #endif // USE_ACL_GRAPH
2175
+
2176
+ /**
2177
+ * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
2178
+ *
2179
+ * If CANN graph execution is enabled and graph capture is required, this function begins
2180
+ * graph capture, runs the graph, ends capture, and stores the captured graph.
2181
+ *
2182
+ * Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
2183
+ *
2184
+ * @param cann_ctx The CANN backend context.
2185
+ * @param cgraph The ggml computation graph.
2186
+ * @param use_cann_graph Whether to use CANN graph execution.
2187
+ * @param cann_graph_update_required Whether graph capture is needed due to graph changes.
2188
+ */
2189
+ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph,
2190
+ bool & use_cann_graph, bool & cann_graph_update_required) {
2191
+ #ifdef USE_ACL_GRAPH
2192
+ if (use_cann_graph && cann_graph_update_required) {
2193
+ if (cann_ctx->cann_graph->graph != nullptr) {
2194
+ ACL_CHECK(aclmdlRIDestroy(cann_ctx->cann_graph->graph));
2195
+ cann_ctx->cann_graph->graph = nullptr;
2196
+ }
2197
+ ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
2198
+ }
2199
+ #endif // USE_ACL_GRAPH
2200
+
2201
+ // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
2202
+ // With the use of CANN graphs, the execution will be performed by the graph launch.
2203
+ if (!use_cann_graph || cann_graph_update_required) {
2204
+ for (int i = 0; i < cgraph->n_nodes; i++) {
2205
+ ggml_tensor * node = cgraph->nodes[i];
2206
+
2207
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2208
+ continue;
2209
+ }
2210
+
2211
+ bool ok = ggml_cann_compute_forward(*cann_ctx, node);
2212
+ if (!ok) {
2213
+ GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
2214
+ }
2215
+ GGML_ASSERT(ok);
2216
+ }
2217
+ }
2218
+
2219
+ #ifdef USE_ACL_GRAPH
2220
+ if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
2221
+ ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->cann_graph->graph));
2222
+ }
2223
+
2224
+ if (use_cann_graph) {
2225
+ // Execute graph
2226
+ ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream()));
2227
+ }
2228
+ #endif // USE_ACL_GRAPH
2229
+ }
2230
+
2231
+
2232
  /**
2233
  * @brief Computes a computational graph using a CANN backend.
2234
  *
 
2245
  ggml_backend_t backend, ggml_cgraph* cgraph) {
2246
  ggml_backend_cann_context* cann_ctx =
2247
  (ggml_backend_cann_context*)backend->context;
 
2248
  ggml_cann_set_device(cann_ctx->device);
 
2249
  release_nz_workspace();
2250
+ #ifdef USE_ACL_GRAPH
2251
+ bool use_cann_graph = true;
2252
+ bool cann_graph_update_required = false;
2253
 
2254
+ // check environment LLAMA_SET_ROWS
2255
+ if (!cann_ctx->support_set_rows) {
2256
+ use_cann_graph = false;
2257
+ }
2258
 
2259
+ if (use_cann_graph) {
2260
+ if (cann_ctx->cann_graph == nullptr) {
2261
+ cann_ctx->cann_graph.reset(new ggml_cann_graph());
2262
+ cann_graph_update_required = true;
2263
  }
2264
 
2265
+ cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph);
2266
+ set_ggml_graph_node_properties(cann_ctx, cgraph);
 
 
 
 
 
2267
  }
2268
+ #else
2269
+ bool use_cann_graph = false;
2270
+ bool cann_graph_update_required = false;
2271
+ #endif // USE_ACL_GRAPH
2272
+
2273
+ evaluate_and_capture_cann_graph(
2274
+ cann_ctx,
2275
+ cgraph,
2276
+ use_cann_graph,
2277
+ cann_graph_update_required
2278
+ );
2279
 
2280
  return GGML_STATUS_SUCCESS;
2281
  }
 
2391
  // only support F32 and F16.
2392
  return false;
2393
  }
 
 
 
 
 
 
2394
  return true;
2395
  } break;
2396
  case GGML_OP_CONT: {