Spaces:
Running
CANN: add support for ACL Graph (llama/15065)
Browse files* feat(cann): add optional support for ACL Graph execution
This commit adds support for executing ggml computational graphs using
Huawei's ACL graph mode via the USE_CANN_GRAPH flag. The support can be
enabled at compile time using the CMake option:
-DUSE_CANN_GRAPH=ON
By default, ACL graph execution is **disabled**, and the fallback path
uses node-by-node execution.
Key additions:
- CMake option to toggle graph mode
- Graph capture and execution logic using
- Tensor property matching to determine whether graph update is required
- Safe fallback and logging if the environment variable LLAMA_SET_ROWS
is unset or invalid
This prepares the backend for performance improvements in repetitive graph
execution scenarios on Ascend devices.
Signed-off-by: noemotiovon <[email protected]>
* Fix review comments
Signed-off-by: noemotiovon <[email protected]>
* remane USE_CANN_GRAPH to USE_ACL_GRAPH
Signed-off-by: noemotiovon <[email protected]>
* fix typo
Signed-off-by: noemotiovon <[email protected]>
---------
Signed-off-by: noemotiovon <[email protected]>
- ggml/src/ggml-cann/CMakeLists.txt +14 -0
- ggml/src/ggml-cann/common.h +36 -0
- ggml/src/ggml-cann/ggml-cann.cpp +178 -19
|
@@ -31,6 +31,13 @@ string(REGEX MATCH "[0-9]+[a-zA-Z]" SOC_TYPE_MAJOR_SN "${SOC_VERSION}")
|
|
| 31 |
set(SOC_TYPE_COMPILE_OPTION "ASCEND_${SOC_TYPE_MAJOR_SN}")
|
| 32 |
string(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION)
|
| 33 |
message(STATUS "CANN: SOC_VERSION = ${SOC_VERSION}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
if (CANN_INSTALL_DIR)
|
| 36 |
# Only Support Linux.
|
|
@@ -68,6 +75,13 @@ if (CANN_INSTALL_DIR)
|
|
| 68 |
|
| 69 |
target_compile_definitions(ggml-cann PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}")
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}")
|
| 72 |
message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}")
|
| 73 |
else()
|
|
|
|
| 31 |
set(SOC_TYPE_COMPILE_OPTION "ASCEND_${SOC_TYPE_MAJOR_SN}")
|
| 32 |
string(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION)
|
| 33 |
message(STATUS "CANN: SOC_VERSION = ${SOC_VERSION}")
|
| 34 |
+
option(USE_ACL_GRAPH "Enable CANN graph execution (ACL graph mode)" OFF)
|
| 35 |
+
|
| 36 |
+
if(USE_ACL_GRAPH AND (SOC_TYPE_MAJOR_SN STREQUAL "310P" OR SOC_TYPE_COMPILE_OPTION STREQUAL "ASCEND_310P"))
|
| 37 |
+
message(FATAL_ERROR
|
| 38 |
+
"CANN Graph (ACL graph mode) is not supported on 310P devices. "
|
| 39 |
+
"Please build with -DUSE_ACL_GRAPH=OFF or use a supported SOC.")
|
| 40 |
+
endif()
|
| 41 |
|
| 42 |
if (CANN_INSTALL_DIR)
|
| 43 |
# Only Support Linux.
|
|
|
|
| 75 |
|
| 76 |
target_compile_definitions(ggml-cann PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}")
|
| 77 |
|
| 78 |
+
if (USE_ACL_GRAPH)
|
| 79 |
+
target_compile_definitions(ggml-cann PRIVATE USE_ACL_GRAPH)
|
| 80 |
+
message(STATUS "CANN: USE_ACL_GRAPH is enabled.")
|
| 81 |
+
else()
|
| 82 |
+
message(STATUS "CANN: USE_ACL_GRAPH is disabled.")
|
| 83 |
+
endif()
|
| 84 |
+
|
| 85 |
message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}")
|
| 86 |
message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}")
|
| 87 |
else()
|
|
@@ -337,6 +337,29 @@ private:
|
|
| 337 |
int32_t device_;
|
| 338 |
};
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
/**
|
| 341 |
* @brief Context for managing CANN backend operations.
|
| 342 |
*/
|
|
@@ -345,8 +368,13 @@ struct ggml_backend_cann_context {
|
|
| 345 |
std::string name; /**< Name of the device. */
|
| 346 |
std::string description; /**< Description of the device. */
|
| 347 |
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
cann_task_queue task_queue;
|
| 349 |
bool async_mode;
|
|
|
|
| 350 |
|
| 351 |
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
|
| 352 |
|
|
@@ -362,6 +390,14 @@ struct ggml_backend_cann_context {
|
|
| 362 |
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
|
| 363 |
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
|
| 364 |
device, async_mode ? "ON" : "OFF");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
}
|
| 366 |
|
| 367 |
/**
|
|
|
|
| 337 |
int32_t device_;
|
| 338 |
};
|
| 339 |
|
| 340 |
+
#ifdef USE_ACL_GRAPH
|
| 341 |
+
struct ggml_graph_node_properties {
|
| 342 |
+
void * node_address;
|
| 343 |
+
ggml_op node_op;
|
| 344 |
+
int64_t ne[GGML_MAX_DIMS];
|
| 345 |
+
size_t nb[GGML_MAX_DIMS];
|
| 346 |
+
void * src_address[GGML_MAX_SRC];
|
| 347 |
+
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
| 348 |
+
};
|
| 349 |
+
|
| 350 |
+
struct ggml_cann_graph {
|
| 351 |
+
~ggml_cann_graph() {
|
| 352 |
+
if (graph != nullptr) {
|
| 353 |
+
aclmdlRIDestroy(graph);
|
| 354 |
+
}
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
aclmdlRI graph = nullptr;
|
| 358 |
+
|
| 359 |
+
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
| 360 |
+
};
|
| 361 |
+
#endif // USE_ACL_GRAPH
|
| 362 |
+
|
| 363 |
/**
|
| 364 |
* @brief Context for managing CANN backend operations.
|
| 365 |
*/
|
|
|
|
| 368 |
std::string name; /**< Name of the device. */
|
| 369 |
std::string description; /**< Description of the device. */
|
| 370 |
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
|
| 371 |
+
#ifdef USE_ACL_GRAPH
|
| 372 |
+
/// Cached CANN ACL graph used for executing the current ggml computation graph.
|
| 373 |
+
std::unique_ptr<ggml_cann_graph> cann_graph;
|
| 374 |
+
#endif
|
| 375 |
cann_task_queue task_queue;
|
| 376 |
bool async_mode;
|
| 377 |
+
bool support_set_rows;
|
| 378 |
|
| 379 |
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
|
| 380 |
|
|
|
|
| 390 |
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
|
| 391 |
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
|
| 392 |
device, async_mode ? "ON" : "OFF");
|
| 393 |
+
|
| 394 |
+
support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or(""));
|
| 395 |
+
GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF");
|
| 396 |
+
|
| 397 |
+
if (!support_set_rows) {
|
| 398 |
+
GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. "
|
| 399 |
+
"Falling back to eager mode.\n", __func__);
|
| 400 |
+
}
|
| 401 |
}
|
| 402 |
|
| 403 |
/**
|
|
@@ -2075,6 +2075,160 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
|
|
| 2075 |
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
| 2076 |
}
|
| 2077 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2078 |
/**
|
| 2079 |
* @brief Computes a computational graph using a CANN backend.
|
| 2080 |
*
|
|
@@ -2091,26 +2245,37 @@ static enum ggml_status ggml_backend_cann_graph_compute(
|
|
| 2091 |
ggml_backend_t backend, ggml_cgraph* cgraph) {
|
| 2092 |
ggml_backend_cann_context* cann_ctx =
|
| 2093 |
(ggml_backend_cann_context*)backend->context;
|
| 2094 |
-
|
| 2095 |
ggml_cann_set_device(cann_ctx->device);
|
| 2096 |
-
//release temp buffer create by set tensor.
|
| 2097 |
release_nz_workspace();
|
|
|
|
|
|
|
|
|
|
| 2098 |
|
| 2099 |
-
|
| 2100 |
-
|
|
|
|
|
|
|
| 2101 |
|
| 2102 |
-
|
| 2103 |
-
|
|
|
|
|
|
|
| 2104 |
}
|
| 2105 |
|
| 2106 |
-
|
| 2107 |
-
|
| 2108 |
-
if (!ok) {
|
| 2109 |
-
GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
|
| 2110 |
-
node->name, ggml_op_name(node->op));
|
| 2111 |
-
}
|
| 2112 |
-
GGML_ASSERT(ok);
|
| 2113 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2114 |
|
| 2115 |
return GGML_STATUS_SUCCESS;
|
| 2116 |
}
|
|
@@ -2226,12 +2391,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
|
| 2226 |
// only support F32 and F16.
|
| 2227 |
return false;
|
| 2228 |
}
|
| 2229 |
-
|
| 2230 |
-
if (!ggml_are_same_shape(op, src) && !ggml_is_contiguous(op)) {
|
| 2231 |
-
// unsupport dst is not contiguous.
|
| 2232 |
-
return false;
|
| 2233 |
-
}
|
| 2234 |
-
|
| 2235 |
return true;
|
| 2236 |
} break;
|
| 2237 |
case GGML_OP_CONT: {
|
|
|
|
| 2075 |
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
| 2076 |
}
|
| 2077 |
|
| 2078 |
+
#ifdef USE_ACL_GRAPH
|
| 2079 |
+
/**
|
| 2080 |
+
* @brief Populate the internal CANN graph node properties from the ggml computation graph.
|
| 2081 |
+
*
|
| 2082 |
+
* This function copies all node attributes (operation type, dimensions, strides, input sources,
|
| 2083 |
+
* and operation parameters) into the cached CANN graph structure for later reuse or comparison.
|
| 2084 |
+
*
|
| 2085 |
+
* @param cann_ctx The CANN backend context.
|
| 2086 |
+
* @param cgraph The ggml computational graph.
|
| 2087 |
+
*/
|
| 2088 |
+
static void set_ggml_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
|
| 2089 |
+
for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
|
| 2090 |
+
ggml_tensor * node = cgraph->nodes[node_idx];
|
| 2091 |
+
cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_address = node->data;
|
| 2092 |
+
cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_op = node->op;
|
| 2093 |
+
|
| 2094 |
+
for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
|
| 2095 |
+
cann_ctx->cann_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim];
|
| 2096 |
+
cann_ctx->cann_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim];
|
| 2097 |
+
}
|
| 2098 |
+
for (int src = 0; src < GGML_MAX_SRC; src++) {
|
| 2099 |
+
cann_ctx->cann_graph->ggml_graph_properties[node_idx].src_address[src] =
|
| 2100 |
+
node->src[src] ? node->src[src]->data : nullptr;
|
| 2101 |
+
}
|
| 2102 |
+
memcpy(cann_ctx->cann_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
| 2103 |
+
}
|
| 2104 |
+
}
|
| 2105 |
+
|
| 2106 |
+
/**
|
| 2107 |
+
* @brief Check if a ggml tensor node matches a previously captured CANN graph node.
|
| 2108 |
+
*
|
| 2109 |
+
* This function compares all relevant fields (address, op type, shape, source inputs, op params)
|
| 2110 |
+
* to determine whether the current node matches a previously recorded version.
|
| 2111 |
+
*
|
| 2112 |
+
* @param node The current ggml tensor node.
|
| 2113 |
+
* @param graph_node_properties The stored properties of a CANN graph node.
|
| 2114 |
+
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
|
| 2115 |
+
*/
|
| 2116 |
+
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
| 2117 |
+
if (node->data != graph_node_properties->node_address &&
|
| 2118 |
+
node->op != GGML_OP_VIEW) {
|
| 2119 |
+
return false;
|
| 2120 |
+
}
|
| 2121 |
+
if (node->op != graph_node_properties->node_op) {
|
| 2122 |
+
return false;
|
| 2123 |
+
}
|
| 2124 |
+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
| 2125 |
+
if (node->ne[i] != graph_node_properties->ne[i]) {
|
| 2126 |
+
return false;
|
| 2127 |
+
}
|
| 2128 |
+
if (node->nb[i] != graph_node_properties->nb[i]) {
|
| 2129 |
+
return false;
|
| 2130 |
+
}
|
| 2131 |
+
}
|
| 2132 |
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
| 2133 |
+
if (node->src[i] &&
|
| 2134 |
+
node->src[i]->data != graph_node_properties->src_address[i] &&
|
| 2135 |
+
node->op != GGML_OP_VIEW
|
| 2136 |
+
) {
|
| 2137 |
+
return false;
|
| 2138 |
+
}
|
| 2139 |
+
}
|
| 2140 |
+
if (node->op == GGML_OP_SCALE &&
|
| 2141 |
+
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
| 2142 |
+
return false;
|
| 2143 |
+
}
|
| 2144 |
+
return true;
|
| 2145 |
+
}
|
| 2146 |
+
|
| 2147 |
+
/**
|
| 2148 |
+
* @brief Determine if the CANN graph needs to be rebuilt due to graph changes.
|
| 2149 |
+
*
|
| 2150 |
+
* This checks whether the number or properties of ggml graph nodes have changed
|
| 2151 |
+
* compared to the last captured CANN graph. If so, the CANN graph must be re-captured.
|
| 2152 |
+
*
|
| 2153 |
+
* @param cann_ctx The CANN backend context.
|
| 2154 |
+
* @param cgraph The current ggml computation graph.
|
| 2155 |
+
* @return true if an update is required; false otherwise.
|
| 2156 |
+
*/
|
| 2157 |
+
static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
|
| 2158 |
+
// The number of nodes is different, so the graph needs to be reconstructed.
|
| 2159 |
+
if (cann_ctx->cann_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
|
| 2160 |
+
cann_ctx->cann_graph->ggml_graph_properties.resize(cgraph->n_nodes);
|
| 2161 |
+
return true;
|
| 2162 |
+
}
|
| 2163 |
+
|
| 2164 |
+
// The number of nodes is the same; iterate over each node to check whether they match.
|
| 2165 |
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 2166 |
+
bool has_matching_properties = ggml_graph_node_has_matching_properties(
|
| 2167 |
+
cgraph->nodes[i], &cann_ctx->cann_graph->ggml_graph_properties[i]);
|
| 2168 |
+
if(!has_matching_properties) {
|
| 2169 |
+
return true;
|
| 2170 |
+
}
|
| 2171 |
+
}
|
| 2172 |
+
return false;
|
| 2173 |
+
}
|
| 2174 |
+
#endif // USE_ACL_GRAPH
|
| 2175 |
+
|
| 2176 |
+
/**
|
| 2177 |
+
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
|
| 2178 |
+
*
|
| 2179 |
+
* If CANN graph execution is enabled and graph capture is required, this function begins
|
| 2180 |
+
* graph capture, runs the graph, ends capture, and stores the captured graph.
|
| 2181 |
+
*
|
| 2182 |
+
* Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
|
| 2183 |
+
*
|
| 2184 |
+
* @param cann_ctx The CANN backend context.
|
| 2185 |
+
* @param cgraph The ggml computation graph.
|
| 2186 |
+
* @param use_cann_graph Whether to use CANN graph execution.
|
| 2187 |
+
* @param cann_graph_update_required Whether graph capture is needed due to graph changes.
|
| 2188 |
+
*/
|
| 2189 |
+
static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph,
|
| 2190 |
+
bool & use_cann_graph, bool & cann_graph_update_required) {
|
| 2191 |
+
#ifdef USE_ACL_GRAPH
|
| 2192 |
+
if (use_cann_graph && cann_graph_update_required) {
|
| 2193 |
+
if (cann_ctx->cann_graph->graph != nullptr) {
|
| 2194 |
+
ACL_CHECK(aclmdlRIDestroy(cann_ctx->cann_graph->graph));
|
| 2195 |
+
cann_ctx->cann_graph->graph = nullptr;
|
| 2196 |
+
}
|
| 2197 |
+
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
|
| 2198 |
+
}
|
| 2199 |
+
#endif // USE_ACL_GRAPH
|
| 2200 |
+
|
| 2201 |
+
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
|
| 2202 |
+
// With the use of CANN graphs, the execution will be performed by the graph launch.
|
| 2203 |
+
if (!use_cann_graph || cann_graph_update_required) {
|
| 2204 |
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 2205 |
+
ggml_tensor * node = cgraph->nodes[i];
|
| 2206 |
+
|
| 2207 |
+
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
| 2208 |
+
continue;
|
| 2209 |
+
}
|
| 2210 |
+
|
| 2211 |
+
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
|
| 2212 |
+
if (!ok) {
|
| 2213 |
+
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
| 2214 |
+
}
|
| 2215 |
+
GGML_ASSERT(ok);
|
| 2216 |
+
}
|
| 2217 |
+
}
|
| 2218 |
+
|
| 2219 |
+
#ifdef USE_ACL_GRAPH
|
| 2220 |
+
if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
|
| 2221 |
+
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->cann_graph->graph));
|
| 2222 |
+
}
|
| 2223 |
+
|
| 2224 |
+
if (use_cann_graph) {
|
| 2225 |
+
// Execute graph
|
| 2226 |
+
ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream()));
|
| 2227 |
+
}
|
| 2228 |
+
#endif // USE_ACL_GRAPH
|
| 2229 |
+
}
|
| 2230 |
+
|
| 2231 |
+
|
| 2232 |
/**
|
| 2233 |
* @brief Computes a computational graph using a CANN backend.
|
| 2234 |
*
|
|
|
|
| 2245 |
ggml_backend_t backend, ggml_cgraph* cgraph) {
|
| 2246 |
ggml_backend_cann_context* cann_ctx =
|
| 2247 |
(ggml_backend_cann_context*)backend->context;
|
|
|
|
| 2248 |
ggml_cann_set_device(cann_ctx->device);
|
|
|
|
| 2249 |
release_nz_workspace();
|
| 2250 |
+
#ifdef USE_ACL_GRAPH
|
| 2251 |
+
bool use_cann_graph = true;
|
| 2252 |
+
bool cann_graph_update_required = false;
|
| 2253 |
|
| 2254 |
+
// check environment LLAMA_SET_ROWS
|
| 2255 |
+
if (!cann_ctx->support_set_rows) {
|
| 2256 |
+
use_cann_graph = false;
|
| 2257 |
+
}
|
| 2258 |
|
| 2259 |
+
if (use_cann_graph) {
|
| 2260 |
+
if (cann_ctx->cann_graph == nullptr) {
|
| 2261 |
+
cann_ctx->cann_graph.reset(new ggml_cann_graph());
|
| 2262 |
+
cann_graph_update_required = true;
|
| 2263 |
}
|
| 2264 |
|
| 2265 |
+
cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph);
|
| 2266 |
+
set_ggml_graph_node_properties(cann_ctx, cgraph);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2267 |
}
|
| 2268 |
+
#else
|
| 2269 |
+
bool use_cann_graph = false;
|
| 2270 |
+
bool cann_graph_update_required = false;
|
| 2271 |
+
#endif // USE_ACL_GRAPH
|
| 2272 |
+
|
| 2273 |
+
evaluate_and_capture_cann_graph(
|
| 2274 |
+
cann_ctx,
|
| 2275 |
+
cgraph,
|
| 2276 |
+
use_cann_graph,
|
| 2277 |
+
cann_graph_update_required
|
| 2278 |
+
);
|
| 2279 |
|
| 2280 |
return GGML_STATUS_SUCCESS;
|
| 2281 |
}
|
|
|
|
| 2391 |
// only support F32 and F16.
|
| 2392 |
return false;
|
| 2393 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2394 |
return true;
|
| 2395 |
} break;
|
| 2396 |
case GGML_OP_CONT: {
|