Spaces:
Sleeping
Sleeping
talk-llama : sync llama.cpp
Browse files- examples/talk-llama/llama-batch.cpp +3 -1
- examples/talk-llama/llama-context.cpp +79 -113
- examples/talk-llama/llama-cparams.cpp +4 -0
- examples/talk-llama/llama-cparams.h +2 -0
- examples/talk-llama/llama-grammar.cpp +12 -2
- examples/talk-llama/llama-graph.cpp +137 -233
- examples/talk-llama/llama-graph.h +49 -7
- examples/talk-llama/llama-hparams.cpp +17 -1
- examples/talk-llama/llama-hparams.h +34 -5
- examples/talk-llama/llama-kv-cache.cpp +724 -479
- examples/talk-llama/llama-kv-cache.h +194 -91
- examples/talk-llama/llama-kv-cells.h +379 -0
- examples/talk-llama/llama-memory.h +4 -3
- examples/talk-llama/llama-model.cpp +278 -95
- examples/talk-llama/llama-model.h +4 -1
- examples/talk-llama/llama-sampling.cpp +2 -2
- examples/talk-llama/llama-vocab.cpp +4 -4
- examples/talk-llama/llama.h +25 -124
examples/talk-llama/llama-batch.cpp
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
#include "llama-batch.h"
|
| 2 |
|
|
|
|
| 3 |
#include <cstring>
|
| 4 |
#include <algorithm>
|
| 5 |
|
|
@@ -281,9 +282,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
|
|
| 281 |
batch = in_batch;
|
| 282 |
GGML_ASSERT(batch.n_tokens > 0);
|
| 283 |
if (!batch.pos) {
|
|
|
|
| 284 |
pos.resize(batch.n_tokens);
|
| 285 |
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
| 286 |
-
pos[i] =
|
| 287 |
}
|
| 288 |
batch.pos = pos.data();
|
| 289 |
}
|
|
|
|
| 1 |
#include "llama-batch.h"
|
| 2 |
|
| 3 |
+
#include <cassert>
|
| 4 |
#include <cstring>
|
| 5 |
#include <algorithm>
|
| 6 |
|
|
|
|
| 282 |
batch = in_batch;
|
| 283 |
GGML_ASSERT(batch.n_tokens > 0);
|
| 284 |
if (!batch.pos) {
|
| 285 |
+
assert(p0 >= 0);
|
| 286 |
pos.resize(batch.n_tokens);
|
| 287 |
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
| 288 |
+
pos[i] = p0 + i;
|
| 289 |
}
|
| 290 |
batch.pos = pos.data();
|
| 291 |
}
|
examples/talk-llama/llama-context.cpp
CHANGED
|
@@ -25,7 +25,11 @@ llama_context::llama_context(
|
|
| 25 |
|
| 26 |
const auto & hparams = model.hparams;
|
| 27 |
|
| 28 |
-
cparams.n_seq_max
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
cparams.n_threads = params.n_threads;
|
| 30 |
cparams.n_threads_batch = params.n_threads_batch;
|
| 31 |
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
|
@@ -93,6 +97,7 @@ llama_context::llama_context(
|
|
| 93 |
}
|
| 94 |
|
| 95 |
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
|
|
|
| 96 |
cparams.op_offload = params.op_offload;
|
| 97 |
|
| 98 |
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
|
@@ -176,8 +181,9 @@ llama_context::llama_context(
|
|
| 176 |
// init the memory module
|
| 177 |
if (!hparams.vocab_only) {
|
| 178 |
llama_memory_params params_mem = {
|
| 179 |
-
/*.type_k
|
| 180 |
-
/*.type_v
|
|
|
|
| 181 |
};
|
| 182 |
|
| 183 |
memory.reset(model.create_memory(params_mem, cparams));
|
|
@@ -687,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
| 687 |
|
| 688 |
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
| 689 |
|
|
|
|
| 690 |
if (batch.token) {
|
| 691 |
for (int32_t i = 0; i < n_tokens; ++i) {
|
| 692 |
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
| 693 |
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
| 694 |
return -1;
|
| 695 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
}
|
| 697 |
}
|
| 698 |
|
|
@@ -846,7 +858,7 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
| 846 |
|
| 847 |
int llama_context::decode(llama_batch & inp_batch) {
|
| 848 |
if (!memory) {
|
| 849 |
-
|
| 850 |
return encode(inp_batch);
|
| 851 |
}
|
| 852 |
|
|
@@ -855,11 +867,17 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 855 |
return -1;
|
| 856 |
}
|
| 857 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
| 859 |
|
| 860 |
// temporary allocate memory for the input batch if needed
|
| 861 |
-
|
| 862 |
-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
|
| 863 |
|
| 864 |
const llama_batch & batch = batch_allocr.batch;
|
| 865 |
|
|
@@ -875,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 875 |
|
| 876 |
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
| 877 |
|
|
|
|
| 878 |
if (batch.token) {
|
| 879 |
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
| 880 |
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
| 881 |
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
| 882 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 883 |
}
|
| 884 |
}
|
| 885 |
}
|
|
@@ -947,8 +971,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 947 |
|
| 948 |
// find KV slot
|
| 949 |
if (!kv_self->find_slot(ubatch)) {
|
| 950 |
-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
| 951 |
-
|
| 952 |
return 1;
|
| 953 |
}
|
| 954 |
|
|
@@ -2093,6 +2115,7 @@ llama_context_params llama_context_default_params() {
|
|
| 2093 |
/*.flash_attn =*/ false,
|
| 2094 |
/*.no_perf =*/ true,
|
| 2095 |
/*.op_offload =*/ true,
|
|
|
|
| 2096 |
};
|
| 2097 |
|
| 2098 |
return result;
|
|
@@ -2287,65 +2310,51 @@ int32_t llama_apply_adapter_cvec(
|
|
| 2287 |
return res ? 0 : -1;
|
| 2288 |
}
|
| 2289 |
|
| 2290 |
-
//
|
| 2291 |
-
// kv cache view
|
| 2292 |
-
//
|
| 2293 |
-
|
| 2294 |
-
llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
|
| 2295 |
-
const auto * kv = ctx->get_kv_self();
|
| 2296 |
-
if (kv == nullptr) {
|
| 2297 |
-
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
| 2298 |
-
return {};
|
| 2299 |
-
}
|
| 2300 |
-
|
| 2301 |
-
return llama_kv_cache_view_init(*kv, n_seq_max);
|
| 2302 |
-
}
|
| 2303 |
-
|
| 2304 |
-
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
|
| 2305 |
-
const auto * kv = ctx->get_kv_self();
|
| 2306 |
-
if (kv == nullptr) {
|
| 2307 |
-
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
| 2308 |
-
return;
|
| 2309 |
-
}
|
| 2310 |
-
|
| 2311 |
-
llama_kv_cache_view_update(view, kv);
|
| 2312 |
-
}
|
| 2313 |
-
|
| 2314 |
//
|
| 2315 |
// kv cache
|
| 2316 |
//
|
| 2317 |
|
| 2318 |
// deprecated
|
| 2319 |
-
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
| 2320 |
-
return llama_kv_self_n_tokens(ctx);
|
| 2321 |
-
}
|
| 2322 |
-
|
| 2323 |
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
| 2324 |
const auto * kv = ctx->get_kv_self();
|
| 2325 |
if (!kv) {
|
| 2326 |
return 0;
|
| 2327 |
}
|
| 2328 |
|
| 2329 |
-
|
| 2330 |
-
}
|
| 2331 |
|
| 2332 |
-
|
| 2333 |
-
|
| 2334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2335 |
}
|
| 2336 |
|
|
|
|
|
|
|
| 2337 |
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
| 2338 |
const auto * kv = ctx->get_kv_self();
|
| 2339 |
if (!kv) {
|
| 2340 |
return 0;
|
| 2341 |
}
|
| 2342 |
|
| 2343 |
-
|
| 2344 |
-
}
|
| 2345 |
|
| 2346 |
-
|
| 2347 |
-
|
| 2348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2349 |
}
|
| 2350 |
|
| 2351 |
void llama_kv_self_clear(llama_context * ctx) {
|
|
@@ -2357,15 +2366,6 @@ void llama_kv_self_clear(llama_context * ctx) {
|
|
| 2357 |
kv->clear();
|
| 2358 |
}
|
| 2359 |
|
| 2360 |
-
// deprecated
|
| 2361 |
-
bool llama_kv_cache_seq_rm(
|
| 2362 |
-
llama_context * ctx,
|
| 2363 |
-
llama_seq_id seq_id,
|
| 2364 |
-
llama_pos p0,
|
| 2365 |
-
llama_pos p1) {
|
| 2366 |
-
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
|
| 2367 |
-
}
|
| 2368 |
-
|
| 2369 |
bool llama_kv_self_seq_rm(
|
| 2370 |
llama_context * ctx,
|
| 2371 |
llama_seq_id seq_id,
|
|
@@ -2379,16 +2379,6 @@ bool llama_kv_self_seq_rm(
|
|
| 2379 |
return kv->seq_rm(seq_id, p0, p1);
|
| 2380 |
}
|
| 2381 |
|
| 2382 |
-
// deprecated
|
| 2383 |
-
void llama_kv_cache_seq_cp(
|
| 2384 |
-
llama_context * ctx,
|
| 2385 |
-
llama_seq_id seq_id_src,
|
| 2386 |
-
llama_seq_id seq_id_dst,
|
| 2387 |
-
llama_pos p0,
|
| 2388 |
-
llama_pos p1) {
|
| 2389 |
-
llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
|
| 2390 |
-
}
|
| 2391 |
-
|
| 2392 |
void llama_kv_self_seq_cp(
|
| 2393 |
llama_context * ctx,
|
| 2394 |
llama_seq_id seq_id_src,
|
|
@@ -2403,13 +2393,6 @@ void llama_kv_self_seq_cp(
|
|
| 2403 |
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
| 2404 |
}
|
| 2405 |
|
| 2406 |
-
// deprecated
|
| 2407 |
-
void llama_kv_cache_seq_keep(
|
| 2408 |
-
llama_context * ctx,
|
| 2409 |
-
llama_seq_id seq_id) {
|
| 2410 |
-
llama_kv_self_seq_keep(ctx, seq_id);
|
| 2411 |
-
}
|
| 2412 |
-
|
| 2413 |
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
| 2414 |
auto * kv = ctx->get_kv_self();
|
| 2415 |
if (!kv) {
|
|
@@ -2419,16 +2402,6 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
|
| 2419 |
kv->seq_keep(seq_id);
|
| 2420 |
}
|
| 2421 |
|
| 2422 |
-
// deprecated
|
| 2423 |
-
void llama_kv_cache_seq_add(
|
| 2424 |
-
llama_context * ctx,
|
| 2425 |
-
llama_seq_id seq_id,
|
| 2426 |
-
llama_pos p0,
|
| 2427 |
-
llama_pos p1,
|
| 2428 |
-
llama_pos delta) {
|
| 2429 |
-
llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
|
| 2430 |
-
}
|
| 2431 |
-
|
| 2432 |
void llama_kv_self_seq_add(
|
| 2433 |
llama_context * ctx,
|
| 2434 |
llama_seq_id seq_id,
|
|
@@ -2443,16 +2416,6 @@ void llama_kv_self_seq_add(
|
|
| 2443 |
kv->seq_add(seq_id, p0, p1, delta);
|
| 2444 |
}
|
| 2445 |
|
| 2446 |
-
// deprecated
|
| 2447 |
-
void llama_kv_cache_seq_div(
|
| 2448 |
-
llama_context * ctx,
|
| 2449 |
-
llama_seq_id seq_id,
|
| 2450 |
-
llama_pos p0,
|
| 2451 |
-
llama_pos p1,
|
| 2452 |
-
int d) {
|
| 2453 |
-
llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
|
| 2454 |
-
}
|
| 2455 |
-
|
| 2456 |
void llama_kv_self_seq_div(
|
| 2457 |
llama_context * ctx,
|
| 2458 |
llama_seq_id seq_id,
|
|
@@ -2467,25 +2430,24 @@ void llama_kv_self_seq_div(
|
|
| 2467 |
kv->seq_div(seq_id, p0, p1, d);
|
| 2468 |
}
|
| 2469 |
|
| 2470 |
-
|
| 2471 |
-
|
| 2472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2473 |
}
|
| 2474 |
|
| 2475 |
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
| 2476 |
const auto * kv = ctx->get_kv_self();
|
| 2477 |
if (!kv) {
|
| 2478 |
-
return
|
| 2479 |
}
|
| 2480 |
|
| 2481 |
return kv->seq_pos_max(seq_id);
|
| 2482 |
}
|
| 2483 |
|
| 2484 |
-
// deprecated
|
| 2485 |
-
void llama_kv_cache_defrag(llama_context * ctx) {
|
| 2486 |
-
llama_kv_self_defrag(ctx);
|
| 2487 |
-
}
|
| 2488 |
-
|
| 2489 |
void llama_kv_self_defrag(llama_context * ctx) {
|
| 2490 |
auto * kv = ctx->get_kv_self();
|
| 2491 |
if (!kv) {
|
|
@@ -2496,11 +2458,6 @@ void llama_kv_self_defrag(llama_context * ctx) {
|
|
| 2496 |
kv->defrag_sched(-1.0f);
|
| 2497 |
}
|
| 2498 |
|
| 2499 |
-
// deprecated
|
| 2500 |
-
bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
| 2501 |
-
return llama_kv_self_can_shift(ctx);
|
| 2502 |
-
}
|
| 2503 |
-
|
| 2504 |
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
| 2505 |
const auto * kv = ctx->get_kv_self();
|
| 2506 |
if (!kv) {
|
|
@@ -2510,11 +2467,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) {
|
|
| 2510 |
return kv->get_can_shift();
|
| 2511 |
}
|
| 2512 |
|
| 2513 |
-
// deprecated
|
| 2514 |
-
void llama_kv_cache_update(llama_context * ctx) {
|
| 2515 |
-
llama_kv_self_update(ctx);
|
| 2516 |
-
}
|
| 2517 |
-
|
| 2518 |
// llama state API
|
| 2519 |
|
| 2520 |
// deprecated
|
|
@@ -2637,7 +2589,21 @@ int32_t llama_encode(
|
|
| 2637 |
int32_t llama_decode(
|
| 2638 |
llama_context * ctx,
|
| 2639 |
llama_batch batch) {
|
| 2640 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2641 |
if (ret != 0) {
|
| 2642 |
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
| 2643 |
}
|
|
|
|
| 25 |
|
| 26 |
const auto & hparams = model.hparams;
|
| 27 |
|
| 28 |
+
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
| 29 |
+
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
|
| 30 |
+
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
cparams.n_threads = params.n_threads;
|
| 34 |
cparams.n_threads_batch = params.n_threads_batch;
|
| 35 |
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
|
|
|
| 97 |
}
|
| 98 |
|
| 99 |
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
| 100 |
+
|
| 101 |
cparams.op_offload = params.op_offload;
|
| 102 |
|
| 103 |
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
|
|
|
| 181 |
// init the memory module
|
| 182 |
if (!hparams.vocab_only) {
|
| 183 |
llama_memory_params params_mem = {
|
| 184 |
+
/*.type_k =*/ params.type_k,
|
| 185 |
+
/*.type_v =*/ params.type_v,
|
| 186 |
+
/*.swa_full =*/ params.swa_full,
|
| 187 |
};
|
| 188 |
|
| 189 |
memory.reset(model.create_memory(params_mem, cparams));
|
|
|
|
| 693 |
|
| 694 |
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
| 695 |
|
| 696 |
+
// TODO: move the validation to the llama_batch_allocr
|
| 697 |
if (batch.token) {
|
| 698 |
for (int32_t i = 0; i < n_tokens; ++i) {
|
| 699 |
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
| 700 |
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
| 701 |
return -1;
|
| 702 |
}
|
| 703 |
+
|
| 704 |
+
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
|
| 705 |
+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
|
| 706 |
+
throw -1;
|
| 707 |
+
}
|
| 708 |
}
|
| 709 |
}
|
| 710 |
|
|
|
|
| 858 |
|
| 859 |
int llama_context::decode(llama_batch & inp_batch) {
|
| 860 |
if (!memory) {
|
| 861 |
+
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
| 862 |
return encode(inp_batch);
|
| 863 |
}
|
| 864 |
|
|
|
|
| 867 |
return -1;
|
| 868 |
}
|
| 869 |
|
| 870 |
+
if (!inp_batch.pos) {
|
| 871 |
+
if (inp_batch.seq_id) {
|
| 872 |
+
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
|
| 873 |
+
return -1;
|
| 874 |
+
}
|
| 875 |
+
}
|
| 876 |
+
|
| 877 |
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
| 878 |
|
| 879 |
// temporary allocate memory for the input batch if needed
|
| 880 |
+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
|
|
|
|
| 881 |
|
| 882 |
const llama_batch & batch = batch_allocr.batch;
|
| 883 |
|
|
|
|
| 893 |
|
| 894 |
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
| 895 |
|
| 896 |
+
// TODO: move the validation to the llama_batch_allocr
|
| 897 |
if (batch.token) {
|
| 898 |
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
| 899 |
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
| 900 |
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
| 901 |
+
return -1;
|
| 902 |
+
}
|
| 903 |
+
|
| 904 |
+
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
|
| 905 |
+
LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
|
| 906 |
+
return -1;
|
| 907 |
}
|
| 908 |
}
|
| 909 |
}
|
|
|
|
| 971 |
|
| 972 |
// find KV slot
|
| 973 |
if (!kv_self->find_slot(ubatch)) {
|
|
|
|
|
|
|
| 974 |
return 1;
|
| 975 |
}
|
| 976 |
|
|
|
|
| 2115 |
/*.flash_attn =*/ false,
|
| 2116 |
/*.no_perf =*/ true,
|
| 2117 |
/*.op_offload =*/ true,
|
| 2118 |
+
/*.swa_full =*/ true,
|
| 2119 |
};
|
| 2120 |
|
| 2121 |
return result;
|
|
|
|
| 2310 |
return res ? 0 : -1;
|
| 2311 |
}
|
| 2312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2313 |
//
|
| 2314 |
// kv cache
|
| 2315 |
//
|
| 2316 |
|
| 2317 |
// deprecated
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2318 |
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
| 2319 |
const auto * kv = ctx->get_kv_self();
|
| 2320 |
if (!kv) {
|
| 2321 |
return 0;
|
| 2322 |
}
|
| 2323 |
|
| 2324 |
+
int32_t res = 0;
|
|
|
|
| 2325 |
|
| 2326 |
+
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
|
| 2327 |
+
const llama_pos p0 = kv->seq_pos_min(s);
|
| 2328 |
+
const llama_pos p1 = kv->seq_pos_max(s);
|
| 2329 |
+
|
| 2330 |
+
if (p0 >= 0) {
|
| 2331 |
+
res += (p1 - p0) + 1;
|
| 2332 |
+
}
|
| 2333 |
+
}
|
| 2334 |
+
|
| 2335 |
+
return res;
|
| 2336 |
}
|
| 2337 |
|
| 2338 |
+
// deprecated
|
| 2339 |
+
// note: this is the same as above - will be removed anyway, so it's ok
|
| 2340 |
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
| 2341 |
const auto * kv = ctx->get_kv_self();
|
| 2342 |
if (!kv) {
|
| 2343 |
return 0;
|
| 2344 |
}
|
| 2345 |
|
| 2346 |
+
int32_t res = 0;
|
|
|
|
| 2347 |
|
| 2348 |
+
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
|
| 2349 |
+
const llama_pos p0 = kv->seq_pos_min(s);
|
| 2350 |
+
const llama_pos p1 = kv->seq_pos_max(s);
|
| 2351 |
+
|
| 2352 |
+
if (p0 >= 0) {
|
| 2353 |
+
res += (p1 - p0) + 1;
|
| 2354 |
+
}
|
| 2355 |
+
}
|
| 2356 |
+
|
| 2357 |
+
return res;
|
| 2358 |
}
|
| 2359 |
|
| 2360 |
void llama_kv_self_clear(llama_context * ctx) {
|
|
|
|
| 2366 |
kv->clear();
|
| 2367 |
}
|
| 2368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2369 |
bool llama_kv_self_seq_rm(
|
| 2370 |
llama_context * ctx,
|
| 2371 |
llama_seq_id seq_id,
|
|
|
|
| 2379 |
return kv->seq_rm(seq_id, p0, p1);
|
| 2380 |
}
|
| 2381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2382 |
void llama_kv_self_seq_cp(
|
| 2383 |
llama_context * ctx,
|
| 2384 |
llama_seq_id seq_id_src,
|
|
|
|
| 2393 |
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
| 2394 |
}
|
| 2395 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2396 |
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
| 2397 |
auto * kv = ctx->get_kv_self();
|
| 2398 |
if (!kv) {
|
|
|
|
| 2402 |
kv->seq_keep(seq_id);
|
| 2403 |
}
|
| 2404 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2405 |
void llama_kv_self_seq_add(
|
| 2406 |
llama_context * ctx,
|
| 2407 |
llama_seq_id seq_id,
|
|
|
|
| 2416 |
kv->seq_add(seq_id, p0, p1, delta);
|
| 2417 |
}
|
| 2418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2419 |
void llama_kv_self_seq_div(
|
| 2420 |
llama_context * ctx,
|
| 2421 |
llama_seq_id seq_id,
|
|
|
|
| 2430 |
kv->seq_div(seq_id, p0, p1, d);
|
| 2431 |
}
|
| 2432 |
|
| 2433 |
+
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
| 2434 |
+
const auto * kv = ctx->get_kv_self();
|
| 2435 |
+
if (!kv) {
|
| 2436 |
+
return -1;
|
| 2437 |
+
}
|
| 2438 |
+
|
| 2439 |
+
return kv->seq_pos_min(seq_id);
|
| 2440 |
}
|
| 2441 |
|
| 2442 |
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
| 2443 |
const auto * kv = ctx->get_kv_self();
|
| 2444 |
if (!kv) {
|
| 2445 |
+
return -1;
|
| 2446 |
}
|
| 2447 |
|
| 2448 |
return kv->seq_pos_max(seq_id);
|
| 2449 |
}
|
| 2450 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2451 |
void llama_kv_self_defrag(llama_context * ctx) {
|
| 2452 |
auto * kv = ctx->get_kv_self();
|
| 2453 |
if (!kv) {
|
|
|
|
| 2458 |
kv->defrag_sched(-1.0f);
|
| 2459 |
}
|
| 2460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2461 |
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
| 2462 |
const auto * kv = ctx->get_kv_self();
|
| 2463 |
if (!kv) {
|
|
|
|
| 2467 |
return kv->get_can_shift();
|
| 2468 |
}
|
| 2469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2470 |
// llama state API
|
| 2471 |
|
| 2472 |
// deprecated
|
|
|
|
| 2589 |
int32_t llama_decode(
|
| 2590 |
llama_context * ctx,
|
| 2591 |
llama_batch batch) {
|
| 2592 |
+
int ret = ctx->decode(batch);
|
| 2593 |
+
|
| 2594 |
+
// defrag and try again
|
| 2595 |
+
// TODO: distinguish return code when we are sure that even after defrag there is no space available
|
| 2596 |
+
if (ret == 1) {
|
| 2597 |
+
llama_kv_self_defrag(ctx);
|
| 2598 |
+
ret = ctx->decode(batch);
|
| 2599 |
+
|
| 2600 |
+
if (ret == 1) {
|
| 2601 |
+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
| 2602 |
+
|
| 2603 |
+
return ret;
|
| 2604 |
+
}
|
| 2605 |
+
}
|
| 2606 |
+
|
| 2607 |
if (ret != 0) {
|
| 2608 |
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
| 2609 |
}
|
examples/talk-llama/llama-cparams.cpp
CHANGED
|
@@ -1 +1,5 @@
|
|
| 1 |
#include "llama-cparams.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#include "llama-cparams.h"
|
| 2 |
+
|
| 3 |
+
size_t llama_max_parallel_sequences(void) {
|
| 4 |
+
return LLAMA_MAX_PARALLEL_SEQUENCES;
|
| 5 |
+
}
|
examples/talk-llama/llama-cparams.h
CHANGED
|
@@ -4,6 +4,8 @@
|
|
| 4 |
|
| 5 |
#include <cstdint>
|
| 6 |
|
|
|
|
|
|
|
| 7 |
struct llama_cparams {
|
| 8 |
uint32_t n_ctx; // context size used during inference
|
| 9 |
uint32_t n_batch;
|
|
|
|
| 4 |
|
| 5 |
#include <cstdint>
|
| 6 |
|
| 7 |
+
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
|
| 8 |
+
|
| 9 |
struct llama_cparams {
|
| 10 |
uint32_t n_ctx; // context size used during inference
|
| 11 |
uint32_t n_batch;
|
examples/talk-llama/llama-grammar.cpp
CHANGED
|
@@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|
| 1177 |
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
| 1178 |
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
| 1179 |
grammar.awaiting_trigger = false;
|
| 1180 |
-
// get from the first
|
| 1181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1182 |
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
| 1183 |
grammar.trigger_buffer.clear();
|
| 1184 |
llama_grammar_accept_str(grammar, constrained_str);
|
|
|
|
| 1177 |
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
| 1178 |
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
| 1179 |
grammar.awaiting_trigger = false;
|
| 1180 |
+
// get from the first matched capturing group to the end of the string
|
| 1181 |
+
size_t start = std::string::npos;
|
| 1182 |
+
for (auto i = 1u; i < match.size(); i++) {
|
| 1183 |
+
if (match.length(i) > 0) {
|
| 1184 |
+
start = match.position(i);
|
| 1185 |
+
break;
|
| 1186 |
+
}
|
| 1187 |
+
}
|
| 1188 |
+
if (start == std::string::npos) {
|
| 1189 |
+
start = match.position(0);
|
| 1190 |
+
}
|
| 1191 |
+
auto constrained_str = grammar.trigger_buffer.substr(start);
|
| 1192 |
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
| 1193 |
grammar.trigger_buffer.clear();
|
| 1194 |
llama_grammar_accept_str(grammar, constrained_str);
|
examples/talk-llama/llama-graph.cpp
CHANGED
|
@@ -9,33 +9,6 @@
|
|
| 9 |
#include <cmath>
|
| 10 |
#include <cstring>
|
| 11 |
|
| 12 |
-
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
| 13 |
-
// TODO move to hparams if a T5 variant appears that uses a different value
|
| 14 |
-
const int64_t max_distance = 128;
|
| 15 |
-
|
| 16 |
-
if (bidirectional) {
|
| 17 |
-
n_buckets >>= 1;
|
| 18 |
-
}
|
| 19 |
-
|
| 20 |
-
const int64_t max_exact = n_buckets >> 1;
|
| 21 |
-
|
| 22 |
-
int32_t relative_position = x - y;
|
| 23 |
-
int32_t relative_bucket = 0;
|
| 24 |
-
|
| 25 |
-
if (bidirectional) {
|
| 26 |
-
relative_bucket += (relative_position > 0) * n_buckets;
|
| 27 |
-
relative_position = abs(relative_position);
|
| 28 |
-
} else {
|
| 29 |
-
relative_position = -std::min<int32_t>(relative_position, 0);
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
| 33 |
-
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
| 34 |
-
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
| 35 |
-
|
| 36 |
-
return relative_bucket;
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
| 40 |
if (ubatch->token) {
|
| 41 |
const int64_t n_tokens = ubatch->n_tokens;
|
|
@@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
| 110 |
|
| 111 |
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
| 112 |
if (pos_bucket) {
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
|
| 116 |
-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
| 117 |
-
|
| 118 |
-
int32_t * data = (int32_t *) pos_bucket->data;
|
| 119 |
-
|
| 120 |
-
const int64_t n_kv = kv_self->n;
|
| 121 |
-
|
| 122 |
-
for (int h = 0; h < 1; ++h) {
|
| 123 |
-
for (int j = 0; j < n_tokens; ++j) {
|
| 124 |
-
for (int i = 0; i < n_kv; ++i) {
|
| 125 |
-
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
| 126 |
-
}
|
| 127 |
-
}
|
| 128 |
-
}
|
| 129 |
}
|
| 130 |
}
|
| 131 |
|
|
@@ -403,99 +361,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
| 403 |
}
|
| 404 |
|
| 405 |
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
| 406 |
-
if (self_kq_mask
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
const int64_t n_seqs = ubatch->n_seqs;
|
| 411 |
-
|
| 412 |
-
float * data = nullptr;
|
| 413 |
-
float * data_swa = nullptr;
|
| 414 |
-
|
| 415 |
-
if (self_kq_mask) {
|
| 416 |
-
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
| 417 |
-
data = (float *) self_kq_mask->data;
|
| 418 |
-
}
|
| 419 |
-
|
| 420 |
-
if (self_kq_mask_swa) {
|
| 421 |
-
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
| 422 |
-
data_swa = (float *) self_kq_mask_swa->data;
|
| 423 |
-
}
|
| 424 |
-
|
| 425 |
-
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
| 426 |
-
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
| 427 |
-
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
| 428 |
-
// Causal mask:
|
| 429 |
-
// xxx-------
|
| 430 |
-
// xxxx------
|
| 431 |
-
// xxxxx-----
|
| 432 |
-
// Non-causal mask:
|
| 433 |
-
// xxxxx-----
|
| 434 |
-
// xxxxx-----
|
| 435 |
-
// xxxxx-----
|
| 436 |
-
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
| 437 |
-
for (int h = 0; h < 1; ++h) {
|
| 438 |
-
for (int s = 0; s < n_seqs; ++s) {
|
| 439 |
-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 440 |
-
|
| 441 |
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
| 442 |
-
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
|
| 443 |
-
for (int i = 0; i < n_kv; ++i) {
|
| 444 |
-
float f;
|
| 445 |
-
// mask the token if:
|
| 446 |
-
if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
|
| 447 |
-
|| (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
|
| 448 |
-
) {
|
| 449 |
-
f = -INFINITY;
|
| 450 |
-
} else {
|
| 451 |
-
if (hparams.use_alibi) {
|
| 452 |
-
f = -std::abs(kv_self->cells[i].pos - pos);
|
| 453 |
-
} else {
|
| 454 |
-
f = 0.0f;
|
| 455 |
-
}
|
| 456 |
-
}
|
| 457 |
-
|
| 458 |
-
if (data) {
|
| 459 |
-
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
| 460 |
-
}
|
| 461 |
-
|
| 462 |
-
// may need to cut off old tokens for sliding window
|
| 463 |
-
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
|
| 464 |
-
if (data_swa) {
|
| 465 |
-
if (hparams.n_attn_chunk) {
|
| 466 |
-
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
| 467 |
-
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
| 468 |
-
f = -INFINITY;
|
| 469 |
-
}
|
| 470 |
-
} else {
|
| 471 |
-
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
| 472 |
-
f = -INFINITY;
|
| 473 |
-
}
|
| 474 |
-
}
|
| 475 |
-
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
| 476 |
-
}
|
| 477 |
-
}
|
| 478 |
-
}
|
| 479 |
-
}
|
| 480 |
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
| 486 |
-
}
|
| 487 |
-
}
|
| 488 |
-
}
|
| 489 |
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 493 |
-
for (int j = 0; j < n_kv; ++j) {
|
| 494 |
-
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
| 495 |
-
}
|
| 496 |
-
}
|
| 497 |
-
}
|
| 498 |
-
}
|
| 499 |
}
|
| 500 |
}
|
| 501 |
|
|
@@ -545,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
| 545 |
n_layer (hparams.n_layer),
|
| 546 |
n_rot (hparams.n_rot),
|
| 547 |
n_ctx (cparams.n_ctx),
|
| 548 |
-
n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
|
| 549 |
n_head (hparams.n_head()),
|
| 550 |
n_head_kv (hparams.n_head_kv()),
|
| 551 |
n_embd_head_k (hparams.n_embd_head_k),
|
|
@@ -1153,7 +1029,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
|
| 1153 |
|
| 1154 |
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
|
| 1155 |
|
| 1156 |
-
const auto n_kv = kv_self->
|
| 1157 |
|
| 1158 |
auto & cur = inp->pos_bucket;
|
| 1159 |
|
|
@@ -1188,16 +1064,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
| 1188 |
ggml_tensor * kq_b,
|
| 1189 |
ggml_tensor * kq_mask,
|
| 1190 |
ggml_tensor * v_mla,
|
| 1191 |
-
bool v_trans,
|
| 1192 |
float kq_scale) const {
|
| 1193 |
-
|
| 1194 |
-
//const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1195 |
-
|
| 1196 |
-
//const int64_t n_head = hparams.n_head(il);
|
| 1197 |
-
//const int64_t n_head_kv = hparams.n_head_kv(il);
|
| 1198 |
|
| 1199 |
-
|
| 1200 |
-
|
|
|
|
| 1201 |
|
| 1202 |
const auto n_tokens = q->ne[1];
|
| 1203 |
const auto n_head = q->ne[2];
|
|
@@ -1336,17 +1208,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1336 |
|
| 1337 |
const auto & kq_mask = inp->get_kq_mask();
|
| 1338 |
|
| 1339 |
-
ggml_tensor * q =
|
| 1340 |
-
|
| 1341 |
-
|
| 1342 |
-
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
| 1343 |
-
//cb(k, "k", il);
|
| 1344 |
-
|
| 1345 |
-
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
| 1346 |
-
//cb(k, "v", il);
|
| 1347 |
-
|
| 1348 |
-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
| 1349 |
|
|
|
|
| 1350 |
cb(cur, "kqv_out", il);
|
| 1351 |
|
| 1352 |
if (wo) {
|
|
@@ -1369,22 +1235,16 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|
| 1369 |
|
| 1370 |
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
|
| 1371 |
|
| 1372 |
-
|
| 1373 |
-
|
| 1374 |
-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1375 |
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
| 1376 |
-
ggml_set_input(inp->self_kq_mask);
|
| 1377 |
-
|
| 1378 |
-
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
| 1379 |
|
| 1380 |
-
|
| 1381 |
-
GGML_ASSERT(hparams.n_swa > 0);
|
| 1382 |
|
| 1383 |
-
inp->
|
| 1384 |
-
//cb(inp->
|
| 1385 |
-
ggml_set_input(inp->
|
| 1386 |
|
| 1387 |
-
inp->
|
| 1388 |
}
|
| 1389 |
|
| 1390 |
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
|
@@ -1409,81 +1269,104 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1409 |
ggml_build_forward_expand(gf, v_cur);
|
| 1410 |
|
| 1411 |
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
| 1412 |
-
const auto & n_ctx = cparams.n_ctx;
|
| 1413 |
|
| 1414 |
-
|
| 1415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1416 |
|
| 1417 |
-
|
|
|
|
|
|
|
| 1418 |
|
| 1419 |
-
|
|
|
|
| 1420 |
|
| 1421 |
-
|
| 1422 |
-
|
| 1423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1424 |
|
| 1425 |
-
|
|
|
|
|
|
|
| 1426 |
|
| 1427 |
-
|
| 1428 |
-
|
| 1429 |
|
| 1430 |
-
|
| 1431 |
-
|
| 1432 |
|
| 1433 |
-
|
| 1434 |
|
| 1435 |
-
|
|
|
|
| 1436 |
|
| 1437 |
-
|
| 1438 |
-
|
| 1439 |
-
|
| 1440 |
-
// note: the V cache is transposed when not using flash attention
|
| 1441 |
-
v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
|
| 1442 |
-
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
|
| 1443 |
-
(kv_head)*ggml_element_size(kv_self->v_l[il]));
|
| 1444 |
|
| 1445 |
-
|
| 1446 |
-
|
| 1447 |
-
|
|
|
|
|
|
|
| 1448 |
|
| 1449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1450 |
}
|
| 1451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1452 |
const bool is_swa = hparams.is_swa(il);
|
| 1453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1454 |
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
| 1455 |
|
| 1456 |
-
|
|
|
|
|
|
|
| 1457 |
|
| 1458 |
-
|
| 1459 |
-
|
| 1460 |
-
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
| 1461 |
-
const auto & n_embd_head_v = hparams.n_embd_head_v;
|
| 1462 |
-
|
| 1463 |
-
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
| 1464 |
-
//cb(q, "q", il);
|
| 1465 |
-
|
| 1466 |
-
ggml_tensor * k =
|
| 1467 |
-
ggml_view_3d(ctx0, kv_self->k_l[il],
|
| 1468 |
-
n_embd_head_k, n_kv, n_head_kv,
|
| 1469 |
-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
| 1470 |
-
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
| 1471 |
-
0);
|
| 1472 |
-
//cb(k, "k", il);
|
| 1473 |
-
|
| 1474 |
-
ggml_tensor * v = !v_trans ?
|
| 1475 |
-
ggml_view_3d(ctx0, kv_self->v_l[il],
|
| 1476 |
-
n_embd_head_v, n_kv, n_head_kv,
|
| 1477 |
-
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
| 1478 |
-
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
|
| 1479 |
-
0) :
|
| 1480 |
-
ggml_view_3d(ctx0, kv_self->v_l[il],
|
| 1481 |
-
n_kv, n_embd_head_v, n_head_kv,
|
| 1482 |
-
ggml_element_size(kv_self->v_l[il])*n_ctx,
|
| 1483 |
-
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
|
| 1484 |
-
0);
|
| 1485 |
-
|
| 1486 |
-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
|
| 1487 |
cb(cur, "kqv_out", il);
|
| 1488 |
|
| 1489 |
if (wo) {
|
|
@@ -1534,17 +1417,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1534 |
|
| 1535 |
const auto & kq_mask = inp->get_kq_mask_cross();
|
| 1536 |
|
| 1537 |
-
ggml_tensor * q =
|
| 1538 |
-
|
| 1539 |
-
|
| 1540 |
-
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
| 1541 |
-
//cb(k, "k", il);
|
| 1542 |
-
|
| 1543 |
-
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
| 1544 |
-
//cb(k, "v", il);
|
| 1545 |
-
|
| 1546 |
-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
| 1547 |
|
|
|
|
| 1548 |
cb(cur, "kqv_out", il);
|
| 1549 |
|
| 1550 |
if (wo) {
|
|
@@ -1712,3 +1589,30 @@ void llm_graph_context::build_pooling(
|
|
| 1712 |
|
| 1713 |
ggml_build_forward_expand(gf, cur);
|
| 1714 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
#include <cmath>
|
| 10 |
#include <cstring>
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
| 13 |
if (ubatch->token) {
|
| 14 |
const int64_t n_tokens = ubatch->n_tokens;
|
|
|
|
| 83 |
|
| 84 |
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
| 85 |
if (pos_bucket) {
|
| 86 |
+
kv_self->set_input_pos_bucket(pos_bucket, ubatch);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
}
|
| 88 |
}
|
| 89 |
|
|
|
|
| 361 |
}
|
| 362 |
|
| 363 |
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
| 364 |
+
if (self_kq_mask) {
|
| 365 |
+
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
| 366 |
+
}
|
| 367 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
+
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
| 370 |
+
if (self_kq_mask) {
|
| 371 |
+
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
| 372 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
+
if (self_kq_mask_swa) {
|
| 375 |
+
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
}
|
| 377 |
}
|
| 378 |
|
|
|
|
| 422 |
n_layer (hparams.n_layer),
|
| 423 |
n_rot (hparams.n_rot),
|
| 424 |
n_ctx (cparams.n_ctx),
|
|
|
|
| 425 |
n_head (hparams.n_head()),
|
| 426 |
n_head_kv (hparams.n_head_kv()),
|
| 427 |
n_embd_head_k (hparams.n_embd_head_k),
|
|
|
|
| 1029 |
|
| 1030 |
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
|
| 1031 |
|
| 1032 |
+
const auto n_kv = kv_self->get_n();
|
| 1033 |
|
| 1034 |
auto & cur = inp->pos_bucket;
|
| 1035 |
|
|
|
|
| 1064 |
ggml_tensor * kq_b,
|
| 1065 |
ggml_tensor * kq_mask,
|
| 1066 |
ggml_tensor * v_mla,
|
|
|
|
| 1067 |
float kq_scale) const {
|
| 1068 |
+
const bool v_trans = v->nb[1] > v->nb[2];
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1069 |
|
| 1070 |
+
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
| 1071 |
+
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
| 1072 |
+
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
| 1073 |
|
| 1074 |
const auto n_tokens = q->ne[1];
|
| 1075 |
const auto n_head = q->ne[2];
|
|
|
|
| 1208 |
|
| 1209 |
const auto & kq_mask = inp->get_kq_mask();
|
| 1210 |
|
| 1211 |
+
ggml_tensor * q = q_cur;
|
| 1212 |
+
ggml_tensor * k = k_cur;
|
| 1213 |
+
ggml_tensor * v = v_cur;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1214 |
|
| 1215 |
+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1216 |
cb(cur, "kqv_out", il);
|
| 1217 |
|
| 1218 |
if (wo) {
|
|
|
|
| 1235 |
|
| 1236 |
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
|
| 1237 |
|
| 1238 |
+
{
|
| 1239 |
+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1240 |
|
| 1241 |
+
const auto n_kv = kv_self->get_n();
|
|
|
|
| 1242 |
|
| 1243 |
+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1244 |
+
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
| 1245 |
+
ggml_set_input(inp->self_kq_mask);
|
| 1246 |
|
| 1247 |
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
| 1248 |
}
|
| 1249 |
|
| 1250 |
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
|
|
|
| 1269 |
ggml_build_forward_expand(gf, v_cur);
|
| 1270 |
|
| 1271 |
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
|
|
|
| 1272 |
|
| 1273 |
+
// store to KV cache
|
| 1274 |
+
{
|
| 1275 |
+
ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
|
| 1276 |
+
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
|
| 1277 |
+
}
|
| 1278 |
+
|
| 1279 |
+
const auto & kq_mask = inp->get_kq_mask();
|
| 1280 |
|
| 1281 |
+
ggml_tensor * q = q_cur;
|
| 1282 |
+
ggml_tensor * k = kv_self->get_k(ctx0, il);
|
| 1283 |
+
ggml_tensor * v = kv_self->get_v(ctx0, il);
|
| 1284 |
|
| 1285 |
+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1286 |
+
cb(cur, "kqv_out", il);
|
| 1287 |
|
| 1288 |
+
if (wo) {
|
| 1289 |
+
cur = build_lora_mm(wo, cur);
|
| 1290 |
+
if (arch == LLM_ARCH_GLM4) {
|
| 1291 |
+
// GLM4 seems to have numerical issues with half-precision accumulators
|
| 1292 |
+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
| 1293 |
+
}
|
| 1294 |
+
}
|
| 1295 |
|
| 1296 |
+
if (wo_b) {
|
| 1297 |
+
cur = ggml_add(ctx0, cur, wo_b);
|
| 1298 |
+
}
|
| 1299 |
|
| 1300 |
+
return cur;
|
| 1301 |
+
}
|
| 1302 |
|
| 1303 |
+
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
| 1304 |
+
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
| 1305 |
|
| 1306 |
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
|
| 1307 |
|
| 1308 |
+
{
|
| 1309 |
+
const auto n_kv = kv_self->get_kv_base()->get_n();
|
| 1310 |
|
| 1311 |
+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1312 |
+
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
| 1313 |
+
ggml_set_input(inp->self_kq_mask);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1314 |
|
| 1315 |
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
| 1316 |
+
}
|
| 1317 |
+
|
| 1318 |
+
{
|
| 1319 |
+
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
| 1320 |
|
| 1321 |
+
const auto n_kv = kv_self->get_kv_swa()->get_n();
|
| 1322 |
+
|
| 1323 |
+
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1324 |
+
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
| 1325 |
+
ggml_set_input(inp->self_kq_mask_swa);
|
| 1326 |
+
|
| 1327 |
+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
| 1328 |
}
|
| 1329 |
|
| 1330 |
+
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
| 1331 |
+
}
|
| 1332 |
+
|
| 1333 |
+
ggml_tensor * llm_graph_context::build_attn(
|
| 1334 |
+
llm_graph_input_attn_kv_unified_iswa * inp,
|
| 1335 |
+
ggml_cgraph * gf,
|
| 1336 |
+
ggml_tensor * wo,
|
| 1337 |
+
ggml_tensor * wo_b,
|
| 1338 |
+
ggml_tensor * q_cur,
|
| 1339 |
+
ggml_tensor * k_cur,
|
| 1340 |
+
ggml_tensor * v_cur,
|
| 1341 |
+
ggml_tensor * kq_b,
|
| 1342 |
+
ggml_tensor * v_mla,
|
| 1343 |
+
float kq_scale,
|
| 1344 |
+
int il) const {
|
| 1345 |
+
// these nodes are added to the graph together so that they are not reordered
|
| 1346 |
+
// by doing so, the number of splits in the graph is reduced
|
| 1347 |
+
ggml_build_forward_expand(gf, q_cur);
|
| 1348 |
+
ggml_build_forward_expand(gf, k_cur);
|
| 1349 |
+
ggml_build_forward_expand(gf, v_cur);
|
| 1350 |
+
|
| 1351 |
const bool is_swa = hparams.is_swa(il);
|
| 1352 |
|
| 1353 |
+
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
| 1354 |
+
|
| 1355 |
+
const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
|
| 1356 |
+
|
| 1357 |
+
// store to KV cache
|
| 1358 |
+
{
|
| 1359 |
+
ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
|
| 1360 |
+
ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
|
| 1361 |
+
}
|
| 1362 |
+
|
| 1363 |
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
| 1364 |
|
| 1365 |
+
ggml_tensor * q = q_cur;
|
| 1366 |
+
ggml_tensor * k = kv->get_k(ctx0, il);
|
| 1367 |
+
ggml_tensor * v = kv->get_v(ctx0, il);
|
| 1368 |
|
| 1369 |
+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1370 |
cb(cur, "kqv_out", il);
|
| 1371 |
|
| 1372 |
if (wo) {
|
|
|
|
| 1417 |
|
| 1418 |
const auto & kq_mask = inp->get_kq_mask_cross();
|
| 1419 |
|
| 1420 |
+
ggml_tensor * q = q_cur;
|
| 1421 |
+
ggml_tensor * k = k_cur;
|
| 1422 |
+
ggml_tensor * v = v_cur;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1423 |
|
| 1424 |
+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1425 |
cb(cur, "kqv_out", il);
|
| 1426 |
|
| 1427 |
if (wo) {
|
|
|
|
| 1589 |
|
| 1590 |
ggml_build_forward_expand(gf, cur);
|
| 1591 |
}
|
| 1592 |
+
|
| 1593 |
+
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
| 1594 |
+
// TODO move to hparams if a T5 variant appears that uses a different value
|
| 1595 |
+
const int64_t max_distance = 128;
|
| 1596 |
+
|
| 1597 |
+
if (bidirectional) {
|
| 1598 |
+
n_buckets >>= 1;
|
| 1599 |
+
}
|
| 1600 |
+
|
| 1601 |
+
const int64_t max_exact = n_buckets >> 1;
|
| 1602 |
+
|
| 1603 |
+
int32_t relative_position = x - y;
|
| 1604 |
+
int32_t relative_bucket = 0;
|
| 1605 |
+
|
| 1606 |
+
if (bidirectional) {
|
| 1607 |
+
relative_bucket += (relative_position > 0) * n_buckets;
|
| 1608 |
+
relative_position = abs(relative_position);
|
| 1609 |
+
} else {
|
| 1610 |
+
relative_position = -std::min<int32_t>(relative_position, 0);
|
| 1611 |
+
}
|
| 1612 |
+
|
| 1613 |
+
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
| 1614 |
+
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
| 1615 |
+
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
| 1616 |
+
|
| 1617 |
+
return relative_bucket;
|
| 1618 |
+
}
|
examples/talk-llama/llama-graph.h
CHANGED
|
@@ -19,6 +19,7 @@ struct llama_cparams;
|
|
| 19 |
|
| 20 |
class llama_memory_i;
|
| 21 |
class llama_kv_cache_unified;
|
|
|
|
| 22 |
class llama_kv_cache_recurrent;
|
| 23 |
|
| 24 |
// certain models (typically multi-modal) can produce different types of graphs
|
|
@@ -255,6 +256,31 @@ public:
|
|
| 255 |
|
| 256 |
void set_input(const llama_ubatch * ubatch) override;
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
| 259 |
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
| 260 |
|
|
@@ -266,7 +292,7 @@ public:
|
|
| 266 |
const llama_hparams & hparams;
|
| 267 |
const llama_cparams & cparams;
|
| 268 |
|
| 269 |
-
const
|
| 270 |
};
|
| 271 |
|
| 272 |
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
|
@@ -378,7 +404,6 @@ struct llm_graph_context {
|
|
| 378 |
const int64_t n_layer;
|
| 379 |
const int64_t n_rot;
|
| 380 |
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
|
| 381 |
-
const int64_t n_ctx_per_seq;
|
| 382 |
const int64_t n_head;
|
| 383 |
const int64_t n_head_kv;
|
| 384 |
const int64_t n_embd_head_k;
|
|
@@ -507,13 +532,12 @@ struct llm_graph_context {
|
|
| 507 |
|
| 508 |
ggml_tensor * build_attn_mha(
|
| 509 |
ggml_cgraph * gf,
|
| 510 |
-
ggml_tensor * q,
|
| 511 |
-
ggml_tensor * k,
|
| 512 |
-
ggml_tensor * v,
|
| 513 |
ggml_tensor * kq_b,
|
| 514 |
ggml_tensor * kq_mask,
|
| 515 |
-
ggml_tensor * v_mla,
|
| 516 |
-
bool v_trans,
|
| 517 |
float kq_scale) const;
|
| 518 |
|
| 519 |
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
|
@@ -546,6 +570,21 @@ struct llm_graph_context {
|
|
| 546 |
float kq_scale,
|
| 547 |
int il) const;
|
| 548 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
| 550 |
|
| 551 |
ggml_tensor * build_attn(
|
|
@@ -596,3 +635,6 @@ struct llm_graph_context {
|
|
| 596 |
ggml_tensor * cls_out,
|
| 597 |
ggml_tensor * cls_out_b) const;
|
| 598 |
};
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
class llama_memory_i;
|
| 21 |
class llama_kv_cache_unified;
|
| 22 |
+
class llama_kv_cache_unified_iswa;
|
| 23 |
class llama_kv_cache_recurrent;
|
| 24 |
|
| 25 |
// certain models (typically multi-modal) can produce different types of graphs
|
|
|
|
| 256 |
|
| 257 |
void set_input(const llama_ubatch * ubatch) override;
|
| 258 |
|
| 259 |
+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
| 260 |
+
|
| 261 |
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
| 262 |
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
| 263 |
+
|
| 264 |
+
const llama_hparams & hparams;
|
| 265 |
+
const llama_cparams & cparams;
|
| 266 |
+
|
| 267 |
+
const llama_kv_cache_unified * kv_self;
|
| 268 |
+
};
|
| 269 |
+
|
| 270 |
+
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
| 271 |
+
public:
|
| 272 |
+
llm_graph_input_attn_kv_unified_iswa(
|
| 273 |
+
const llama_hparams & hparams,
|
| 274 |
+
const llama_cparams & cparams,
|
| 275 |
+
const llama_kv_cache_unified_iswa * kv_self) :
|
| 276 |
+
hparams(hparams),
|
| 277 |
+
cparams(cparams),
|
| 278 |
+
kv_self(kv_self) {
|
| 279 |
+
}
|
| 280 |
+
~llm_graph_input_attn_kv_unified_iswa() = default;
|
| 281 |
+
|
| 282 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 283 |
+
|
| 284 |
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
| 285 |
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
| 286 |
|
|
|
|
| 292 |
const llama_hparams & hparams;
|
| 293 |
const llama_cparams & cparams;
|
| 294 |
|
| 295 |
+
const llama_kv_cache_unified_iswa * kv_self;
|
| 296 |
};
|
| 297 |
|
| 298 |
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
|
|
|
| 404 |
const int64_t n_layer;
|
| 405 |
const int64_t n_rot;
|
| 406 |
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
|
|
|
|
| 407 |
const int64_t n_head;
|
| 408 |
const int64_t n_head_kv;
|
| 409 |
const int64_t n_embd_head_k;
|
|
|
|
| 532 |
|
| 533 |
ggml_tensor * build_attn_mha(
|
| 534 |
ggml_cgraph * gf,
|
| 535 |
+
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
| 536 |
+
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
| 537 |
+
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
| 538 |
ggml_tensor * kq_b,
|
| 539 |
ggml_tensor * kq_mask,
|
| 540 |
+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
|
|
| 541 |
float kq_scale) const;
|
| 542 |
|
| 543 |
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
|
|
|
| 570 |
float kq_scale,
|
| 571 |
int il) const;
|
| 572 |
|
| 573 |
+
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
|
| 574 |
+
|
| 575 |
+
ggml_tensor * build_attn(
|
| 576 |
+
llm_graph_input_attn_kv_unified_iswa * inp,
|
| 577 |
+
ggml_cgraph * gf,
|
| 578 |
+
ggml_tensor * wo,
|
| 579 |
+
ggml_tensor * wo_b,
|
| 580 |
+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
| 581 |
+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
| 582 |
+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
| 583 |
+
ggml_tensor * kq_b,
|
| 584 |
+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
| 585 |
+
float kq_scale,
|
| 586 |
+
int il) const;
|
| 587 |
+
|
| 588 |
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
| 589 |
|
| 590 |
ggml_tensor * build_attn(
|
|
|
|
| 635 |
ggml_tensor * cls_out,
|
| 636 |
ggml_tensor * cls_out_b) const;
|
| 637 |
};
|
| 638 |
+
|
| 639 |
+
// TODO: better name
|
| 640 |
+
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
|
examples/talk-llama/llama-hparams.cpp
CHANGED
|
@@ -2,6 +2,22 @@
|
|
| 2 |
|
| 3 |
#include "ggml.h"
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
uint32_t llama_hparams::n_head(uint32_t il) const {
|
| 6 |
if (il < n_layer) {
|
| 7 |
return n_head_arr[il];
|
|
@@ -72,7 +88,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
|
| 72 |
|
| 73 |
bool llama_hparams::is_swa(uint32_t il) const {
|
| 74 |
if (il < n_layer) {
|
| 75 |
-
return
|
| 76 |
}
|
| 77 |
|
| 78 |
GGML_ABORT("fatal error");
|
|
|
|
| 2 |
|
| 3 |
#include "ggml.h"
|
| 4 |
|
| 5 |
+
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
|
| 6 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 7 |
+
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
|
| 8 |
+
}
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
bool llama_hparams::is_swa_any() const {
|
| 12 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 13 |
+
if (swa_layers[il]) {
|
| 14 |
+
return true;
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
return false;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
uint32_t llama_hparams::n_head(uint32_t il) const {
|
| 22 |
if (il < n_layer) {
|
| 23 |
return n_head_arr[il];
|
|
|
|
| 88 |
|
| 89 |
bool llama_hparams::is_swa(uint32_t il) const {
|
| 90 |
if (il < n_layer) {
|
| 91 |
+
return swa_layers[il];
|
| 92 |
}
|
| 93 |
|
| 94 |
GGML_ABORT("fatal error");
|
examples/talk-llama/llama-hparams.h
CHANGED
|
@@ -14,6 +14,12 @@ enum llama_expert_gating_func_type {
|
|
| 14 |
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
| 15 |
};
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
struct llama_hparams_posnet {
|
| 18 |
uint32_t n_embd;
|
| 19 |
uint32_t n_layer;
|
|
@@ -35,8 +41,6 @@ struct llama_hparams {
|
|
| 35 |
uint32_t n_embd_features = 0;
|
| 36 |
uint32_t n_layer;
|
| 37 |
uint32_t n_rot;
|
| 38 |
-
uint32_t n_swa = 0; // sliding window attention (SWA)
|
| 39 |
-
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
|
| 40 |
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
| 41 |
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
| 42 |
uint32_t n_expert = 0;
|
|
@@ -96,6 +100,15 @@ struct llama_hparams {
|
|
| 96 |
|
| 97 |
std::array<int, 4> rope_sections;
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
// for State Space Models
|
| 100 |
uint32_t ssm_d_conv = 0;
|
| 101 |
uint32_t ssm_d_inner = 0;
|
|
@@ -116,11 +129,10 @@ struct llama_hparams {
|
|
| 116 |
bool causal_attn = true;
|
| 117 |
bool use_alibi = false;
|
| 118 |
bool attn_soft_cap = false;
|
|
|
|
| 119 |
|
|
|
|
| 120 |
uint32_t n_moe_layer_step = 0;
|
| 121 |
-
bool use_kq_norm = true;
|
| 122 |
-
uint32_t n_attn_chunk = 0;
|
| 123 |
-
// values below seems to be fixed on llama4
|
| 124 |
uint32_t n_no_rope_layer_step = 4;
|
| 125 |
uint32_t n_attn_temp_floor_scale = 8192;
|
| 126 |
float f_attn_temp_scale = 0.1;
|
|
@@ -133,6 +145,23 @@ struct llama_hparams {
|
|
| 133 |
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
| 134 |
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
uint32_t n_head(uint32_t il = 0) const;
|
| 137 |
|
| 138 |
uint32_t n_head_kv(uint32_t il = 0) const;
|
|
|
|
| 14 |
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
| 15 |
};
|
| 16 |
|
| 17 |
+
enum llama_swa_type {
|
| 18 |
+
LLAMA_SWA_TYPE_NONE = 0,
|
| 19 |
+
LLAMA_SWA_TYPE_STANDARD = 1,
|
| 20 |
+
LLAMA_SWA_TYPE_CHUNKED = 2,
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
struct llama_hparams_posnet {
|
| 24 |
uint32_t n_embd;
|
| 25 |
uint32_t n_layer;
|
|
|
|
| 41 |
uint32_t n_embd_features = 0;
|
| 42 |
uint32_t n_layer;
|
| 43 |
uint32_t n_rot;
|
|
|
|
|
|
|
| 44 |
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
| 45 |
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
| 46 |
uint32_t n_expert = 0;
|
|
|
|
| 100 |
|
| 101 |
std::array<int, 4> rope_sections;
|
| 102 |
|
| 103 |
+
// Sliding Window Attention (SWA)
|
| 104 |
+
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
| 105 |
+
// the size of the sliding window (0 - no SWA)
|
| 106 |
+
uint32_t n_swa = 0;
|
| 107 |
+
// if swa_layers[il] == true, then layer il is SWA
|
| 108 |
+
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
|
| 109 |
+
// by default, all layers are dense
|
| 110 |
+
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
|
| 111 |
+
|
| 112 |
// for State Space Models
|
| 113 |
uint32_t ssm_d_conv = 0;
|
| 114 |
uint32_t ssm_d_inner = 0;
|
|
|
|
| 129 |
bool causal_attn = true;
|
| 130 |
bool use_alibi = false;
|
| 131 |
bool attn_soft_cap = false;
|
| 132 |
+
bool use_kq_norm = true;
|
| 133 |
|
| 134 |
+
// llama4
|
| 135 |
uint32_t n_moe_layer_step = 0;
|
|
|
|
|
|
|
|
|
|
| 136 |
uint32_t n_no_rope_layer_step = 4;
|
| 137 |
uint32_t n_attn_temp_floor_scale = 8192;
|
| 138 |
float f_attn_temp_scale = 0.1;
|
|
|
|
| 145 |
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
| 146 |
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
| 147 |
|
| 148 |
+
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
|
| 149 |
+
// note that if n_pattern == 0, all layers are SWA
|
| 150 |
+
// if n_pattern == 1, all layers are dense
|
| 151 |
+
// example: n_pattern = 3
|
| 152 |
+
// il == 0: swa
|
| 153 |
+
// il == 1: swa
|
| 154 |
+
// il == 2: dense
|
| 155 |
+
// il == 3: swa
|
| 156 |
+
// il == 4: swa
|
| 157 |
+
// il == 5: dense
|
| 158 |
+
// il == 6: swa
|
| 159 |
+
// etc ...
|
| 160 |
+
void set_swa_pattern(uint32_t n_pattern);
|
| 161 |
+
|
| 162 |
+
// return true if one of the layers is SWA
|
| 163 |
+
bool is_swa_any() const;
|
| 164 |
+
|
| 165 |
uint32_t n_head(uint32_t il = 0) const;
|
| 166 |
|
| 167 |
uint32_t n_head_kv(uint32_t il = 0) const;
|
examples/talk-llama/llama-kv-cache.cpp
CHANGED
|
@@ -23,32 +23,21 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
|
|
| 23 |
}
|
| 24 |
|
| 25 |
llama_kv_cache_unified::llama_kv_cache_unified(
|
| 26 |
-
const llama_model &
|
| 27 |
-
|
| 28 |
-
ggml_type
|
| 29 |
-
|
| 30 |
-
bool
|
| 31 |
-
|
| 32 |
-
uint32_t
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
|
| 42 |
-
|
| 43 |
-
head = 0;
|
| 44 |
-
size = kv_size;
|
| 45 |
-
used = 0;
|
| 46 |
-
|
| 47 |
-
this->type_k = type_k;
|
| 48 |
-
this->type_v = type_v;
|
| 49 |
-
|
| 50 |
-
cells.clear();
|
| 51 |
-
cells.resize(kv_size);
|
| 52 |
|
| 53 |
// create a context for each buffer type
|
| 54 |
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
|
@@ -56,7 +45,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 56 |
auto it = ctx_map.find(buft);
|
| 57 |
if (it == ctx_map.end()) {
|
| 58 |
ggml_init_params params = {
|
| 59 |
-
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
|
| 60 |
/*.mem_buffer =*/ NULL,
|
| 61 |
/*.no_alloc =*/ true,
|
| 62 |
};
|
|
@@ -75,37 +64,48 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 75 |
return it->second;
|
| 76 |
};
|
| 77 |
|
| 78 |
-
|
| 79 |
-
v_l.reserve(n_layer);
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
const char * dev_name = "CPU";
|
| 86 |
|
| 87 |
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
|
| 88 |
|
| 89 |
if (offload) {
|
| 90 |
-
auto * dev = model.dev_layer(
|
| 91 |
buft = ggml_backend_dev_buffer_type(dev);
|
| 92 |
|
| 93 |
dev_name = ggml_backend_dev_name(dev);
|
| 94 |
}
|
| 95 |
|
| 96 |
-
LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__,
|
| 97 |
|
| 98 |
ggml_context * ctx = ctx_for_buft(buft);
|
| 99 |
if (!ctx) {
|
| 100 |
throw std::runtime_error("failed to create ggml context for kv cache");
|
| 101 |
}
|
| 102 |
|
| 103 |
-
ggml_tensor * k
|
| 104 |
-
ggml_tensor * v
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
}
|
| 110 |
|
| 111 |
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
|
@@ -117,8 +117,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 117 |
if (!buf) {
|
| 118 |
throw std::runtime_error("failed to allocate buffer for kv cache");
|
| 119 |
}
|
| 120 |
-
|
| 121 |
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
|
|
|
|
|
|
| 122 |
bufs.emplace_back(buf);
|
| 123 |
}
|
| 124 |
|
|
@@ -126,20 +128,17 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 126 |
const size_t memory_size_k = size_k_bytes();
|
| 127 |
const size_t memory_size_v = size_v_bytes();
|
| 128 |
|
| 129 |
-
LLAMA_LOG_INFO("%s:
|
| 130 |
-
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
| 131 |
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
| 132 |
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
| 133 |
}
|
| 134 |
}
|
| 135 |
|
| 136 |
void llama_kv_cache_unified::clear() {
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
cells[i].seq_id.clear();
|
| 140 |
-
}
|
| 141 |
head = 0;
|
| 142 |
-
used = 0;
|
| 143 |
|
| 144 |
for (auto & buf : bufs) {
|
| 145 |
ggml_backend_buffer_clear(buf.get(), 0);
|
|
@@ -147,7 +146,7 @@ void llama_kv_cache_unified::clear() {
|
|
| 147 |
}
|
| 148 |
|
| 149 |
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
| 150 |
-
uint32_t new_head = size;
|
| 151 |
|
| 152 |
if (p0 < 0) {
|
| 153 |
p0 = 0;
|
|
@@ -157,32 +156,20 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|
| 157 |
p1 = std::numeric_limits<llama_pos>::max();
|
| 158 |
}
|
| 159 |
|
| 160 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 161 |
-
if (cells
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
} else if (cells[i].has_seq_id(seq_id)) {
|
| 165 |
-
cells[i].seq_id.erase(seq_id);
|
| 166 |
-
} else {
|
| 167 |
-
continue;
|
| 168 |
-
}
|
| 169 |
-
if (cells[i].is_empty()) {
|
| 170 |
-
// keep count of the number of used cells
|
| 171 |
-
if (cells[i].pos >= 0) {
|
| 172 |
-
used--;
|
| 173 |
-
}
|
| 174 |
-
|
| 175 |
-
cells[i].pos = -1;
|
| 176 |
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
}
|
| 181 |
}
|
| 182 |
}
|
| 183 |
|
| 184 |
// If we freed up a slot, set head to it so searching can start there.
|
| 185 |
-
if (new_head != size && new_head < head) {
|
| 186 |
head = new_head;
|
| 187 |
}
|
| 188 |
|
|
@@ -202,49 +189,40 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
|
|
| 202 |
p1 = std::numeric_limits<llama_pos>::max();
|
| 203 |
}
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
cells[i].seq_id.insert(seq_id_dst);
|
| 211 |
}
|
| 212 |
}
|
| 213 |
}
|
| 214 |
|
| 215 |
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
| 216 |
-
uint32_t new_head = size;
|
| 217 |
|
| 218 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 219 |
-
if (
|
| 220 |
-
if (cells
|
| 221 |
-
used--;
|
| 222 |
-
}
|
| 223 |
-
|
| 224 |
-
cells[i].pos = -1;
|
| 225 |
-
cells[i].seq_id.clear();
|
| 226 |
-
|
| 227 |
-
if (new_head == size){
|
| 228 |
new_head = i;
|
| 229 |
}
|
| 230 |
-
} else {
|
| 231 |
-
cells[i].seq_id.clear();
|
| 232 |
-
cells[i].seq_id.insert(seq_id);
|
| 233 |
}
|
| 234 |
}
|
| 235 |
|
| 236 |
// If we freed up a slot, set head to it so searching can start there.
|
| 237 |
-
if (new_head != size && new_head < head) {
|
| 238 |
head = new_head;
|
| 239 |
}
|
| 240 |
}
|
| 241 |
|
| 242 |
-
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos
|
| 243 |
-
if (
|
| 244 |
return;
|
| 245 |
}
|
| 246 |
|
| 247 |
-
uint32_t new_head = size;
|
| 248 |
|
| 249 |
if (p0 < 0) {
|
| 250 |
p0 = 0;
|
|
@@ -254,24 +232,19 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
|
| 254 |
p1 = std::numeric_limits<llama_pos>::max();
|
| 255 |
}
|
| 256 |
|
| 257 |
-
// If there is no range then return early to avoid looping over
|
| 258 |
if (p0 == p1) {
|
| 259 |
return;
|
| 260 |
}
|
| 261 |
|
| 262 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 263 |
-
if (cells
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
cells[i].delta += delta;
|
| 267 |
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
}
|
| 272 |
-
cells[i].pos = -1;
|
| 273 |
-
cells[i].seq_id.clear();
|
| 274 |
-
if (new_head == size) {
|
| 275 |
new_head = i;
|
| 276 |
}
|
| 277 |
}
|
|
@@ -280,7 +253,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
|
| 280 |
|
| 281 |
// If we freed up a slot, set head to it so searching can start there.
|
| 282 |
// Otherwise we just start the next search from the beginning.
|
| 283 |
-
head = new_head != size ? new_head : 0;
|
| 284 |
}
|
| 285 |
|
| 286 |
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
@@ -301,66 +274,41 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
|
| 301 |
return;
|
| 302 |
}
|
| 303 |
|
| 304 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 305 |
-
if (cells
|
| 306 |
-
|
|
|
|
| 307 |
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
cells[i].pos /= d;
|
| 311 |
-
cells[i].delta += cells[i].pos - p_old;
|
| 312 |
-
}
|
| 313 |
}
|
| 314 |
}
|
| 315 |
}
|
| 316 |
|
| 317 |
-
llama_pos llama_kv_cache_unified::
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 321 |
-
if (cells[i].has_seq_id(seq_id)) {
|
| 322 |
-
result = std::max(result, cells[i].pos);
|
| 323 |
-
}
|
| 324 |
-
}
|
| 325 |
|
| 326 |
-
|
|
|
|
| 327 |
}
|
| 328 |
|
| 329 |
void llama_kv_cache_unified::restore() {
|
| 330 |
-
|
| 331 |
-
|
| 332 |
}
|
| 333 |
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
for (auto & range : pending.ranges) {
|
| 337 |
-
for (uint32_t i = range.c0; i < range.c1; ++i) {
|
| 338 |
-
cells[i].seq_id.clear();
|
| 339 |
-
|
| 340 |
-
// keep count of the number of used cells
|
| 341 |
-
if (cells[i].pos >= 0) {
|
| 342 |
-
used--;
|
| 343 |
-
}
|
| 344 |
-
|
| 345 |
-
cells[i].pos = -1;
|
| 346 |
-
}
|
| 347 |
-
|
| 348 |
-
new_head = std::min(new_head, range.c0);
|
| 349 |
-
}
|
| 350 |
-
|
| 351 |
-
if (new_head != size && new_head < head) {
|
| 352 |
-
head = new_head;
|
| 353 |
-
}
|
| 354 |
}
|
| 355 |
|
| 356 |
void llama_kv_cache_unified::commit() {
|
| 357 |
-
if (
|
| 358 |
-
LLAMA_LOG_WARN("%s:
|
| 359 |
-
__func__, "https://github.com/ggml-org/llama.cpp/pull/
|
| 360 |
return;
|
| 361 |
}
|
| 362 |
|
| 363 |
-
|
| 364 |
}
|
| 365 |
|
| 366 |
bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
@@ -368,7 +316,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
| 368 |
|
| 369 |
auto * sched = lctx.get_sched();
|
| 370 |
|
| 371 |
-
if (
|
| 372 |
if (!get_can_shift()) {
|
| 373 |
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
| 374 |
}
|
|
@@ -392,13 +340,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
| 392 |
need_reserve = true;
|
| 393 |
}
|
| 394 |
|
| 395 |
-
|
| 396 |
-
has_shift = false;
|
| 397 |
-
|
| 398 |
-
for (uint32_t i = 0; i < size; ++i) {
|
| 399 |
-
cells[i].delta = 0;
|
| 400 |
-
}
|
| 401 |
-
}
|
| 402 |
}
|
| 403 |
|
| 404 |
if (do_defrag) {
|
|
@@ -429,7 +371,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
| 429 |
void llama_kv_cache_unified::defrag_sched(float thold) {
|
| 430 |
// - do not defrag small contexts (i.e. < 2048 tokens)
|
| 431 |
// - count the padding towards the number of used tokens
|
| 432 |
-
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(
|
| 433 |
|
| 434 |
// queue defragmentation for next llama_kv_cache_update
|
| 435 |
if (fragmentation > thold) {
|
|
@@ -440,7 +382,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
|
|
| 440 |
}
|
| 441 |
|
| 442 |
void llama_kv_cache_unified::set_full() {
|
| 443 |
-
n = size;
|
| 444 |
|
| 445 |
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
|
| 446 |
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
|
|
@@ -450,51 +392,67 @@ void llama_kv_cache_unified::set_full() {
|
|
| 450 |
head = 0;
|
| 451 |
}
|
| 452 |
|
| 453 |
-
llama_sbatch llama_kv_cache_unified::sbatch_init(
|
| 454 |
-
const llama_batch & batch,
|
| 455 |
-
bool logits_all) {
|
| 456 |
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
| 457 |
}
|
| 458 |
|
| 459 |
-
llama_ubatch llama_kv_cache_unified::ubatch_next(
|
| 460 |
-
llama_sbatch & sbatch,
|
| 461 |
-
uint32_t n_ubatch,
|
| 462 |
-
bool embd_pooled) const {
|
| 463 |
GGML_UNUSED(embd_pooled);
|
| 464 |
return sbatch.split_simple(n_ubatch);
|
| 465 |
}
|
| 466 |
|
| 467 |
-
bool llama_kv_cache_unified::find_slot(
|
| 468 |
-
const llama_ubatch & ubatch) {
|
| 469 |
const uint32_t n_tokens = ubatch.n_tokens;
|
| 470 |
-
const uint32_t n_seqs = ubatch.n_seqs;
|
| 471 |
-
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 472 |
|
| 473 |
// if we have enough unused cells before the current head ->
|
| 474 |
// better to start searching from the beginning of the cache, hoping to fill it
|
| 475 |
-
if (head >
|
| 476 |
head = 0;
|
| 477 |
}
|
| 478 |
|
| 479 |
// otherwise, one cell per token.
|
| 480 |
|
| 481 |
-
if (n_tokens > size) {
|
| 482 |
-
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %
|
| 483 |
return false;
|
| 484 |
}
|
| 485 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
uint32_t n_tested = 0;
|
| 487 |
|
| 488 |
while (true) {
|
| 489 |
-
if (head + n_tokens > size) {
|
| 490 |
-
n_tested += size - head;
|
| 491 |
head = 0;
|
| 492 |
continue;
|
| 493 |
}
|
| 494 |
|
| 495 |
bool found = true;
|
| 496 |
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 497 |
-
|
|
|
|
| 498 |
found = false;
|
| 499 |
head += i + 1;
|
| 500 |
n_tested += i + 1;
|
|
@@ -506,66 +464,257 @@ bool llama_kv_cache_unified::find_slot(
|
|
| 506 |
break;
|
| 507 |
}
|
| 508 |
|
| 509 |
-
if (n_tested >= size) {
|
| 510 |
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
| 511 |
return false;
|
| 512 |
}
|
| 513 |
}
|
| 514 |
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
uint32_t k = s*n_seq_tokens + i;
|
| 518 |
-
cells[head + k].pos = ubatch.pos[k];
|
| 519 |
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
|
|
|
|
|
|
| 523 |
}
|
| 524 |
}
|
| 525 |
|
| 526 |
-
used += n_tokens;
|
| 527 |
-
|
| 528 |
-
pending.ranges.push_back({head, head + n_tokens});
|
| 529 |
-
|
| 530 |
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
| 531 |
// after enough generations, the benefit from this heuristic disappears
|
| 532 |
// if we start defragmenting the cache, the benefit from this will be more important
|
| 533 |
-
n = std::min(size, std::max(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
|
| 535 |
-
|
|
|
|
| 536 |
|
|
|
|
| 537 |
return true;
|
| 538 |
}
|
| 539 |
|
| 540 |
-
|
| 541 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
|
| 543 |
-
|
| 544 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
}
|
| 546 |
|
| 547 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
}
|
| 549 |
|
| 550 |
-
|
| 551 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
}
|
| 553 |
|
| 554 |
-
|
| 555 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
}
|
| 557 |
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
|
|
|
|
|
|
|
|
|
| 562 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
|
| 564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
}
|
| 566 |
|
| 567 |
size_t llama_kv_cache_unified::total_size() const {
|
| 568 |
size_t size = 0;
|
|
|
|
| 569 |
for (const auto & buf : bufs) {
|
| 570 |
size += ggml_backend_buffer_get_size(buf.get());
|
| 571 |
}
|
|
@@ -576,8 +725,8 @@ size_t llama_kv_cache_unified::total_size() const {
|
|
| 576 |
size_t llama_kv_cache_unified::size_k_bytes() const {
|
| 577 |
size_t size_k_bytes = 0;
|
| 578 |
|
| 579 |
-
for (const auto &
|
| 580 |
-
size_k_bytes += ggml_nbytes(k);
|
| 581 |
}
|
| 582 |
|
| 583 |
return size_k_bytes;
|
|
@@ -586,8 +735,8 @@ size_t llama_kv_cache_unified::size_k_bytes() const {
|
|
| 586 |
size_t llama_kv_cache_unified::size_v_bytes() const {
|
| 587 |
size_t size_v_bytes = 0;
|
| 588 |
|
| 589 |
-
for (const auto &
|
| 590 |
-
size_v_bytes += ggml_nbytes(v);
|
| 591 |
}
|
| 592 |
|
| 593 |
return size_v_bytes;
|
|
@@ -651,13 +800,7 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
|
| 651 |
GGML_UNUSED(ubatch);
|
| 652 |
|
| 653 |
if (k_shift) {
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
int32_t * data = (int32_t *) k_shift->data;
|
| 657 |
-
|
| 658 |
-
for (uint32_t i = 0; i < kv_self->size; ++i) {
|
| 659 |
-
data[i] = kv_self->cells[i].delta;
|
| 660 |
-
}
|
| 661 |
}
|
| 662 |
}
|
| 663 |
|
|
@@ -667,13 +810,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|
| 667 |
ggml_cgraph * gf) const {
|
| 668 |
auto res = std::make_unique<llm_graph_result>();
|
| 669 |
|
| 670 |
-
const auto & n_layer = hparams.n_layer;
|
| 671 |
-
|
| 672 |
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
| 673 |
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
| 674 |
|
| 675 |
-
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
| 676 |
-
|
| 677 |
//GGML_ASSERT(kv_self->size == n_ctx);
|
| 678 |
|
| 679 |
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
|
@@ -681,24 +820,22 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|
| 681 |
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
|
| 682 |
ggml_set_input(inp->k_shift);
|
| 683 |
|
| 684 |
-
for (
|
|
|
|
|
|
|
| 685 |
const int64_t n_head_kv = hparams.n_head_kv(il);
|
| 686 |
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 687 |
|
| 688 |
-
const
|
|
|
|
| 689 |
|
| 690 |
-
|
| 691 |
-
// if we decide to make them configurable, like the non-sliding ones
|
| 692 |
-
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
|
| 693 |
-
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
|
| 694 |
-
|
| 695 |
-
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
| 696 |
|
| 697 |
ggml_tensor * k =
|
| 698 |
-
ggml_view_3d(ctx,
|
| 699 |
-
n_embd_head_k, n_head_kv, size,
|
| 700 |
-
ggml_row_size(
|
| 701 |
-
ggml_row_size(
|
| 702 |
0);
|
| 703 |
|
| 704 |
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
|
|
@@ -803,44 +940,46 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|
| 803 |
nm++;
|
| 804 |
}
|
| 805 |
|
| 806 |
-
for (
|
|
|
|
|
|
|
| 807 |
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 808 |
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 809 |
|
| 810 |
-
ggml_tensor * view_k_src = ggml_view_2d(ctx,
|
| 811 |
n_embd_k_gqa, nm,
|
| 812 |
-
ggml_row_size(
|
| 813 |
-
ggml_row_size(
|
| 814 |
|
| 815 |
-
ggml_tensor * view_k_dst = ggml_view_2d(ctx,
|
| 816 |
n_embd_k_gqa, nm,
|
| 817 |
-
ggml_row_size(
|
| 818 |
-
ggml_row_size(
|
| 819 |
|
| 820 |
ggml_tensor * view_v_src;
|
| 821 |
ggml_tensor * view_v_dst;
|
| 822 |
|
| 823 |
if (cparams.flash_attn) {
|
| 824 |
// NOTE: the V cache is not transposed when using flash attention
|
| 825 |
-
view_v_src = ggml_view_2d(ctx,
|
| 826 |
n_embd_v_gqa, nm,
|
| 827 |
-
ggml_row_size(
|
| 828 |
-
ggml_row_size(
|
| 829 |
|
| 830 |
-
view_v_dst = ggml_view_2d(ctx,
|
| 831 |
n_embd_v_gqa, nm,
|
| 832 |
-
ggml_row_size(
|
| 833 |
-
ggml_row_size(
|
| 834 |
} else {
|
| 835 |
-
view_v_src = ggml_view_2d(ctx,
|
| 836 |
nm, n_embd_v_gqa,
|
| 837 |
-
ggml_row_size(
|
| 838 |
-
ggml_row_size(
|
| 839 |
|
| 840 |
-
view_v_dst = ggml_view_2d(ctx,
|
| 841 |
nm, n_embd_v_gqa,
|
| 842 |
-
ggml_row_size(
|
| 843 |
-
ggml_row_size(
|
| 844 |
}
|
| 845 |
|
| 846 |
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
|
|
@@ -857,10 +996,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|
| 857 |
}
|
| 858 |
|
| 859 |
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
| 860 |
-
const uint32_t n_layer =
|
| 861 |
|
| 862 |
-
const uint32_t n_kv =
|
| 863 |
-
const uint32_t n_used =
|
| 864 |
|
| 865 |
assert(n_used <= n_kv);
|
| 866 |
|
|
@@ -888,9 +1027,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
| 888 |
ids.resize(n_kv, n_kv);
|
| 889 |
|
| 890 |
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
if (!cell0.is_empty()) {
|
| 894 |
ids[i0] = i0;
|
| 895 |
|
| 896 |
continue;
|
|
@@ -901,7 +1038,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
| 901 |
uint32_t nh = 1;
|
| 902 |
|
| 903 |
// determine the size of the hole
|
| 904 |
-
while (i0 + nh < n_used && cells
|
| 905 |
nh++;
|
| 906 |
}
|
| 907 |
|
|
@@ -910,9 +1047,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
| 910 |
|
| 911 |
// starting from the end, find nh non-empty cells
|
| 912 |
for (; is > i0; --is) {
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
if (cell1.is_empty() || ids[is] != n_kv) {
|
| 916 |
continue;
|
| 917 |
}
|
| 918 |
|
|
@@ -939,9 +1074,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
| 939 |
|
| 940 |
// go back and move the nf cells to the hole
|
| 941 |
for (; i1 < n_kv; ++i1) {
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
if (cell1.is_empty() || ids[i1] != n_kv) {
|
| 945 |
if (n_moves == max_moves) {
|
| 946 |
stop = true;
|
| 947 |
break;
|
|
@@ -955,10 +1088,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
| 955 |
ids[i1] = i0 + nf;
|
| 956 |
|
| 957 |
// move the cell meta data
|
| 958 |
-
cells
|
| 959 |
|
| 960 |
-
// clear the old cell and move the head there
|
| 961 |
-
cell1 = kv_cell();
|
| 962 |
head = n_used;
|
| 963 |
|
| 964 |
if (!cont) {
|
|
@@ -993,16 +1124,30 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
| 993 |
return true;
|
| 994 |
}
|
| 995 |
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
const kv_cell & cell = cells[i - 1];
|
| 999 |
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1003 |
}
|
| 1004 |
|
| 1005 |
-
return
|
| 1006 |
}
|
| 1007 |
|
| 1008 |
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
|
@@ -1011,23 +1156,24 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
|
|
| 1011 |
|
| 1012 |
// Count the number of cells with the specified seq_id
|
| 1013 |
// Find all the ranges of cells with this seq id (or all, when -1)
|
| 1014 |
-
uint32_t cell_range_begin = size;
|
| 1015 |
-
|
| 1016 |
-
|
| 1017 |
-
if ((seq_id == -1
|
| 1018 |
++cell_count;
|
| 1019 |
-
if (cell_range_begin == size) {
|
| 1020 |
cell_range_begin = i;
|
| 1021 |
}
|
| 1022 |
} else {
|
| 1023 |
-
if (cell_range_begin != size) {
|
| 1024 |
cell_ranges.emplace_back(cell_range_begin, i);
|
| 1025 |
-
cell_range_begin = size;
|
| 1026 |
}
|
| 1027 |
}
|
| 1028 |
}
|
| 1029 |
-
|
| 1030 |
-
|
|
|
|
| 1031 |
}
|
| 1032 |
|
| 1033 |
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
|
@@ -1064,17 +1210,24 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
|
| 1064 |
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
| 1065 |
for (const auto & range : cell_ranges) {
|
| 1066 |
for (uint32_t i = range.first; i < range.second; ++i) {
|
| 1067 |
-
|
| 1068 |
-
|
| 1069 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1070 |
|
| 1071 |
io.write(&pos, sizeof(pos));
|
| 1072 |
io.write(&n_seq_id, sizeof(n_seq_id));
|
| 1073 |
|
| 1074 |
-
|
| 1075 |
-
|
| 1076 |
-
io.write(&seq_id, sizeof(seq_id));
|
| 1077 |
-
}
|
| 1078 |
}
|
| 1079 |
}
|
| 1080 |
}
|
|
@@ -1082,7 +1235,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
|
|
| 1082 |
|
| 1083 |
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
| 1084 |
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
| 1085 |
-
const uint32_t n_layer =
|
| 1086 |
|
| 1087 |
io.write(&v_trans, sizeof(v_trans));
|
| 1088 |
io.write(&n_layer, sizeof(n_layer));
|
|
@@ -1091,56 +1244,63 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|
| 1091 |
|
| 1092 |
// Iterate and write all the keys first, each row is a cell
|
| 1093 |
// Get whole range at a time
|
| 1094 |
-
for (
|
|
|
|
|
|
|
| 1095 |
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 1096 |
|
| 1097 |
// Write key type
|
| 1098 |
-
const int32_t k_type_i = (int32_t)
|
| 1099 |
io.write(&k_type_i, sizeof(k_type_i));
|
| 1100 |
|
| 1101 |
// Write row size of key
|
| 1102 |
-
const uint64_t k_size_row = ggml_row_size(
|
| 1103 |
io.write(&k_size_row, sizeof(k_size_row));
|
| 1104 |
|
| 1105 |
// Read each range of cells of k_size length each into tmp_buf and write out
|
| 1106 |
for (const auto & range : cell_ranges) {
|
| 1107 |
const size_t range_size = range.second - range.first;
|
| 1108 |
const size_t buf_size = range_size * k_size_row;
|
| 1109 |
-
io.write_tensor(
|
| 1110 |
}
|
| 1111 |
}
|
| 1112 |
|
| 1113 |
if (!v_trans) {
|
| 1114 |
-
for (
|
|
|
|
|
|
|
| 1115 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1116 |
|
| 1117 |
// Write value type
|
| 1118 |
-
const int32_t v_type_i = (int32_t)
|
| 1119 |
io.write(&v_type_i, sizeof(v_type_i));
|
| 1120 |
|
| 1121 |
// Write row size of value
|
| 1122 |
-
const uint64_t v_size_row = ggml_row_size(
|
| 1123 |
io.write(&v_size_row, sizeof(v_size_row));
|
| 1124 |
|
| 1125 |
// Read each range of cells of v_size length each into tmp_buf and write out
|
| 1126 |
for (const auto & range : cell_ranges) {
|
| 1127 |
const size_t range_size = range.second - range.first;
|
| 1128 |
const size_t buf_size = range_size * v_size_row;
|
| 1129 |
-
io.write_tensor(
|
| 1130 |
}
|
| 1131 |
}
|
| 1132 |
} else {
|
| 1133 |
// When v is transposed, we also need the element size and get the element ranges from each row
|
| 1134 |
-
const uint32_t kv_size = size;
|
| 1135 |
-
|
|
|
|
|
|
|
|
|
|
| 1136 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1137 |
|
| 1138 |
// Write value type
|
| 1139 |
-
const int32_t v_type_i = (int32_t)
|
| 1140 |
io.write(&v_type_i, sizeof(v_type_i));
|
| 1141 |
|
| 1142 |
// Write element size
|
| 1143 |
-
const uint32_t v_size_el = ggml_type_size(
|
| 1144 |
io.write(&v_size_el, sizeof(v_size_el));
|
| 1145 |
|
| 1146 |
// Write GQA embedding size
|
|
@@ -1153,7 +1313,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|
| 1153 |
const size_t range_size = range.second - range.first;
|
| 1154 |
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
| 1155 |
const size_t buf_size = range_size * v_size_el;
|
| 1156 |
-
io.write_tensor(
|
| 1157 |
}
|
| 1158 |
}
|
| 1159 |
}
|
|
@@ -1170,8 +1330,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
| 1170 |
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
| 1171 |
|
| 1172 |
batch.n_tokens = cell_count;
|
| 1173 |
-
batch.n_seq_tokens = cell_count;
|
| 1174 |
-
batch.n_seqs = 1;
|
| 1175 |
|
| 1176 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1177 |
llama_pos pos;
|
|
@@ -1180,32 +1338,40 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
| 1180 |
io.read_to(&pos, sizeof(pos));
|
| 1181 |
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 1182 |
|
| 1183 |
-
if (n_seq_id !=
|
| 1184 |
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
| 1185 |
return false;
|
| 1186 |
}
|
| 1187 |
|
| 1188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1189 |
}
|
| 1190 |
-
|
| 1191 |
-
batch.seq_id[0] = &dest_seq_id;
|
| 1192 |
if (!find_slot(batch)) {
|
| 1193 |
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 1194 |
return false;
|
| 1195 |
}
|
|
|
|
| 1196 |
commit();
|
| 1197 |
|
| 1198 |
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 1199 |
// Assume that this is one contiguous block of cells
|
| 1200 |
-
GGML_ASSERT(head + cell_count <= size);
|
| 1201 |
-
GGML_ASSERT(cells
|
| 1202 |
-
GGML_ASSERT(cells
|
| 1203 |
-
GGML_ASSERT(cells
|
| 1204 |
-
GGML_ASSERT(cells
|
| 1205 |
} else {
|
| 1206 |
// whole KV cache restore
|
| 1207 |
|
| 1208 |
-
if (cell_count > size) {
|
| 1209 |
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
| 1210 |
return false;
|
| 1211 |
}
|
|
@@ -1213,34 +1379,28 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
| 1213 |
clear();
|
| 1214 |
|
| 1215 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1216 |
-
kv_cell & cell = cells[i];
|
| 1217 |
-
|
| 1218 |
llama_pos pos;
|
| 1219 |
uint32_t n_seq_id;
|
| 1220 |
|
| 1221 |
io.read_to(&pos, sizeof(pos));
|
| 1222 |
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 1223 |
|
| 1224 |
-
|
| 1225 |
|
| 1226 |
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 1227 |
llama_seq_id seq_id;
|
| 1228 |
io.read_to(&seq_id, sizeof(seq_id));
|
| 1229 |
|
| 1230 |
-
|
| 1231 |
-
|
| 1232 |
-
if (seq_id < 0) {
|
| 1233 |
-
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
| 1234 |
-
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
|
| 1235 |
return false;
|
| 1236 |
}
|
| 1237 |
|
| 1238 |
-
|
| 1239 |
}
|
| 1240 |
}
|
| 1241 |
|
| 1242 |
head = 0;
|
| 1243 |
-
used = cell_count;
|
| 1244 |
}
|
| 1245 |
|
| 1246 |
return true;
|
|
@@ -1249,15 +1409,16 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
| 1249 |
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
| 1250 |
uint32_t v_trans;
|
| 1251 |
uint32_t n_layer;
|
|
|
|
| 1252 |
io.read_to(&v_trans, sizeof(v_trans));
|
| 1253 |
io.read_to(&n_layer, sizeof(n_layer));
|
| 1254 |
|
| 1255 |
-
if (n_layer !=
|
| 1256 |
-
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer,
|
| 1257 |
return false;
|
| 1258 |
}
|
| 1259 |
-
if (cell_count > size) {
|
| 1260 |
-
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
| 1261 |
return false;
|
| 1262 |
}
|
| 1263 |
if (this->v_trans != (bool) v_trans) {
|
|
@@ -1266,13 +1427,15 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1266 |
}
|
| 1267 |
|
| 1268 |
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
| 1269 |
-
for (
|
|
|
|
|
|
|
| 1270 |
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 1271 |
|
| 1272 |
// Read type of key
|
| 1273 |
int32_t k_type_i_ref;
|
| 1274 |
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
| 1275 |
-
const int32_t k_type_i = (int32_t)
|
| 1276 |
if (k_type_i != k_type_i_ref) {
|
| 1277 |
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
| 1278 |
return false;
|
|
@@ -1281,7 +1444,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1281 |
// Read row size of key
|
| 1282 |
uint64_t k_size_row_ref;
|
| 1283 |
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
| 1284 |
-
const size_t k_size_row = ggml_row_size(
|
| 1285 |
if (k_size_row != k_size_row_ref) {
|
| 1286 |
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
| 1287 |
return false;
|
|
@@ -1289,18 +1452,20 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1289 |
|
| 1290 |
if (cell_count) {
|
| 1291 |
// Read and set the keys for the whole cell range
|
| 1292 |
-
ggml_backend_tensor_set(
|
| 1293 |
}
|
| 1294 |
}
|
| 1295 |
|
| 1296 |
if (!this->v_trans) {
|
| 1297 |
-
for (
|
|
|
|
|
|
|
| 1298 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1299 |
|
| 1300 |
// Read type of value
|
| 1301 |
int32_t v_type_i_ref;
|
| 1302 |
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1303 |
-
const int32_t v_type_i = (int32_t)
|
| 1304 |
if (v_type_i != v_type_i_ref) {
|
| 1305 |
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1306 |
return false;
|
|
@@ -1309,7 +1474,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1309 |
// Read row size of value
|
| 1310 |
uint64_t v_size_row_ref;
|
| 1311 |
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
| 1312 |
-
const size_t v_size_row = ggml_row_size(
|
| 1313 |
if (v_size_row != v_size_row_ref) {
|
| 1314 |
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
| 1315 |
return false;
|
|
@@ -1317,18 +1482,20 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1317 |
|
| 1318 |
if (cell_count) {
|
| 1319 |
// Read and set the values for the whole cell range
|
| 1320 |
-
ggml_backend_tensor_set(
|
| 1321 |
}
|
| 1322 |
}
|
| 1323 |
} else {
|
| 1324 |
// For each layer, read the values for each cell (transposed)
|
| 1325 |
-
for (
|
|
|
|
|
|
|
| 1326 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1327 |
|
| 1328 |
// Read type of value
|
| 1329 |
int32_t v_type_i_ref;
|
| 1330 |
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1331 |
-
const int32_t v_type_i = (int32_t)
|
| 1332 |
if (v_type_i != v_type_i_ref) {
|
| 1333 |
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1334 |
return false;
|
|
@@ -1337,7 +1504,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1337 |
// Read element size of value
|
| 1338 |
uint32_t v_size_el_ref;
|
| 1339 |
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
| 1340 |
-
const size_t v_size_el = ggml_type_size(
|
| 1341 |
if (v_size_el != v_size_el_ref) {
|
| 1342 |
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
| 1343 |
return false;
|
|
@@ -1354,8 +1521,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1354 |
if (cell_count) {
|
| 1355 |
// For each row in the transposed matrix, read the values for the whole cell range
|
| 1356 |
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 1357 |
-
const size_t dst_offset = (head + j * size) * v_size_el;
|
| 1358 |
-
ggml_backend_tensor_set(
|
| 1359 |
}
|
| 1360 |
}
|
| 1361 |
}
|
|
@@ -1364,6 +1531,193 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1364 |
return true;
|
| 1365 |
}
|
| 1366 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1367 |
//
|
| 1368 |
// llama_kv_cache_recurrent
|
| 1369 |
//
|
|
@@ -1373,19 +1727,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|
| 1373 |
ggml_type type_k,
|
| 1374 |
ggml_type type_v,
|
| 1375 |
bool offload,
|
| 1376 |
-
uint32_t kv_size
|
|
|
|
| 1377 |
const int32_t n_layer = hparams.n_layer;
|
| 1378 |
|
| 1379 |
-
LLAMA_LOG_INFO("%s: kv_size = %
|
| 1380 |
-
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
|
| 1381 |
|
| 1382 |
head = 0;
|
| 1383 |
size = kv_size;
|
| 1384 |
used = 0;
|
| 1385 |
|
| 1386 |
-
this->type_k = type_k;
|
| 1387 |
-
this->type_v = type_v;
|
| 1388 |
-
|
| 1389 |
cells.clear();
|
| 1390 |
cells.resize(kv_size);
|
| 1391 |
|
|
@@ -1623,8 +1975,8 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
|
| 1623 |
}
|
| 1624 |
}
|
| 1625 |
|
| 1626 |
-
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos
|
| 1627 |
-
if (
|
| 1628 |
return;
|
| 1629 |
}
|
| 1630 |
|
|
@@ -1647,7 +1999,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
|
| 1647 |
if (tail_id >= 0) {
|
| 1648 |
kv_cell & cell = cells[tail_id];
|
| 1649 |
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 1650 |
-
cell.pos +=
|
| 1651 |
}
|
| 1652 |
}
|
| 1653 |
}
|
|
@@ -1683,8 +2035,24 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
|
| 1683 |
}
|
| 1684 |
}
|
| 1685 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1686 |
llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
| 1687 |
-
llama_pos result =
|
| 1688 |
|
| 1689 |
for (uint32_t i = 0; i < size; ++i) {
|
| 1690 |
if (cells[i].has_seq_id(seq_id)) {
|
|
@@ -1707,8 +2075,8 @@ void llama_kv_cache_recurrent::commit() {
|
|
| 1707 |
pending.ranges.clear();
|
| 1708 |
}
|
| 1709 |
|
| 1710 |
-
bool llama_kv_cache_recurrent::update(llama_context &
|
| 1711 |
-
GGML_UNUSED(
|
| 1712 |
return false;
|
| 1713 |
}
|
| 1714 |
|
|
@@ -1769,7 +2137,7 @@ bool llama_kv_cache_recurrent::find_slot(
|
|
| 1769 |
if (seq_id < 0 || (uint32_t) seq_id >= size) {
|
| 1770 |
// too big seq_id
|
| 1771 |
// TODO: would it be possible to resize the cache instead?
|
| 1772 |
-
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%
|
| 1773 |
return false;
|
| 1774 |
}
|
| 1775 |
if (j > 0) {
|
|
@@ -1912,29 +2280,6 @@ bool llama_kv_cache_recurrent::find_slot(
|
|
| 1912 |
return n >= n_seqs;
|
| 1913 |
}
|
| 1914 |
|
| 1915 |
-
int32_t llama_kv_cache_recurrent::get_n_tokens() const {
|
| 1916 |
-
int32_t result = 0;
|
| 1917 |
-
|
| 1918 |
-
for (uint32_t i = 0; i < size; i++) {
|
| 1919 |
-
result += cells[i].seq_id.size();
|
| 1920 |
-
}
|
| 1921 |
-
|
| 1922 |
-
return result;
|
| 1923 |
-
}
|
| 1924 |
-
|
| 1925 |
-
int32_t llama_kv_cache_recurrent::get_used_cells() const {
|
| 1926 |
-
return used;
|
| 1927 |
-
}
|
| 1928 |
-
|
| 1929 |
-
llama_pos llama_kv_cache_recurrent::get_pos_max() const {
|
| 1930 |
-
llama_pos pos_max = -1;
|
| 1931 |
-
for (const auto & cell : cells) {
|
| 1932 |
-
pos_max = std::max(pos_max, cell.pos);
|
| 1933 |
-
}
|
| 1934 |
-
|
| 1935 |
-
return pos_max;
|
| 1936 |
-
}
|
| 1937 |
-
|
| 1938 |
bool llama_kv_cache_recurrent::get_can_shift() const {
|
| 1939 |
return false;
|
| 1940 |
}
|
|
@@ -2063,6 +2408,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
|
|
| 2063 |
io.read_to(&cell_count, sizeof(cell_count));
|
| 2064 |
|
| 2065 |
bool res = true;
|
|
|
|
| 2066 |
res = res && state_read_meta(io, cell_count, seq_id);
|
| 2067 |
res = res && state_read_data(io, cell_count);
|
| 2068 |
|
|
@@ -2391,104 +2737,3 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|
| 2391 |
|
| 2392 |
return true;
|
| 2393 |
}
|
| 2394 |
-
|
| 2395 |
-
//
|
| 2396 |
-
// kv cache view
|
| 2397 |
-
//
|
| 2398 |
-
|
| 2399 |
-
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) {
|
| 2400 |
-
llama_kv_cache_view result = {
|
| 2401 |
-
/*.n_cells = */ 0,
|
| 2402 |
-
/*.n_seq_max = */ n_seq_max,
|
| 2403 |
-
/*.token_count = */ 0,
|
| 2404 |
-
/*.used_cells = */ kv.get_used_cells(),
|
| 2405 |
-
/*.max_contiguous = */ 0,
|
| 2406 |
-
/*.max_contiguous_idx = */ -1,
|
| 2407 |
-
/*.cells = */ nullptr,
|
| 2408 |
-
/*.cells_sequences = */ nullptr,
|
| 2409 |
-
};
|
| 2410 |
-
|
| 2411 |
-
return result;
|
| 2412 |
-
}
|
| 2413 |
-
|
| 2414 |
-
void llama_kv_cache_view_free(llama_kv_cache_view * view) {
|
| 2415 |
-
if (view->cells != nullptr) {
|
| 2416 |
-
free(view->cells);
|
| 2417 |
-
view->cells = nullptr;
|
| 2418 |
-
}
|
| 2419 |
-
if (view->cells_sequences != nullptr) {
|
| 2420 |
-
free(view->cells_sequences);
|
| 2421 |
-
view->cells_sequences = nullptr;
|
| 2422 |
-
}
|
| 2423 |
-
}
|
| 2424 |
-
|
| 2425 |
-
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv) {
|
| 2426 |
-
// TODO: rework this in the future, for now quick hack
|
| 2427 |
-
const llama_kv_cache_unified * kvu = dynamic_cast<const llama_kv_cache_unified *>(kv);
|
| 2428 |
-
if (kvu == nullptr) {
|
| 2429 |
-
LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__);
|
| 2430 |
-
return;
|
| 2431 |
-
}
|
| 2432 |
-
|
| 2433 |
-
if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) {
|
| 2434 |
-
view->n_cells = int32_t(kvu->size);
|
| 2435 |
-
void * p = realloc(view->cells, sizeof(llama_kv_cache_view_cell) * view->n_cells);
|
| 2436 |
-
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
|
| 2437 |
-
view->cells = (llama_kv_cache_view_cell *)p;
|
| 2438 |
-
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
|
| 2439 |
-
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
|
| 2440 |
-
view->cells_sequences = (llama_seq_id *)p;
|
| 2441 |
-
}
|
| 2442 |
-
|
| 2443 |
-
const std::vector<llama_kv_cache_unified::kv_cell> & kv_cells = kvu->cells;
|
| 2444 |
-
llama_kv_cache_view_cell * c_curr = view->cells;
|
| 2445 |
-
llama_seq_id * cs_curr = view->cells_sequences;
|
| 2446 |
-
int32_t used_cells = 0;
|
| 2447 |
-
int32_t token_count = 0;
|
| 2448 |
-
int32_t curr_contig_idx = -1;
|
| 2449 |
-
uint32_t max_contig = 0;
|
| 2450 |
-
int32_t max_contig_idx = -1;
|
| 2451 |
-
|
| 2452 |
-
for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) {
|
| 2453 |
-
const size_t curr_size = kv_cells[i].seq_id.size();
|
| 2454 |
-
token_count += curr_size;
|
| 2455 |
-
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
|
| 2456 |
-
|
| 2457 |
-
if (curr_size > 0) {
|
| 2458 |
-
if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
|
| 2459 |
-
max_contig = i - curr_contig_idx;
|
| 2460 |
-
max_contig_idx = curr_contig_idx;
|
| 2461 |
-
}
|
| 2462 |
-
curr_contig_idx = -1;
|
| 2463 |
-
} else if (curr_contig_idx < 0) {
|
| 2464 |
-
curr_contig_idx = i;
|
| 2465 |
-
}
|
| 2466 |
-
|
| 2467 |
-
int seq_idx = 0;
|
| 2468 |
-
for (const llama_seq_id it : kv_cells[i].seq_id) {
|
| 2469 |
-
if (seq_idx >= view->n_seq_max) {
|
| 2470 |
-
break;
|
| 2471 |
-
}
|
| 2472 |
-
cs_curr[seq_idx] = it;
|
| 2473 |
-
seq_idx++;
|
| 2474 |
-
}
|
| 2475 |
-
if (seq_idx != 0) {
|
| 2476 |
-
used_cells++;
|
| 2477 |
-
}
|
| 2478 |
-
for (; seq_idx < view->n_seq_max; seq_idx++) {
|
| 2479 |
-
cs_curr[seq_idx] = -1;
|
| 2480 |
-
}
|
| 2481 |
-
}
|
| 2482 |
-
if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
|
| 2483 |
-
max_contig_idx = curr_contig_idx;
|
| 2484 |
-
max_contig = kv_cells.size() - curr_contig_idx;
|
| 2485 |
-
}
|
| 2486 |
-
view->max_contiguous = max_contig;
|
| 2487 |
-
view->max_contiguous_idx = max_contig_idx;
|
| 2488 |
-
view->token_count = token_count;
|
| 2489 |
-
view->used_cells = used_cells;
|
| 2490 |
-
if (uint32_t(used_cells) != kvu->used) {
|
| 2491 |
-
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
|
| 2492 |
-
__func__, kvu->used, used_cells);
|
| 2493 |
-
}
|
| 2494 |
-
}
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
llama_kv_cache_unified::llama_kv_cache_unified(
|
| 26 |
+
const llama_model & model,
|
| 27 |
+
layer_filter_cb && filter,
|
| 28 |
+
ggml_type type_k,
|
| 29 |
+
ggml_type type_v,
|
| 30 |
+
bool v_trans,
|
| 31 |
+
bool offload,
|
| 32 |
+
uint32_t kv_size,
|
| 33 |
+
uint32_t n_seq_max,
|
| 34 |
+
uint32_t n_pad,
|
| 35 |
+
uint32_t n_swa,
|
| 36 |
+
llama_swa_type swa_type) :
|
| 37 |
+
model(model), hparams(model.hparams), v_trans(v_trans),
|
| 38 |
+
n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
| 39 |
+
|
| 40 |
+
GGML_ASSERT(kv_size % n_pad == 0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
// create a context for each buffer type
|
| 43 |
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
|
|
|
| 45 |
auto it = ctx_map.find(buft);
|
| 46 |
if (it == ctx_map.end()) {
|
| 47 |
ggml_init_params params = {
|
| 48 |
+
/*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
|
| 49 |
/*.mem_buffer =*/ NULL,
|
| 50 |
/*.no_alloc =*/ true,
|
| 51 |
};
|
|
|
|
| 64 |
return it->second;
|
| 65 |
};
|
| 66 |
|
| 67 |
+
head = 0;
|
|
|
|
| 68 |
|
| 69 |
+
cells.resize(kv_size);
|
| 70 |
+
|
| 71 |
+
for (uint32_t il = 0; il < hparams.n_layer; il++) {
|
| 72 |
+
if (filter && !filter(il)) {
|
| 73 |
+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
| 74 |
+
continue;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 78 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 79 |
|
| 80 |
const char * dev_name = "CPU";
|
| 81 |
|
| 82 |
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
|
| 83 |
|
| 84 |
if (offload) {
|
| 85 |
+
auto * dev = model.dev_layer(il);
|
| 86 |
buft = ggml_backend_dev_buffer_type(dev);
|
| 87 |
|
| 88 |
dev_name = ggml_backend_dev_name(dev);
|
| 89 |
}
|
| 90 |
|
| 91 |
+
LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
|
| 92 |
|
| 93 |
ggml_context * ctx = ctx_for_buft(buft);
|
| 94 |
if (!ctx) {
|
| 95 |
throw std::runtime_error("failed to create ggml context for kv cache");
|
| 96 |
}
|
| 97 |
|
| 98 |
+
ggml_tensor * k;
|
| 99 |
+
ggml_tensor * v;
|
| 100 |
+
|
| 101 |
+
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
|
| 102 |
+
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
|
| 103 |
+
|
| 104 |
+
ggml_format_name(k, "cache_k_l%d", il);
|
| 105 |
+
ggml_format_name(v, "cache_v_l%d", il);
|
| 106 |
+
|
| 107 |
+
map_layer_ids[il] = layers.size();
|
| 108 |
+
layers.push_back({ il, k, v });
|
| 109 |
}
|
| 110 |
|
| 111 |
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
|
|
|
| 117 |
if (!buf) {
|
| 118 |
throw std::runtime_error("failed to allocate buffer for kv cache");
|
| 119 |
}
|
| 120 |
+
|
| 121 |
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
| 122 |
+
|
| 123 |
+
ggml_backend_buffer_clear(buf, 0);
|
| 124 |
bufs.emplace_back(buf);
|
| 125 |
}
|
| 126 |
|
|
|
|
| 128 |
const size_t memory_size_k = size_k_bytes();
|
| 129 |
const size_t memory_size_v = size_v_bytes();
|
| 130 |
|
| 131 |
+
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
| 132 |
+
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
|
| 133 |
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
| 134 |
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
| 135 |
}
|
| 136 |
}
|
| 137 |
|
| 138 |
void llama_kv_cache_unified::clear() {
|
| 139 |
+
cells.reset();
|
| 140 |
+
|
|
|
|
|
|
|
| 141 |
head = 0;
|
|
|
|
| 142 |
|
| 143 |
for (auto & buf : bufs) {
|
| 144 |
ggml_backend_buffer_clear(buf.get(), 0);
|
|
|
|
| 146 |
}
|
| 147 |
|
| 148 |
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
| 149 |
+
uint32_t new_head = cells.size();
|
| 150 |
|
| 151 |
if (p0 < 0) {
|
| 152 |
p0 = 0;
|
|
|
|
| 156 |
p1 = std::numeric_limits<llama_pos>::max();
|
| 157 |
}
|
| 158 |
|
| 159 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 160 |
+
if (!cells.pos_in(i, p0, p1)) {
|
| 161 |
+
continue;
|
| 162 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
| 165 |
+
if (new_head == cells.size()) {
|
| 166 |
+
new_head = i;
|
| 167 |
}
|
| 168 |
}
|
| 169 |
}
|
| 170 |
|
| 171 |
// If we freed up a slot, set head to it so searching can start there.
|
| 172 |
+
if (new_head != cells.size() && new_head < head) {
|
| 173 |
head = new_head;
|
| 174 |
}
|
| 175 |
|
|
|
|
| 189 |
p1 = std::numeric_limits<llama_pos>::max();
|
| 190 |
}
|
| 191 |
|
| 192 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 193 |
+
if (!cells.pos_in(i, p0, p1)) {
|
| 194 |
+
continue;
|
| 195 |
+
}
|
| 196 |
|
| 197 |
+
if (cells.seq_has(i, seq_id_src)) {
|
| 198 |
+
cells.seq_add(i, seq_id_dst);
|
|
|
|
| 199 |
}
|
| 200 |
}
|
| 201 |
}
|
| 202 |
|
| 203 |
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
| 204 |
+
uint32_t new_head = cells.size();
|
| 205 |
|
| 206 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 207 |
+
if (cells.seq_keep(i, seq_id)) {
|
| 208 |
+
if (new_head == cells.size()) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
new_head = i;
|
| 210 |
}
|
|
|
|
|
|
|
|
|
|
| 211 |
}
|
| 212 |
}
|
| 213 |
|
| 214 |
// If we freed up a slot, set head to it so searching can start there.
|
| 215 |
+
if (new_head != cells.size() && new_head < head) {
|
| 216 |
head = new_head;
|
| 217 |
}
|
| 218 |
}
|
| 219 |
|
| 220 |
+
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
| 221 |
+
if (shift == 0) {
|
| 222 |
return;
|
| 223 |
}
|
| 224 |
|
| 225 |
+
uint32_t new_head = cells.size();
|
| 226 |
|
| 227 |
if (p0 < 0) {
|
| 228 |
p0 = 0;
|
|
|
|
| 232 |
p1 = std::numeric_limits<llama_pos>::max();
|
| 233 |
}
|
| 234 |
|
| 235 |
+
// If there is no range then return early to avoid looping over all cells.
|
| 236 |
if (p0 == p1) {
|
| 237 |
return;
|
| 238 |
}
|
| 239 |
|
| 240 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 241 |
+
if (!cells.pos_in(i, p0, p1)) {
|
| 242 |
+
continue;
|
| 243 |
+
}
|
|
|
|
| 244 |
|
| 245 |
+
if (cells.seq_has(i, seq_id)) {
|
| 246 |
+
if (cells.pos_add(i, shift)) {
|
| 247 |
+
if (new_head == cells.size()) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
new_head = i;
|
| 249 |
}
|
| 250 |
}
|
|
|
|
| 253 |
|
| 254 |
// If we freed up a slot, set head to it so searching can start there.
|
| 255 |
// Otherwise we just start the next search from the beginning.
|
| 256 |
+
head = new_head != cells.size() ? new_head : 0;
|
| 257 |
}
|
| 258 |
|
| 259 |
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
|
|
| 274 |
return;
|
| 275 |
}
|
| 276 |
|
| 277 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 278 |
+
if (!cells.pos_in(i, p0, p1)) {
|
| 279 |
+
continue;
|
| 280 |
+
}
|
| 281 |
|
| 282 |
+
if (cells.seq_has(i, seq_id)) {
|
| 283 |
+
cells.pos_div(i, d);
|
|
|
|
|
|
|
|
|
|
| 284 |
}
|
| 285 |
}
|
| 286 |
}
|
| 287 |
|
| 288 |
+
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
| 289 |
+
return cells.seq_pos_min(seq_id);
|
| 290 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
+
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
| 293 |
+
return cells.seq_pos_max(seq_id);
|
| 294 |
}
|
| 295 |
|
| 296 |
void llama_kv_cache_unified::restore() {
|
| 297 |
+
for (auto & state : recovery.states) {
|
| 298 |
+
cells.set(state.i, state.cells);
|
| 299 |
}
|
| 300 |
|
| 301 |
+
recovery.clear();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
}
|
| 303 |
|
| 304 |
void llama_kv_cache_unified::commit() {
|
| 305 |
+
if (recovery.states.empty()) {
|
| 306 |
+
LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
|
| 307 |
+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
|
| 308 |
return;
|
| 309 |
}
|
| 310 |
|
| 311 |
+
recovery.clear();
|
| 312 |
}
|
| 313 |
|
| 314 |
bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
|
|
| 316 |
|
| 317 |
auto * sched = lctx.get_sched();
|
| 318 |
|
| 319 |
+
if (cells.get_has_shift()) {
|
| 320 |
if (!get_can_shift()) {
|
| 321 |
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
| 322 |
}
|
|
|
|
| 340 |
need_reserve = true;
|
| 341 |
}
|
| 342 |
|
| 343 |
+
cells.reset_shift();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
}
|
| 345 |
|
| 346 |
if (do_defrag) {
|
|
|
|
| 371 |
void llama_kv_cache_unified::defrag_sched(float thold) {
|
| 372 |
// - do not defrag small contexts (i.e. < 2048 tokens)
|
| 373 |
// - count the padding towards the number of used tokens
|
| 374 |
+
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f;
|
| 375 |
|
| 376 |
// queue defragmentation for next llama_kv_cache_update
|
| 377 |
if (fragmentation > thold) {
|
|
|
|
| 382 |
}
|
| 383 |
|
| 384 |
void llama_kv_cache_unified::set_full() {
|
| 385 |
+
n = cells.size();
|
| 386 |
|
| 387 |
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
|
| 388 |
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
|
|
|
|
| 392 |
head = 0;
|
| 393 |
}
|
| 394 |
|
| 395 |
+
llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
|
|
|
|
|
|
|
| 396 |
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
| 397 |
}
|
| 398 |
|
| 399 |
+
llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
|
|
|
|
|
|
|
|
|
|
| 400 |
GGML_UNUSED(embd_pooled);
|
| 401 |
return sbatch.split_simple(n_ubatch);
|
| 402 |
}
|
| 403 |
|
| 404 |
+
bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
|
|
|
| 405 |
const uint32_t n_tokens = ubatch.n_tokens;
|
|
|
|
|
|
|
| 406 |
|
| 407 |
// if we have enough unused cells before the current head ->
|
| 408 |
// better to start searching from the beginning of the cache, hoping to fill it
|
| 409 |
+
if (head > cells.get_used() + 2*ubatch.n_tokens) {
|
| 410 |
head = 0;
|
| 411 |
}
|
| 412 |
|
| 413 |
// otherwise, one cell per token.
|
| 414 |
|
| 415 |
+
if (n_tokens > cells.size()) {
|
| 416 |
+
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
| 417 |
return false;
|
| 418 |
}
|
| 419 |
|
| 420 |
+
//#define FIND_SLOT_DEBUG 1
|
| 421 |
+
#if FIND_SLOT_DEBUG
|
| 422 |
+
LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
|
| 423 |
+
|
| 424 |
+
// for debugging
|
| 425 |
+
{
|
| 426 |
+
std::string ss;
|
| 427 |
+
if (n_swa > 0) {
|
| 428 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 429 |
+
if (cells.is_empty(i)) {
|
| 430 |
+
ss += '.';
|
| 431 |
+
} else {
|
| 432 |
+
ss += 'x';
|
| 433 |
+
}
|
| 434 |
+
if (i%256 == 255) {
|
| 435 |
+
ss += '\n';
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
}
|
| 439 |
+
LLAMA_LOG_WARN("\n%s\n", ss.c_str());
|
| 440 |
+
}
|
| 441 |
+
#endif
|
| 442 |
+
|
| 443 |
uint32_t n_tested = 0;
|
| 444 |
|
| 445 |
while (true) {
|
| 446 |
+
if (head + n_tokens > cells.size()) {
|
| 447 |
+
n_tested += cells.size() - head;
|
| 448 |
head = 0;
|
| 449 |
continue;
|
| 450 |
}
|
| 451 |
|
| 452 |
bool found = true;
|
| 453 |
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 454 |
+
// TODO: improve to accept cells that are masked by the SWA
|
| 455 |
+
if (!cells.is_empty(head + i)) {
|
| 456 |
found = false;
|
| 457 |
head += i + 1;
|
| 458 |
n_tested += i + 1;
|
|
|
|
| 464 |
break;
|
| 465 |
}
|
| 466 |
|
| 467 |
+
if (n_tested >= cells.size()) {
|
| 468 |
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
| 469 |
return false;
|
| 470 |
}
|
| 471 |
}
|
| 472 |
|
| 473 |
+
// store the old state of the cells in the recovery stack
|
| 474 |
+
recovery.states.push_back({head, cells.cp(head, n_tokens)});
|
|
|
|
|
|
|
| 475 |
|
| 476 |
+
for (uint32_t i = 0; i < n_tokens; ++i) {
|
| 477 |
+
cells.pos_set(head + i, ubatch.pos[i]);
|
| 478 |
+
|
| 479 |
+
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
|
| 480 |
+
cells.seq_add(head + i, ubatch.seq_id[i][j]);
|
| 481 |
}
|
| 482 |
}
|
| 483 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
| 485 |
// after enough generations, the benefit from this heuristic disappears
|
| 486 |
// if we start defragmenting the cache, the benefit from this will be more important
|
| 487 |
+
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
| 488 |
+
|
| 489 |
+
#ifdef FIND_SLOT_DEBUG
|
| 490 |
+
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
|
| 491 |
+
#endif
|
| 492 |
|
| 493 |
+
return true;
|
| 494 |
+
}
|
| 495 |
|
| 496 |
+
bool llama_kv_cache_unified::get_can_shift() const {
|
| 497 |
return true;
|
| 498 |
}
|
| 499 |
|
| 500 |
+
uint32_t llama_kv_cache_unified::get_n() const {
|
| 501 |
+
return n;
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
uint32_t llama_kv_cache_unified::get_size() const {
|
| 505 |
+
return cells.size();
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
|
| 509 |
+
const int32_t ikv = map_layer_ids.at(il);
|
| 510 |
+
|
| 511 |
+
auto * k = layers[ikv].k;
|
| 512 |
+
|
| 513 |
+
return ggml_view_3d(ctx, k,
|
| 514 |
+
hparams.n_embd_head_k, hparams.n_head_kv(il), n,
|
| 515 |
+
ggml_row_size(k->type, hparams.n_embd_head_k),
|
| 516 |
+
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
|
| 517 |
+
0);
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const {
|
| 521 |
+
const int32_t ikv = map_layer_ids.at(il);
|
| 522 |
+
|
| 523 |
+
auto * v = layers[ikv].v;
|
| 524 |
|
| 525 |
+
if (!v_trans) {
|
| 526 |
+
// note: v->nb[1] <= v->nb[2]
|
| 527 |
+
return ggml_view_3d(ctx, v,
|
| 528 |
+
hparams.n_embd_head_v, hparams.n_head_kv(il), n,
|
| 529 |
+
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
|
| 530 |
+
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
|
| 531 |
+
0);
|
| 532 |
}
|
| 533 |
|
| 534 |
+
// note: v->nb[1] > v->nb[2]
|
| 535 |
+
return ggml_view_3d(ctx, v,
|
| 536 |
+
n, hparams.n_head_kv(il), hparams.n_embd_head_v,
|
| 537 |
+
ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
|
| 538 |
+
ggml_row_size(v->type, v->ne[1]), // v->nb[2]
|
| 539 |
+
0);
|
| 540 |
}
|
| 541 |
|
| 542 |
+
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
| 543 |
+
const int32_t ikv = map_layer_ids.at(il);
|
| 544 |
+
|
| 545 |
+
auto * k = layers[ikv].k;
|
| 546 |
+
|
| 547 |
+
const int64_t n_tokens = k_cur->ne[2];
|
| 548 |
+
|
| 549 |
+
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
| 550 |
+
n_tokens*hparams.n_embd_k_gqa(il),
|
| 551 |
+
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head);
|
| 552 |
+
|
| 553 |
+
return ggml_cpy(ctx, k_cur, k_view);
|
| 554 |
}
|
| 555 |
|
| 556 |
+
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
| 557 |
+
const int32_t ikv = map_layer_ids.at(il);
|
| 558 |
+
|
| 559 |
+
auto * v = layers[ikv].v;
|
| 560 |
+
|
| 561 |
+
const int64_t n_tokens = v_cur->ne[2];
|
| 562 |
+
|
| 563 |
+
v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
|
| 564 |
+
|
| 565 |
+
ggml_tensor * v_view = nullptr;
|
| 566 |
+
|
| 567 |
+
if (!v_trans) {
|
| 568 |
+
v_view = ggml_view_1d(ctx, v,
|
| 569 |
+
n_tokens*hparams.n_embd_v_gqa(il),
|
| 570 |
+
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head);
|
| 571 |
+
} else {
|
| 572 |
+
// note: the V cache is transposed when not using flash attention
|
| 573 |
+
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
|
| 574 |
+
(v->ne[1])*ggml_element_size(v),
|
| 575 |
+
( head)*ggml_element_size(v));
|
| 576 |
+
|
| 577 |
+
v_cur = ggml_transpose(ctx, v_cur);
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
return ggml_cpy(ctx, v_cur, v_view);
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) {
|
| 584 |
+
// no pruning is needed when the cache does not use SWA
|
| 585 |
+
GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
|
| 586 |
+
|
| 587 |
+
int n_attended = 0;
|
| 588 |
+
|
| 589 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 590 |
+
if (!cells.seq_has(i, seq_id)) {
|
| 591 |
+
continue;
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
const llama_pos p0 = cells.pos_get(i);
|
| 595 |
+
|
| 596 |
+
if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
|
| 597 |
+
n_attended++;
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
if (is_masked_swa(p0, pmax)) {
|
| 601 |
+
cells.seq_rm(i, seq_id);
|
| 602 |
+
}
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
if (n_attended < std::min<int>(n_swa, pmin)) {
|
| 606 |
+
LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa);
|
| 607 |
+
}
|
| 608 |
+
}
|
| 609 |
+
|
| 610 |
+
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
| 611 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 612 |
+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 613 |
+
const int64_t n_seqs = ubatch->n_seqs;
|
| 614 |
+
|
| 615 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 616 |
+
float * data = (float *) dst->data;
|
| 617 |
+
|
| 618 |
+
const int64_t n_kv = n;
|
| 619 |
+
|
| 620 |
+
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
| 621 |
+
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
| 622 |
+
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
| 623 |
+
// Causal mask:
|
| 624 |
+
// xxx-------
|
| 625 |
+
// xxxx------
|
| 626 |
+
// xxxxx-----
|
| 627 |
+
// Non-causal mask:
|
| 628 |
+
// xxxxx-----
|
| 629 |
+
// xxxxx-----
|
| 630 |
+
// xxxxx-----
|
| 631 |
+
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
| 632 |
+
for (int h = 0; h < 1; ++h) {
|
| 633 |
+
for (int s = 0; s < n_seqs; ++s) {
|
| 634 |
+
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 635 |
+
|
| 636 |
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
| 637 |
+
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
|
| 638 |
+
|
| 639 |
+
for (int i = 0; i < n_kv; ++i) {
|
| 640 |
+
float f = 0.0f;
|
| 641 |
+
|
| 642 |
+
bool masked = false;
|
| 643 |
+
|
| 644 |
+
if (cells.is_empty(i)) {
|
| 645 |
+
masked = true;
|
| 646 |
+
} else {
|
| 647 |
+
const llama_pos p0 = cells.pos_get(i);
|
| 648 |
+
|
| 649 |
+
// mask the token if not the same sequence
|
| 650 |
+
masked = masked || (!cells.seq_has(i, seq_id));
|
| 651 |
+
|
| 652 |
+
// mask future tokens
|
| 653 |
+
masked = masked || (causal_attn && p0 > p1);
|
| 654 |
+
|
| 655 |
+
// apply SWA if any
|
| 656 |
+
masked = masked || (is_masked_swa(p0, p1));
|
| 657 |
+
|
| 658 |
+
if (!masked && hparams.use_alibi) {
|
| 659 |
+
f = -std::abs(p0 - p1);
|
| 660 |
+
}
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
if (masked) {
|
| 664 |
+
f = -INFINITY;
|
| 665 |
+
}
|
| 666 |
+
|
| 667 |
+
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
| 668 |
+
}
|
| 669 |
+
}
|
| 670 |
+
}
|
| 671 |
+
|
| 672 |
+
// mask padded tokens
|
| 673 |
+
if (data) {
|
| 674 |
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 675 |
+
for (int j = 0; j < n_kv; ++j) {
|
| 676 |
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
| 677 |
+
}
|
| 678 |
+
}
|
| 679 |
+
}
|
| 680 |
+
}
|
| 681 |
}
|
| 682 |
|
| 683 |
+
void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
| 684 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 685 |
+
|
| 686 |
+
int32_t * data = (int32_t *) dst->data;
|
| 687 |
+
|
| 688 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 689 |
+
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
| 690 |
}
|
| 691 |
+
}
|
| 692 |
+
|
| 693 |
+
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
| 694 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 695 |
+
|
| 696 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 697 |
+
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
| 698 |
+
|
| 699 |
+
int32_t * data = (int32_t *) dst->data;
|
| 700 |
+
|
| 701 |
+
const int64_t n_kv = n;
|
| 702 |
|
| 703 |
+
for (int h = 0; h < 1; ++h) {
|
| 704 |
+
for (int j = 0; j < n_tokens; ++j) {
|
| 705 |
+
for (int i = 0; i < n_kv; ++i) {
|
| 706 |
+
// the position when the cells is empty is irrelevant - it will be masked out later in the attention
|
| 707 |
+
const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
|
| 708 |
+
|
| 709 |
+
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
| 710 |
+
}
|
| 711 |
+
}
|
| 712 |
+
}
|
| 713 |
}
|
| 714 |
|
| 715 |
size_t llama_kv_cache_unified::total_size() const {
|
| 716 |
size_t size = 0;
|
| 717 |
+
|
| 718 |
for (const auto & buf : bufs) {
|
| 719 |
size += ggml_backend_buffer_get_size(buf.get());
|
| 720 |
}
|
|
|
|
| 725 |
size_t llama_kv_cache_unified::size_k_bytes() const {
|
| 726 |
size_t size_k_bytes = 0;
|
| 727 |
|
| 728 |
+
for (const auto & layer : layers) {
|
| 729 |
+
size_k_bytes += ggml_nbytes(layer.k);
|
| 730 |
}
|
| 731 |
|
| 732 |
return size_k_bytes;
|
|
|
|
| 735 |
size_t llama_kv_cache_unified::size_v_bytes() const {
|
| 736 |
size_t size_v_bytes = 0;
|
| 737 |
|
| 738 |
+
for (const auto & layer : layers) {
|
| 739 |
+
size_v_bytes += ggml_nbytes(layer.v);
|
| 740 |
}
|
| 741 |
|
| 742 |
return size_v_bytes;
|
|
|
|
| 800 |
GGML_UNUSED(ubatch);
|
| 801 |
|
| 802 |
if (k_shift) {
|
| 803 |
+
kv_self->set_input_k_shift(k_shift);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 804 |
}
|
| 805 |
}
|
| 806 |
|
|
|
|
| 810 |
ggml_cgraph * gf) const {
|
| 811 |
auto res = std::make_unique<llm_graph_result>();
|
| 812 |
|
|
|
|
|
|
|
| 813 |
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
| 814 |
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
| 815 |
|
|
|
|
|
|
|
| 816 |
//GGML_ASSERT(kv_self->size == n_ctx);
|
| 817 |
|
| 818 |
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
|
|
|
| 820 |
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
|
| 821 |
ggml_set_input(inp->k_shift);
|
| 822 |
|
| 823 |
+
for (const auto & layer : layers) {
|
| 824 |
+
const uint32_t il = layer.il;
|
| 825 |
+
|
| 826 |
const int64_t n_head_kv = hparams.n_head_kv(il);
|
| 827 |
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 828 |
|
| 829 |
+
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
| 830 |
+
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
| 831 |
|
| 832 |
+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 833 |
|
| 834 |
ggml_tensor * k =
|
| 835 |
+
ggml_view_3d(ctx, layer.k,
|
| 836 |
+
n_embd_head_k, n_head_kv, cells.size(),
|
| 837 |
+
ggml_row_size(layer.k->type, n_embd_head_k),
|
| 838 |
+
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
| 839 |
0);
|
| 840 |
|
| 841 |
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
|
|
|
|
| 940 |
nm++;
|
| 941 |
}
|
| 942 |
|
| 943 |
+
for (const auto & layer : layers) {
|
| 944 |
+
const uint32_t il = layer.il;
|
| 945 |
+
|
| 946 |
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 947 |
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 948 |
|
| 949 |
+
ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
|
| 950 |
n_embd_k_gqa, nm,
|
| 951 |
+
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
| 952 |
+
ggml_row_size(layer.k->type, n_embd_k_gqa*i));
|
| 953 |
|
| 954 |
+
ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
|
| 955 |
n_embd_k_gqa, nm,
|
| 956 |
+
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
| 957 |
+
ggml_row_size(layer.k->type, n_embd_k_gqa*id));
|
| 958 |
|
| 959 |
ggml_tensor * view_v_src;
|
| 960 |
ggml_tensor * view_v_dst;
|
| 961 |
|
| 962 |
if (cparams.flash_attn) {
|
| 963 |
// NOTE: the V cache is not transposed when using flash attention
|
| 964 |
+
view_v_src = ggml_view_2d(ctx, layer.v,
|
| 965 |
n_embd_v_gqa, nm,
|
| 966 |
+
ggml_row_size(layer.v->type, n_embd_v_gqa),
|
| 967 |
+
ggml_row_size(layer.v->type, n_embd_v_gqa*i));
|
| 968 |
|
| 969 |
+
view_v_dst = ggml_view_2d(ctx, layer.v,
|
| 970 |
n_embd_v_gqa, nm,
|
| 971 |
+
ggml_row_size(layer.v->type, n_embd_v_gqa),
|
| 972 |
+
ggml_row_size(layer.v->type, n_embd_v_gqa*id));
|
| 973 |
} else {
|
| 974 |
+
view_v_src = ggml_view_2d(ctx, layer.v,
|
| 975 |
nm, n_embd_v_gqa,
|
| 976 |
+
ggml_row_size(layer.v->type, cells.size()),
|
| 977 |
+
ggml_row_size(layer.v->type, i));
|
| 978 |
|
| 979 |
+
view_v_dst = ggml_view_2d(ctx, layer.v,
|
| 980 |
nm, n_embd_v_gqa,
|
| 981 |
+
ggml_row_size(layer.v->type, cells.size()),
|
| 982 |
+
ggml_row_size(layer.v->type, id));
|
| 983 |
}
|
| 984 |
|
| 985 |
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
|
|
|
|
| 996 |
}
|
| 997 |
|
| 998 |
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
| 999 |
+
const uint32_t n_layer = layers.size();
|
| 1000 |
|
| 1001 |
+
const uint32_t n_kv = cells.used_max_p1();
|
| 1002 |
+
const uint32_t n_used = cells.get_used();
|
| 1003 |
|
| 1004 |
assert(n_used <= n_kv);
|
| 1005 |
|
|
|
|
| 1027 |
ids.resize(n_kv, n_kv);
|
| 1028 |
|
| 1029 |
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
| 1030 |
+
if (!cells.is_empty(i0)) {
|
|
|
|
|
|
|
| 1031 |
ids[i0] = i0;
|
| 1032 |
|
| 1033 |
continue;
|
|
|
|
| 1038 |
uint32_t nh = 1;
|
| 1039 |
|
| 1040 |
// determine the size of the hole
|
| 1041 |
+
while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
|
| 1042 |
nh++;
|
| 1043 |
}
|
| 1044 |
|
|
|
|
| 1047 |
|
| 1048 |
// starting from the end, find nh non-empty cells
|
| 1049 |
for (; is > i0; --is) {
|
| 1050 |
+
if (cells.is_empty(is) || ids[is] != n_kv) {
|
|
|
|
|
|
|
| 1051 |
continue;
|
| 1052 |
}
|
| 1053 |
|
|
|
|
| 1074 |
|
| 1075 |
// go back and move the nf cells to the hole
|
| 1076 |
for (; i1 < n_kv; ++i1) {
|
| 1077 |
+
if (cells.is_empty(i1) || ids[i1] != n_kv) {
|
|
|
|
|
|
|
| 1078 |
if (n_moves == max_moves) {
|
| 1079 |
stop = true;
|
| 1080 |
break;
|
|
|
|
| 1088 |
ids[i1] = i0 + nf;
|
| 1089 |
|
| 1090 |
// move the cell meta data
|
| 1091 |
+
cells.mv(i1, i0 + nf);
|
| 1092 |
|
|
|
|
|
|
|
| 1093 |
head = n_used;
|
| 1094 |
|
| 1095 |
if (!cont) {
|
|
|
|
| 1124 |
return true;
|
| 1125 |
}
|
| 1126 |
|
| 1127 |
+
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
| 1128 |
+
assert(p0 >= 0 && p1 >= 0);
|
|
|
|
| 1129 |
|
| 1130 |
+
switch (swa_type) {
|
| 1131 |
+
case LLAMA_SWA_TYPE_NONE:
|
| 1132 |
+
{
|
| 1133 |
+
} break;
|
| 1134 |
+
case LLAMA_SWA_TYPE_STANDARD:
|
| 1135 |
+
{
|
| 1136 |
+
if (p1 - p0 >= (int32_t) n_swa) {
|
| 1137 |
+
return true;
|
| 1138 |
+
}
|
| 1139 |
+
} break;
|
| 1140 |
+
case LLAMA_SWA_TYPE_CHUNKED:
|
| 1141 |
+
{
|
| 1142 |
+
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
|
| 1143 |
+
|
| 1144 |
+
if (p0 < pos_chunk_start) {
|
| 1145 |
+
return true;
|
| 1146 |
+
}
|
| 1147 |
+
} break;
|
| 1148 |
}
|
| 1149 |
|
| 1150 |
+
return false;
|
| 1151 |
}
|
| 1152 |
|
| 1153 |
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
|
|
|
| 1156 |
|
| 1157 |
// Count the number of cells with the specified seq_id
|
| 1158 |
// Find all the ranges of cells with this seq id (or all, when -1)
|
| 1159 |
+
uint32_t cell_range_begin = cells.size();
|
| 1160 |
+
|
| 1161 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 1162 |
+
if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
|
| 1163 |
++cell_count;
|
| 1164 |
+
if (cell_range_begin == cells.size()) {
|
| 1165 |
cell_range_begin = i;
|
| 1166 |
}
|
| 1167 |
} else {
|
| 1168 |
+
if (cell_range_begin != cells.size()) {
|
| 1169 |
cell_ranges.emplace_back(cell_range_begin, i);
|
| 1170 |
+
cell_range_begin = cells.size();
|
| 1171 |
}
|
| 1172 |
}
|
| 1173 |
}
|
| 1174 |
+
|
| 1175 |
+
if (cell_range_begin != cells.size()) {
|
| 1176 |
+
cell_ranges.emplace_back(cell_range_begin, cells.size());
|
| 1177 |
}
|
| 1178 |
|
| 1179 |
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
|
|
|
| 1210 |
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
| 1211 |
for (const auto & range : cell_ranges) {
|
| 1212 |
for (uint32_t i = range.first; i < range.second; ++i) {
|
| 1213 |
+
std::vector<llama_seq_id> seq_ids;
|
| 1214 |
+
|
| 1215 |
+
for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
|
| 1216 |
+
if (cur == seq_id || seq_id == -1) {
|
| 1217 |
+
if (cells.seq_has(i, cur)) {
|
| 1218 |
+
seq_ids.push_back(cur);
|
| 1219 |
+
}
|
| 1220 |
+
}
|
| 1221 |
+
}
|
| 1222 |
+
|
| 1223 |
+
const llama_pos pos = cells.pos_get(i);
|
| 1224 |
+
const uint32_t n_seq_id = seq_ids.size();
|
| 1225 |
|
| 1226 |
io.write(&pos, sizeof(pos));
|
| 1227 |
io.write(&n_seq_id, sizeof(n_seq_id));
|
| 1228 |
|
| 1229 |
+
for (const auto & seq_id : seq_ids) {
|
| 1230 |
+
io.write(&seq_id, sizeof(seq_id));
|
|
|
|
|
|
|
| 1231 |
}
|
| 1232 |
}
|
| 1233 |
}
|
|
|
|
| 1235 |
|
| 1236 |
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
| 1237 |
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
| 1238 |
+
const uint32_t n_layer = layers.size();
|
| 1239 |
|
| 1240 |
io.write(&v_trans, sizeof(v_trans));
|
| 1241 |
io.write(&n_layer, sizeof(n_layer));
|
|
|
|
| 1244 |
|
| 1245 |
// Iterate and write all the keys first, each row is a cell
|
| 1246 |
// Get whole range at a time
|
| 1247 |
+
for (const auto & layer : layers) {
|
| 1248 |
+
const uint32_t il = layer.il;
|
| 1249 |
+
|
| 1250 |
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 1251 |
|
| 1252 |
// Write key type
|
| 1253 |
+
const int32_t k_type_i = (int32_t)layer.k->type;
|
| 1254 |
io.write(&k_type_i, sizeof(k_type_i));
|
| 1255 |
|
| 1256 |
// Write row size of key
|
| 1257 |
+
const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
|
| 1258 |
io.write(&k_size_row, sizeof(k_size_row));
|
| 1259 |
|
| 1260 |
// Read each range of cells of k_size length each into tmp_buf and write out
|
| 1261 |
for (const auto & range : cell_ranges) {
|
| 1262 |
const size_t range_size = range.second - range.first;
|
| 1263 |
const size_t buf_size = range_size * k_size_row;
|
| 1264 |
+
io.write_tensor(layer.k, range.first * k_size_row, buf_size);
|
| 1265 |
}
|
| 1266 |
}
|
| 1267 |
|
| 1268 |
if (!v_trans) {
|
| 1269 |
+
for (const auto & layer : layers) {
|
| 1270 |
+
const uint32_t il = layer.il;
|
| 1271 |
+
|
| 1272 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1273 |
|
| 1274 |
// Write value type
|
| 1275 |
+
const int32_t v_type_i = (int32_t)layer.v->type;
|
| 1276 |
io.write(&v_type_i, sizeof(v_type_i));
|
| 1277 |
|
| 1278 |
// Write row size of value
|
| 1279 |
+
const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
|
| 1280 |
io.write(&v_size_row, sizeof(v_size_row));
|
| 1281 |
|
| 1282 |
// Read each range of cells of v_size length each into tmp_buf and write out
|
| 1283 |
for (const auto & range : cell_ranges) {
|
| 1284 |
const size_t range_size = range.second - range.first;
|
| 1285 |
const size_t buf_size = range_size * v_size_row;
|
| 1286 |
+
io.write_tensor(layer.v, range.first * v_size_row, buf_size);
|
| 1287 |
}
|
| 1288 |
}
|
| 1289 |
} else {
|
| 1290 |
// When v is transposed, we also need the element size and get the element ranges from each row
|
| 1291 |
+
const uint32_t kv_size = cells.size();
|
| 1292 |
+
|
| 1293 |
+
for (const auto & layer : layers) {
|
| 1294 |
+
const uint32_t il = layer.il;
|
| 1295 |
+
|
| 1296 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1297 |
|
| 1298 |
// Write value type
|
| 1299 |
+
const int32_t v_type_i = (int32_t)layer.v->type;
|
| 1300 |
io.write(&v_type_i, sizeof(v_type_i));
|
| 1301 |
|
| 1302 |
// Write element size
|
| 1303 |
+
const uint32_t v_size_el = ggml_type_size(layer.v->type);
|
| 1304 |
io.write(&v_size_el, sizeof(v_size_el));
|
| 1305 |
|
| 1306 |
// Write GQA embedding size
|
|
|
|
| 1313 |
const size_t range_size = range.second - range.first;
|
| 1314 |
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
| 1315 |
const size_t buf_size = range_size * v_size_el;
|
| 1316 |
+
io.write_tensor(layer.v, src_offset, buf_size);
|
| 1317 |
}
|
| 1318 |
}
|
| 1319 |
}
|
|
|
|
| 1330 |
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
| 1331 |
|
| 1332 |
batch.n_tokens = cell_count;
|
|
|
|
|
|
|
| 1333 |
|
| 1334 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1335 |
llama_pos pos;
|
|
|
|
| 1338 |
io.read_to(&pos, sizeof(pos));
|
| 1339 |
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 1340 |
|
| 1341 |
+
if (n_seq_id != 1) {
|
| 1342 |
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
| 1343 |
return false;
|
| 1344 |
}
|
| 1345 |
|
| 1346 |
+
// read the sequence id, but directly discard it - we will use dest_seq_id instead
|
| 1347 |
+
{
|
| 1348 |
+
llama_seq_id seq_id;
|
| 1349 |
+
io.read_to(&seq_id, sizeof(seq_id));
|
| 1350 |
+
}
|
| 1351 |
+
|
| 1352 |
+
batch.pos[i] = pos;
|
| 1353 |
+
batch.n_seq_id[i] = n_seq_id;
|
| 1354 |
+
batch.seq_id[i] = &dest_seq_id;
|
| 1355 |
}
|
| 1356 |
+
|
|
|
|
| 1357 |
if (!find_slot(batch)) {
|
| 1358 |
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 1359 |
return false;
|
| 1360 |
}
|
| 1361 |
+
|
| 1362 |
commit();
|
| 1363 |
|
| 1364 |
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 1365 |
// Assume that this is one contiguous block of cells
|
| 1366 |
+
GGML_ASSERT(head + cell_count <= cells.size());
|
| 1367 |
+
GGML_ASSERT(cells.pos_get(head) == batch.pos[0]);
|
| 1368 |
+
GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
|
| 1369 |
+
GGML_ASSERT(cells.seq_has(head, dest_seq_id));
|
| 1370 |
+
GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
|
| 1371 |
} else {
|
| 1372 |
// whole KV cache restore
|
| 1373 |
|
| 1374 |
+
if (cell_count > cells.size()) {
|
| 1375 |
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
| 1376 |
return false;
|
| 1377 |
}
|
|
|
|
| 1379 |
clear();
|
| 1380 |
|
| 1381 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
|
|
|
|
|
| 1382 |
llama_pos pos;
|
| 1383 |
uint32_t n_seq_id;
|
| 1384 |
|
| 1385 |
io.read_to(&pos, sizeof(pos));
|
| 1386 |
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
| 1387 |
|
| 1388 |
+
cells.pos_set(i, pos);
|
| 1389 |
|
| 1390 |
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 1391 |
llama_seq_id seq_id;
|
| 1392 |
io.read_to(&seq_id, sizeof(seq_id));
|
| 1393 |
|
| 1394 |
+
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
|
| 1395 |
+
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
|
|
|
|
|
|
|
|
|
|
| 1396 |
return false;
|
| 1397 |
}
|
| 1398 |
|
| 1399 |
+
cells.seq_add(i, seq_id);
|
| 1400 |
}
|
| 1401 |
}
|
| 1402 |
|
| 1403 |
head = 0;
|
|
|
|
| 1404 |
}
|
| 1405 |
|
| 1406 |
return true;
|
|
|
|
| 1409 |
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
| 1410 |
uint32_t v_trans;
|
| 1411 |
uint32_t n_layer;
|
| 1412 |
+
|
| 1413 |
io.read_to(&v_trans, sizeof(v_trans));
|
| 1414 |
io.read_to(&n_layer, sizeof(n_layer));
|
| 1415 |
|
| 1416 |
+
if (n_layer != layers.size()) {
|
| 1417 |
+
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
|
| 1418 |
return false;
|
| 1419 |
}
|
| 1420 |
+
if (cell_count > cells.size()) {
|
| 1421 |
+
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
|
| 1422 |
return false;
|
| 1423 |
}
|
| 1424 |
if (this->v_trans != (bool) v_trans) {
|
|
|
|
| 1427 |
}
|
| 1428 |
|
| 1429 |
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
| 1430 |
+
for (const auto & layer : layers) {
|
| 1431 |
+
const uint32_t il = layer.il;
|
| 1432 |
+
|
| 1433 |
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 1434 |
|
| 1435 |
// Read type of key
|
| 1436 |
int32_t k_type_i_ref;
|
| 1437 |
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
| 1438 |
+
const int32_t k_type_i = (int32_t) layer.k->type;
|
| 1439 |
if (k_type_i != k_type_i_ref) {
|
| 1440 |
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
| 1441 |
return false;
|
|
|
|
| 1444 |
// Read row size of key
|
| 1445 |
uint64_t k_size_row_ref;
|
| 1446 |
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
| 1447 |
+
const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
|
| 1448 |
if (k_size_row != k_size_row_ref) {
|
| 1449 |
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
| 1450 |
return false;
|
|
|
|
| 1452 |
|
| 1453 |
if (cell_count) {
|
| 1454 |
// Read and set the keys for the whole cell range
|
| 1455 |
+
ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
| 1456 |
}
|
| 1457 |
}
|
| 1458 |
|
| 1459 |
if (!this->v_trans) {
|
| 1460 |
+
for (const auto & layer : layers) {
|
| 1461 |
+
const uint32_t il = layer.il;
|
| 1462 |
+
|
| 1463 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1464 |
|
| 1465 |
// Read type of value
|
| 1466 |
int32_t v_type_i_ref;
|
| 1467 |
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1468 |
+
const int32_t v_type_i = (int32_t)layer.v->type;
|
| 1469 |
if (v_type_i != v_type_i_ref) {
|
| 1470 |
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1471 |
return false;
|
|
|
|
| 1474 |
// Read row size of value
|
| 1475 |
uint64_t v_size_row_ref;
|
| 1476 |
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
| 1477 |
+
const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
|
| 1478 |
if (v_size_row != v_size_row_ref) {
|
| 1479 |
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
| 1480 |
return false;
|
|
|
|
| 1482 |
|
| 1483 |
if (cell_count) {
|
| 1484 |
// Read and set the values for the whole cell range
|
| 1485 |
+
ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
| 1486 |
}
|
| 1487 |
}
|
| 1488 |
} else {
|
| 1489 |
// For each layer, read the values for each cell (transposed)
|
| 1490 |
+
for (const auto & layer : layers) {
|
| 1491 |
+
const uint32_t il = layer.il;
|
| 1492 |
+
|
| 1493 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 1494 |
|
| 1495 |
// Read type of value
|
| 1496 |
int32_t v_type_i_ref;
|
| 1497 |
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1498 |
+
const int32_t v_type_i = (int32_t)layer.v->type;
|
| 1499 |
if (v_type_i != v_type_i_ref) {
|
| 1500 |
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1501 |
return false;
|
|
|
|
| 1504 |
// Read element size of value
|
| 1505 |
uint32_t v_size_el_ref;
|
| 1506 |
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
| 1507 |
+
const size_t v_size_el = ggml_type_size(layer.v->type);
|
| 1508 |
if (v_size_el != v_size_el_ref) {
|
| 1509 |
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
| 1510 |
return false;
|
|
|
|
| 1521 |
if (cell_count) {
|
| 1522 |
// For each row in the transposed matrix, read the values for the whole cell range
|
| 1523 |
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 1524 |
+
const size_t dst_offset = (head + j * cells.size()) * v_size_el;
|
| 1525 |
+
ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
| 1526 |
}
|
| 1527 |
}
|
| 1528 |
}
|
|
|
|
| 1531 |
return true;
|
| 1532 |
}
|
| 1533 |
|
| 1534 |
+
//
|
| 1535 |
+
// llama_kv_cache_unified_iswa
|
| 1536 |
+
//
|
| 1537 |
+
|
| 1538 |
+
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
| 1539 |
+
const llama_model & model,
|
| 1540 |
+
ggml_type type_k,
|
| 1541 |
+
ggml_type type_v,
|
| 1542 |
+
bool v_trans,
|
| 1543 |
+
bool offload,
|
| 1544 |
+
bool swa_full,
|
| 1545 |
+
uint32_t kv_size,
|
| 1546 |
+
uint32_t n_seq_max,
|
| 1547 |
+
uint32_t n_batch,
|
| 1548 |
+
uint32_t n_pad) : hparams(model.hparams) {
|
| 1549 |
+
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
| 1550 |
+
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
| 1551 |
+
|
| 1552 |
+
const uint32_t size_base = kv_size;
|
| 1553 |
+
|
| 1554 |
+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
|
| 1555 |
+
|
| 1556 |
+
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
|
| 1557 |
+
if (swa_full) {
|
| 1558 |
+
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
|
| 1559 |
+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
| 1560 |
+
|
| 1561 |
+
size_swa = size_base;
|
| 1562 |
+
do_prune = false;
|
| 1563 |
+
}
|
| 1564 |
+
|
| 1565 |
+
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
| 1566 |
+
|
| 1567 |
+
kv_base = std::make_unique<llama_kv_cache_unified>(
|
| 1568 |
+
model, std::move(filter_base), type_k, type_v,
|
| 1569 |
+
v_trans, offload, size_base, n_seq_max, n_pad,
|
| 1570 |
+
0, LLAMA_SWA_TYPE_NONE);
|
| 1571 |
+
|
| 1572 |
+
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
| 1573 |
+
|
| 1574 |
+
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
| 1575 |
+
model, std::move(filter_swa), type_k, type_v,
|
| 1576 |
+
v_trans, offload, size_swa, n_seq_max, n_pad,
|
| 1577 |
+
hparams.n_swa, hparams.swa_type);
|
| 1578 |
+
}
|
| 1579 |
+
|
| 1580 |
+
void llama_kv_cache_unified_iswa::clear() {
|
| 1581 |
+
kv_base->clear();
|
| 1582 |
+
kv_swa ->clear();
|
| 1583 |
+
}
|
| 1584 |
+
|
| 1585 |
+
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
| 1586 |
+
bool res = true;
|
| 1587 |
+
|
| 1588 |
+
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
| 1589 |
+
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
|
| 1590 |
+
|
| 1591 |
+
return res;
|
| 1592 |
+
}
|
| 1593 |
+
|
| 1594 |
+
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
| 1595 |
+
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
| 1596 |
+
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
| 1597 |
+
}
|
| 1598 |
+
|
| 1599 |
+
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
| 1600 |
+
kv_base->seq_keep(seq_id);
|
| 1601 |
+
kv_swa ->seq_keep(seq_id);
|
| 1602 |
+
}
|
| 1603 |
+
|
| 1604 |
+
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
| 1605 |
+
kv_base->seq_add(seq_id, p0, p1, shift);
|
| 1606 |
+
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
| 1607 |
+
}
|
| 1608 |
+
|
| 1609 |
+
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
| 1610 |
+
kv_base->seq_div(seq_id, p0, p1, d);
|
| 1611 |
+
kv_swa ->seq_div(seq_id, p0, p1, d);
|
| 1612 |
+
}
|
| 1613 |
+
|
| 1614 |
+
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
| 1615 |
+
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
| 1616 |
+
return kv_swa->seq_pos_min(seq_id);
|
| 1617 |
+
}
|
| 1618 |
+
|
| 1619 |
+
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
| 1620 |
+
return kv_swa->seq_pos_max(seq_id);
|
| 1621 |
+
}
|
| 1622 |
+
|
| 1623 |
+
void llama_kv_cache_unified_iswa::restore() {
|
| 1624 |
+
kv_base->restore();
|
| 1625 |
+
kv_swa ->restore();
|
| 1626 |
+
}
|
| 1627 |
+
|
| 1628 |
+
void llama_kv_cache_unified_iswa::commit() {
|
| 1629 |
+
kv_base->commit();
|
| 1630 |
+
kv_swa ->commit();
|
| 1631 |
+
|
| 1632 |
+
// slide the attention window, forgetting/pruning old tokens that are outside the window
|
| 1633 |
+
if (do_prune) {
|
| 1634 |
+
for (const auto & [seq_id, entry] : pending.pos) {
|
| 1635 |
+
kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax);
|
| 1636 |
+
}
|
| 1637 |
+
|
| 1638 |
+
}
|
| 1639 |
+
|
| 1640 |
+
pending.clear();
|
| 1641 |
+
}
|
| 1642 |
+
|
| 1643 |
+
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
|
| 1644 |
+
bool res = true;
|
| 1645 |
+
|
| 1646 |
+
res = res & kv_base->update(lctx);
|
| 1647 |
+
res = res & kv_swa ->update(lctx);
|
| 1648 |
+
|
| 1649 |
+
return res;
|
| 1650 |
+
}
|
| 1651 |
+
|
| 1652 |
+
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
|
| 1653 |
+
kv_base->defrag_sched(thold);
|
| 1654 |
+
kv_swa ->defrag_sched(thold);
|
| 1655 |
+
}
|
| 1656 |
+
|
| 1657 |
+
void llama_kv_cache_unified_iswa::set_full() {
|
| 1658 |
+
kv_base->set_full();
|
| 1659 |
+
kv_swa ->set_full();
|
| 1660 |
+
}
|
| 1661 |
+
|
| 1662 |
+
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
|
| 1663 |
+
pending.clear();
|
| 1664 |
+
|
| 1665 |
+
if (do_prune) {
|
| 1666 |
+
for (int i = 0; i < batch.n_tokens; ++i) {
|
| 1667 |
+
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 1668 |
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
| 1669 |
+
const llama_pos pos = batch.pos[i];
|
| 1670 |
+
|
| 1671 |
+
if (pending.pos.find(seq_id) == pending.pos.end()) {
|
| 1672 |
+
pending.pos[seq_id].pmin = pos;
|
| 1673 |
+
pending.pos[seq_id].pmax = pos;
|
| 1674 |
+
} else {
|
| 1675 |
+
pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos);
|
| 1676 |
+
pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos);
|
| 1677 |
+
}
|
| 1678 |
+
}
|
| 1679 |
+
}
|
| 1680 |
+
}
|
| 1681 |
+
|
| 1682 |
+
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
| 1683 |
+
}
|
| 1684 |
+
|
| 1685 |
+
llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
|
| 1686 |
+
GGML_UNUSED(embd_pooled);
|
| 1687 |
+
return sbatch.split_simple(n_ubatch);
|
| 1688 |
+
}
|
| 1689 |
+
|
| 1690 |
+
bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
|
| 1691 |
+
bool res = true;
|
| 1692 |
+
|
| 1693 |
+
res = res & kv_base->find_slot(batch);
|
| 1694 |
+
res = res & kv_swa ->find_slot(batch);
|
| 1695 |
+
|
| 1696 |
+
return res;
|
| 1697 |
+
}
|
| 1698 |
+
|
| 1699 |
+
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
| 1700 |
+
return kv_base->get_size() == kv_swa->get_size();
|
| 1701 |
+
}
|
| 1702 |
+
|
| 1703 |
+
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
| 1704 |
+
kv_base->state_write(io, seq_id);
|
| 1705 |
+
kv_swa ->state_write(io, seq_id);
|
| 1706 |
+
}
|
| 1707 |
+
|
| 1708 |
+
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
| 1709 |
+
kv_base->state_read(io, seq_id);
|
| 1710 |
+
kv_swa ->state_read(io, seq_id);
|
| 1711 |
+
}
|
| 1712 |
+
|
| 1713 |
+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const {
|
| 1714 |
+
return kv_base.get();
|
| 1715 |
+
}
|
| 1716 |
+
|
| 1717 |
+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
|
| 1718 |
+
return kv_swa.get();
|
| 1719 |
+
}
|
| 1720 |
+
|
| 1721 |
//
|
| 1722 |
// llama_kv_cache_recurrent
|
| 1723 |
//
|
|
|
|
| 1727 |
ggml_type type_k,
|
| 1728 |
ggml_type type_v,
|
| 1729 |
bool offload,
|
| 1730 |
+
uint32_t kv_size,
|
| 1731 |
+
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
| 1732 |
const int32_t n_layer = hparams.n_layer;
|
| 1733 |
|
| 1734 |
+
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
|
| 1735 |
+
__func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
|
| 1736 |
|
| 1737 |
head = 0;
|
| 1738 |
size = kv_size;
|
| 1739 |
used = 0;
|
| 1740 |
|
|
|
|
|
|
|
|
|
|
| 1741 |
cells.clear();
|
| 1742 |
cells.resize(kv_size);
|
| 1743 |
|
|
|
|
| 1975 |
}
|
| 1976 |
}
|
| 1977 |
|
| 1978 |
+
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
| 1979 |
+
if (shift == 0) {
|
| 1980 |
return;
|
| 1981 |
}
|
| 1982 |
|
|
|
|
| 1999 |
if (tail_id >= 0) {
|
| 2000 |
kv_cell & cell = cells[tail_id];
|
| 2001 |
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 2002 |
+
cell.pos += shift;
|
| 2003 |
}
|
| 2004 |
}
|
| 2005 |
}
|
|
|
|
| 2035 |
}
|
| 2036 |
}
|
| 2037 |
|
| 2038 |
+
llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
| 2039 |
+
llama_pos result = std::numeric_limits<llama_pos>::max();
|
| 2040 |
+
|
| 2041 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 2042 |
+
if (cells[i].has_seq_id(seq_id)) {
|
| 2043 |
+
result = std::min(result, cells[i].pos);
|
| 2044 |
+
}
|
| 2045 |
+
}
|
| 2046 |
+
|
| 2047 |
+
if (result == std::numeric_limits<llama_pos>::max()) {
|
| 2048 |
+
result = -1;
|
| 2049 |
+
}
|
| 2050 |
+
|
| 2051 |
+
return result;
|
| 2052 |
+
}
|
| 2053 |
+
|
| 2054 |
llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
| 2055 |
+
llama_pos result = -1;
|
| 2056 |
|
| 2057 |
for (uint32_t i = 0; i < size; ++i) {
|
| 2058 |
if (cells[i].has_seq_id(seq_id)) {
|
|
|
|
| 2075 |
pending.ranges.clear();
|
| 2076 |
}
|
| 2077 |
|
| 2078 |
+
bool llama_kv_cache_recurrent::update(llama_context & ctx) {
|
| 2079 |
+
GGML_UNUSED(ctx);
|
| 2080 |
return false;
|
| 2081 |
}
|
| 2082 |
|
|
|
|
| 2137 |
if (seq_id < 0 || (uint32_t) seq_id >= size) {
|
| 2138 |
// too big seq_id
|
| 2139 |
// TODO: would it be possible to resize the cache instead?
|
| 2140 |
+
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
|
| 2141 |
return false;
|
| 2142 |
}
|
| 2143 |
if (j > 0) {
|
|
|
|
| 2280 |
return n >= n_seqs;
|
| 2281 |
}
|
| 2282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2283 |
bool llama_kv_cache_recurrent::get_can_shift() const {
|
| 2284 |
return false;
|
| 2285 |
}
|
|
|
|
| 2408 |
io.read_to(&cell_count, sizeof(cell_count));
|
| 2409 |
|
| 2410 |
bool res = true;
|
| 2411 |
+
|
| 2412 |
res = res && state_read_meta(io, cell_count, seq_id);
|
| 2413 |
res = res && state_read_data(io, cell_count);
|
| 2414 |
|
|
|
|
| 2737 |
|
| 2738 |
return true;
|
| 2739 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/talk-llama/llama-kv-cache.h
CHANGED
|
@@ -4,10 +4,12 @@
|
|
| 4 |
#include "llama-io.h"
|
| 5 |
#include "llama-graph.h"
|
| 6 |
#include "llama-memory.h"
|
|
|
|
| 7 |
|
| 8 |
#include "ggml-cpp.h"
|
| 9 |
|
| 10 |
#include <set>
|
|
|
|
| 11 |
#include <vector>
|
| 12 |
|
| 13 |
struct llama_cparams;
|
|
@@ -34,12 +36,16 @@ struct llama_kv_cache : public llama_memory_i {
|
|
| 34 |
virtual void defrag_sched(float thold) = 0;
|
| 35 |
|
| 36 |
// simulate full cache, used for allocating worst-case compute buffers
|
|
|
|
| 37 |
virtual void set_full() = 0;
|
| 38 |
|
| 39 |
//
|
| 40 |
// batch processing
|
| 41 |
//
|
| 42 |
|
|
|
|
|
|
|
|
|
|
| 43 |
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
|
| 44 |
|
| 45 |
// different KV caches require different batch splitting strategies
|
|
@@ -48,11 +54,10 @@ struct llama_kv_cache : public llama_memory_i {
|
|
| 48 |
// find an empty slot of size "n_tokens" in the cache
|
| 49 |
virtual bool find_slot(const llama_ubatch & batch) = 0;
|
| 50 |
|
|
|
|
|
|
|
| 51 |
// getters
|
| 52 |
-
virtual
|
| 53 |
-
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
| 54 |
-
virtual llama_pos get_pos_max() const = 0;
|
| 55 |
-
virtual bool get_can_shift() const = 0;
|
| 56 |
|
| 57 |
bool get_can_edit() const override { return get_can_shift(); }
|
| 58 |
|
|
@@ -87,38 +92,25 @@ private:
|
|
| 87 |
// llama_kv_cache_unified
|
| 88 |
//
|
| 89 |
|
| 90 |
-
// TODO: add notion of max sequences
|
| 91 |
class llama_kv_cache_unified : public llama_kv_cache {
|
| 92 |
public:
|
| 93 |
-
struct kv_cell {
|
| 94 |
-
llama_pos pos = -1;
|
| 95 |
-
llama_pos delta = 0;
|
| 96 |
-
|
| 97 |
-
std::set<llama_seq_id> seq_id;
|
| 98 |
-
|
| 99 |
-
bool has_seq_id(const llama_seq_id & id) const {
|
| 100 |
-
return seq_id.find(id) != seq_id.end();
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
bool is_empty() const {
|
| 104 |
-
return seq_id.empty();
|
| 105 |
-
}
|
| 106 |
-
|
| 107 |
-
bool is_same_seq(const kv_cell & other) const {
|
| 108 |
-
return seq_id == other.seq_id;
|
| 109 |
-
}
|
| 110 |
-
};
|
| 111 |
-
|
| 112 |
static uint32_t get_padding(const llama_cparams & cparams);
|
| 113 |
|
|
|
|
|
|
|
|
|
|
| 114 |
llama_kv_cache_unified(
|
| 115 |
-
const llama_model &
|
| 116 |
-
|
| 117 |
-
ggml_type
|
| 118 |
-
|
| 119 |
-
bool
|
| 120 |
-
|
| 121 |
-
uint32_t
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
~llama_kv_cache_unified() = default;
|
| 124 |
|
|
@@ -130,10 +122,11 @@ public:
|
|
| 130 |
|
| 131 |
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 132 |
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 133 |
-
void seq_keep(llama_seq_id seq_id)
|
| 134 |
-
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos
|
| 135 |
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 136 |
|
|
|
|
| 137 |
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 138 |
|
| 139 |
//
|
|
@@ -150,7 +143,6 @@ public:
|
|
| 150 |
void set_full() override;
|
| 151 |
|
| 152 |
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
| 153 |
-
|
| 154 |
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
| 155 |
|
| 156 |
// updates the cache head
|
|
@@ -158,50 +150,94 @@ public:
|
|
| 158 |
// to the first cell of the slot.
|
| 159 |
bool find_slot(const llama_ubatch & batch) override;
|
| 160 |
|
| 161 |
-
int32_t get_n_tokens() const override;
|
| 162 |
-
int32_t get_used_cells() const override;
|
| 163 |
-
|
| 164 |
-
// TODO: better data structures to reduce the cost of this operation
|
| 165 |
-
llama_pos get_pos_max() const override;
|
| 166 |
-
|
| 167 |
bool get_can_shift() const override;
|
| 168 |
|
| 169 |
// state write/load
|
| 170 |
|
| 171 |
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
| 172 |
-
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1)
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
|
| 178 |
-
|
| 179 |
-
uint32_t
|
| 180 |
|
| 181 |
-
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
private:
|
| 187 |
const llama_model & model;
|
| 188 |
const llama_hparams & hparams;
|
| 189 |
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
|
|
|
| 193 |
bool v_trans = true; // the value tensor is transposed
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
// required padding
|
| 197 |
-
uint32_t
|
| 198 |
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
| 201 |
|
| 202 |
std::vector<ggml_context_ptr> ctxs;
|
| 203 |
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
// defrag
|
| 206 |
struct {
|
| 207 |
std::vector<uint32_t> ids;
|
|
@@ -210,25 +246,13 @@ private:
|
|
| 210 |
// return true if cells have been moved
|
| 211 |
bool defrag_prepare(int32_t n_max_nodes);
|
| 212 |
|
| 213 |
-
// commit/restore cache
|
| 214 |
-
struct slot_range {
|
| 215 |
-
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
| 216 |
-
uint32_t c1 = 0;
|
| 217 |
-
};
|
| 218 |
-
|
| 219 |
-
// pending cell updates that are not yet committed
|
| 220 |
-
struct {
|
| 221 |
-
std::vector<slot_range> ranges;
|
| 222 |
-
} pending;
|
| 223 |
-
|
| 224 |
-
// find how many cells are currently in use
|
| 225 |
-
uint32_t cell_max() const;
|
| 226 |
-
|
| 227 |
size_t total_size() const;
|
| 228 |
|
| 229 |
size_t size_k_bytes() const;
|
| 230 |
size_t size_v_bytes() const;
|
| 231 |
|
|
|
|
|
|
|
| 232 |
ggml_tensor * build_rope_shift(
|
| 233 |
const llama_cparams & cparams,
|
| 234 |
ggml_context * ctx,
|
|
@@ -255,6 +279,100 @@ private:
|
|
| 255 |
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 256 |
};
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
//
|
| 259 |
// llama_kv_cache_recurrent
|
| 260 |
//
|
|
@@ -286,7 +404,8 @@ public:
|
|
| 286 |
ggml_type type_k,
|
| 287 |
ggml_type type_v,
|
| 288 |
bool offload,
|
| 289 |
-
uint32_t kv_size
|
|
|
|
| 290 |
|
| 291 |
~llama_kv_cache_recurrent() = default;
|
| 292 |
|
|
@@ -298,10 +417,11 @@ public:
|
|
| 298 |
|
| 299 |
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 300 |
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 301 |
-
void seq_keep(llama_seq_id seq_id)
|
| 302 |
-
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos
|
| 303 |
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 304 |
|
|
|
|
| 305 |
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 306 |
|
| 307 |
//
|
|
@@ -311,24 +431,17 @@ public:
|
|
| 311 |
void restore() override;
|
| 312 |
void commit() override;
|
| 313 |
|
| 314 |
-
bool update(llama_context &
|
| 315 |
|
| 316 |
void defrag_sched(float thold) override;
|
| 317 |
|
| 318 |
void set_full() override;
|
| 319 |
|
| 320 |
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
| 321 |
-
|
| 322 |
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
| 323 |
|
| 324 |
bool find_slot(const llama_ubatch & batch) override;
|
| 325 |
|
| 326 |
-
int32_t get_n_tokens() const override;
|
| 327 |
-
int32_t get_used_cells() const override;
|
| 328 |
-
|
| 329 |
-
// TODO: better data structures to reduce the cost of this operation
|
| 330 |
-
llama_pos get_pos_max() const override;
|
| 331 |
-
|
| 332 |
bool get_can_shift() const override;
|
| 333 |
|
| 334 |
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
|
|
@@ -368,8 +481,7 @@ private:
|
|
| 368 |
std::vector<slot_range> ranges;
|
| 369 |
} pending;
|
| 370 |
|
| 371 |
-
|
| 372 |
-
ggml_type type_v = GGML_TYPE_F16;
|
| 373 |
|
| 374 |
std::vector<ggml_context_ptr> ctxs;
|
| 375 |
std::vector<ggml_backend_buffer_ptr> bufs;
|
|
@@ -388,12 +500,3 @@ private:
|
|
| 388 |
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
| 389 |
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 390 |
};
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
//
|
| 394 |
-
// kv cache view
|
| 395 |
-
//
|
| 396 |
-
|
| 397 |
-
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
|
| 398 |
-
|
| 399 |
-
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);
|
|
|
|
| 4 |
#include "llama-io.h"
|
| 5 |
#include "llama-graph.h"
|
| 6 |
#include "llama-memory.h"
|
| 7 |
+
#include "llama-kv-cells.h"
|
| 8 |
|
| 9 |
#include "ggml-cpp.h"
|
| 10 |
|
| 11 |
#include <set>
|
| 12 |
+
#include <unordered_map>
|
| 13 |
#include <vector>
|
| 14 |
|
| 15 |
struct llama_cparams;
|
|
|
|
| 36 |
virtual void defrag_sched(float thold) = 0;
|
| 37 |
|
| 38 |
// simulate full cache, used for allocating worst-case compute buffers
|
| 39 |
+
// TODO: remove
|
| 40 |
virtual void set_full() = 0;
|
| 41 |
|
| 42 |
//
|
| 43 |
// batch processing
|
| 44 |
//
|
| 45 |
|
| 46 |
+
// =============================================================================================================
|
| 47 |
+
// TODO: refactor and simplify this [TAG: KV_API]
|
| 48 |
+
|
| 49 |
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
|
| 50 |
|
| 51 |
// different KV caches require different batch splitting strategies
|
|
|
|
| 54 |
// find an empty slot of size "n_tokens" in the cache
|
| 55 |
virtual bool find_slot(const llama_ubatch & batch) = 0;
|
| 56 |
|
| 57 |
+
// =============================================================================================================
|
| 58 |
+
|
| 59 |
// getters
|
| 60 |
+
virtual bool get_can_shift() const = 0;
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
bool get_can_edit() const override { return get_can_shift(); }
|
| 63 |
|
|
|
|
| 92 |
// llama_kv_cache_unified
|
| 93 |
//
|
| 94 |
|
|
|
|
| 95 |
class llama_kv_cache_unified : public llama_kv_cache {
|
| 96 |
public:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
static uint32_t get_padding(const llama_cparams & cparams);
|
| 98 |
|
| 99 |
+
// this callback is used to filter out layers that should not be included in the cache
|
| 100 |
+
using layer_filter_cb = std::function<bool(int32_t il)>;
|
| 101 |
+
|
| 102 |
llama_kv_cache_unified(
|
| 103 |
+
const llama_model & model,
|
| 104 |
+
layer_filter_cb && filter,
|
| 105 |
+
ggml_type type_k,
|
| 106 |
+
ggml_type type_v,
|
| 107 |
+
bool v_trans,
|
| 108 |
+
bool offload,
|
| 109 |
+
uint32_t kv_size,
|
| 110 |
+
uint32_t n_seq_max,
|
| 111 |
+
uint32_t n_pad,
|
| 112 |
+
uint32_t n_swa,
|
| 113 |
+
llama_swa_type swa_type);
|
| 114 |
|
| 115 |
~llama_kv_cache_unified() = default;
|
| 116 |
|
|
|
|
| 122 |
|
| 123 |
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 124 |
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 125 |
+
void seq_keep(llama_seq_id seq_id) override;
|
| 126 |
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 127 |
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 128 |
|
| 129 |
+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 130 |
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 131 |
|
| 132 |
//
|
|
|
|
| 143 |
void set_full() override;
|
| 144 |
|
| 145 |
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
|
|
|
| 146 |
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
| 147 |
|
| 148 |
// updates the cache head
|
|
|
|
| 150 |
// to the first cell of the slot.
|
| 151 |
bool find_slot(const llama_ubatch & batch) override;
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
bool get_can_shift() const override;
|
| 154 |
|
| 155 |
// state write/load
|
| 156 |
|
| 157 |
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
| 158 |
+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
| 159 |
|
| 160 |
+
//
|
| 161 |
+
// llama_kv_cache_unified specific API
|
| 162 |
+
//
|
| 163 |
|
| 164 |
+
uint32_t get_n() const;
|
| 165 |
+
uint32_t get_size() const;
|
| 166 |
|
| 167 |
+
// get views of the current state of the cache
|
| 168 |
+
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
| 169 |
+
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
| 170 |
|
| 171 |
+
// store k_cur and v_cur in the cache based on the current head location
|
| 172 |
+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
| 173 |
+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
| 174 |
+
|
| 175 |
+
void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax);
|
| 176 |
+
|
| 177 |
+
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
| 178 |
+
void set_input_k_shift (ggml_tensor * dst) const;
|
| 179 |
+
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
| 180 |
|
| 181 |
private:
|
| 182 |
const llama_model & model;
|
| 183 |
const llama_hparams & hparams;
|
| 184 |
|
| 185 |
+
struct kv_layer {
|
| 186 |
+
// layer index in the model
|
| 187 |
+
// note: can be different from the layer index in the KV cache
|
| 188 |
+
uint32_t il;
|
| 189 |
+
|
| 190 |
+
ggml_tensor * k;
|
| 191 |
+
ggml_tensor * v;
|
| 192 |
+
};
|
| 193 |
|
| 194 |
+
bool do_defrag = false;
|
| 195 |
bool v_trans = true; // the value tensor is transposed
|
| 196 |
+
|
| 197 |
+
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
| 198 |
+
|
| 199 |
+
// computed before each graph build
|
| 200 |
+
// TODO: cells should start to maintain this value dynamically based on the edits
|
| 201 |
+
uint32_t n = 0;
|
| 202 |
+
|
| 203 |
+
const uint32_t n_seq_max = 1;
|
| 204 |
|
| 205 |
// required padding
|
| 206 |
+
const uint32_t n_pad = 1;
|
| 207 |
|
| 208 |
+
// SWA
|
| 209 |
+
const uint32_t n_swa = 0;
|
| 210 |
+
|
| 211 |
+
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
| 212 |
|
| 213 |
std::vector<ggml_context_ptr> ctxs;
|
| 214 |
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 215 |
|
| 216 |
+
llama_kv_cells_unified cells;
|
| 217 |
+
|
| 218 |
+
std::vector<kv_layer> layers;
|
| 219 |
+
|
| 220 |
+
// model layer id -> KV cache layer id
|
| 221 |
+
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
| 222 |
+
|
| 223 |
+
// recovery information used to restore the KV cells to their original state in case of a failure
|
| 224 |
+
// TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
|
| 225 |
+
// to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
|
| 226 |
+
struct {
|
| 227 |
+
void clear() {
|
| 228 |
+
states.clear();
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
struct state {
|
| 232 |
+
uint32_t i;
|
| 233 |
+
|
| 234 |
+
llama_kv_cells_unified cells;
|
| 235 |
+
};
|
| 236 |
+
|
| 237 |
+
// stack with the partial states before each ubatch
|
| 238 |
+
std::vector<state> states;
|
| 239 |
+
} recovery;
|
| 240 |
+
|
| 241 |
// defrag
|
| 242 |
struct {
|
| 243 |
std::vector<uint32_t> ids;
|
|
|
|
| 246 |
// return true if cells have been moved
|
| 247 |
bool defrag_prepare(int32_t n_max_nodes);
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
size_t total_size() const;
|
| 250 |
|
| 251 |
size_t size_k_bytes() const;
|
| 252 |
size_t size_v_bytes() const;
|
| 253 |
|
| 254 |
+
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
|
| 255 |
+
|
| 256 |
ggml_tensor * build_rope_shift(
|
| 257 |
const llama_cparams & cparams,
|
| 258 |
ggml_context * ctx,
|
|
|
|
| 279 |
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 280 |
};
|
| 281 |
|
| 282 |
+
//
|
| 283 |
+
// llama_kv_cache_unified_iswa
|
| 284 |
+
//
|
| 285 |
+
|
| 286 |
+
// utilizes two instances of llama_kv_cache_unified
|
| 287 |
+
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
| 288 |
+
// upon successful commit, the SWA cache removes old tokens outside the n_swa window
|
| 289 |
+
|
| 290 |
+
class llama_kv_cache_unified_iswa : public llama_kv_cache {
|
| 291 |
+
public:
|
| 292 |
+
llama_kv_cache_unified_iswa(
|
| 293 |
+
const llama_model & model,
|
| 294 |
+
ggml_type type_k,
|
| 295 |
+
ggml_type type_v,
|
| 296 |
+
bool v_trans,
|
| 297 |
+
bool offload,
|
| 298 |
+
bool swa_full,
|
| 299 |
+
uint32_t kv_size,
|
| 300 |
+
uint32_t n_seq_max,
|
| 301 |
+
uint32_t n_batch,
|
| 302 |
+
uint32_t n_pad);
|
| 303 |
+
|
| 304 |
+
~llama_kv_cache_unified_iswa() = default;
|
| 305 |
+
|
| 306 |
+
//
|
| 307 |
+
// llama_memory_i
|
| 308 |
+
//
|
| 309 |
+
|
| 310 |
+
void clear() override;
|
| 311 |
+
|
| 312 |
+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 313 |
+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 314 |
+
void seq_keep(llama_seq_id seq_id) override;
|
| 315 |
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 316 |
+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 317 |
+
|
| 318 |
+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 319 |
+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 320 |
+
|
| 321 |
+
//
|
| 322 |
+
// llama_kv_cache
|
| 323 |
+
//
|
| 324 |
+
|
| 325 |
+
void restore() override;
|
| 326 |
+
void commit() override;
|
| 327 |
+
|
| 328 |
+
bool update(llama_context & ctx) override;
|
| 329 |
+
|
| 330 |
+
void defrag_sched(float thold) override;
|
| 331 |
+
|
| 332 |
+
void set_full() override;
|
| 333 |
+
|
| 334 |
+
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
| 335 |
+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
| 336 |
+
|
| 337 |
+
bool find_slot(const llama_ubatch & batch) override;
|
| 338 |
+
|
| 339 |
+
bool get_can_shift() const override;
|
| 340 |
+
|
| 341 |
+
// state write/load
|
| 342 |
+
|
| 343 |
+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
| 344 |
+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
| 345 |
+
|
| 346 |
+
//
|
| 347 |
+
// llama_kv_cache_unified_iswa specific API
|
| 348 |
+
//
|
| 349 |
+
|
| 350 |
+
llama_kv_cache_unified * get_kv_base() const;
|
| 351 |
+
llama_kv_cache_unified * get_kv_swa () const;
|
| 352 |
+
|
| 353 |
+
private:
|
| 354 |
+
const llama_hparams & hparams;
|
| 355 |
+
|
| 356 |
+
bool do_prune = true;
|
| 357 |
+
|
| 358 |
+
struct {
|
| 359 |
+
struct entry {
|
| 360 |
+
llama_pos pmin;
|
| 361 |
+
llama_pos pmax;
|
| 362 |
+
};
|
| 363 |
+
|
| 364 |
+
void clear() {
|
| 365 |
+
pos.clear();
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
// used to perform SWA pruning of old tokens
|
| 369 |
+
std::unordered_map<llama_seq_id, entry> pos;
|
| 370 |
+
} pending;
|
| 371 |
+
|
| 372 |
+
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
| 373 |
+
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
| 374 |
+
};
|
| 375 |
+
|
| 376 |
//
|
| 377 |
// llama_kv_cache_recurrent
|
| 378 |
//
|
|
|
|
| 404 |
ggml_type type_k,
|
| 405 |
ggml_type type_v,
|
| 406 |
bool offload,
|
| 407 |
+
uint32_t kv_size,
|
| 408 |
+
uint32_t n_seq_max);
|
| 409 |
|
| 410 |
~llama_kv_cache_recurrent() = default;
|
| 411 |
|
|
|
|
| 417 |
|
| 418 |
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 419 |
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 420 |
+
void seq_keep(llama_seq_id seq_id) override;
|
| 421 |
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 422 |
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 423 |
|
| 424 |
+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 425 |
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 426 |
|
| 427 |
//
|
|
|
|
| 431 |
void restore() override;
|
| 432 |
void commit() override;
|
| 433 |
|
| 434 |
+
bool update(llama_context & ctx) override;
|
| 435 |
|
| 436 |
void defrag_sched(float thold) override;
|
| 437 |
|
| 438 |
void set_full() override;
|
| 439 |
|
| 440 |
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
|
|
|
| 441 |
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
| 442 |
|
| 443 |
bool find_slot(const llama_ubatch & batch) override;
|
| 444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
bool get_can_shift() const override;
|
| 446 |
|
| 447 |
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
|
|
|
|
| 481 |
std::vector<slot_range> ranges;
|
| 482 |
} pending;
|
| 483 |
|
| 484 |
+
const uint32_t n_seq_max = 1;
|
|
|
|
| 485 |
|
| 486 |
std::vector<ggml_context_ptr> ctxs;
|
| 487 |
std::vector<ggml_backend_buffer_ptr> bufs;
|
|
|
|
| 500 |
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
| 501 |
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 502 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/talk-llama/llama-kv-cells.h
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama.h"
|
| 4 |
+
#include "llama-cparams.h"
|
| 5 |
+
|
| 6 |
+
#include <bitset>
|
| 7 |
+
#include <cassert>
|
| 8 |
+
#include <vector>
|
| 9 |
+
#include <set>
|
| 10 |
+
|
| 11 |
+
// meta information about KV cells that can be part of multiple sequences at the same time
|
| 12 |
+
// TODO: add unit tests
|
| 13 |
+
class llama_kv_cells_unified {
|
| 14 |
+
public:
|
| 15 |
+
void reset() {
|
| 16 |
+
for (uint32_t i = 0; i < pos.size(); ++i) {
|
| 17 |
+
pos[i] = -1;
|
| 18 |
+
shift[i] = 0;
|
| 19 |
+
seq[i].reset();
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
has_shift = false;
|
| 23 |
+
|
| 24 |
+
used.clear();
|
| 25 |
+
|
| 26 |
+
for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
| 27 |
+
seq_pos[s].clear();
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
void reset_shift() {
|
| 32 |
+
has_shift = false;
|
| 33 |
+
|
| 34 |
+
for (uint32_t i = 0; i < shift.size(); ++i) {
|
| 35 |
+
shift[i] = 0;
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
uint32_t size() const {
|
| 40 |
+
return pos.size();
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
void resize(uint32_t n) {
|
| 44 |
+
pos.resize(n);
|
| 45 |
+
shift.resize(n);
|
| 46 |
+
seq.resize(n);
|
| 47 |
+
|
| 48 |
+
reset();
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
bool is_empty(uint32_t i) const {
|
| 52 |
+
assert(i < pos.size());
|
| 53 |
+
assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
|
| 54 |
+
|
| 55 |
+
return pos[i] == -1;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
uint32_t get_used() const {
|
| 59 |
+
return used.size();
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
// the index of the first cell that is used
|
| 63 |
+
// return 0 if no cells are used
|
| 64 |
+
uint32_t used_min() const {
|
| 65 |
+
return used.empty() ? 0 : *used.begin();
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// the index of the last cell that is used + 1
|
| 69 |
+
// return 0 if no cells are used
|
| 70 |
+
uint32_t used_max_p1() const {
|
| 71 |
+
#if 0
|
| 72 |
+
if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
|
| 73 |
+
if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
|
| 74 |
+
if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
|
| 75 |
+
#endif
|
| 76 |
+
|
| 77 |
+
return used.empty() ? 0 : *used.rbegin() + 1;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
bool get_has_shift() const {
|
| 81 |
+
return has_shift;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// move cell isrc to idst (used during defrag)
|
| 85 |
+
void mv(uint32_t isrc, uint32_t idst) {
|
| 86 |
+
assert(isrc < pos.size());
|
| 87 |
+
assert(idst < pos.size());
|
| 88 |
+
|
| 89 |
+
pos [idst] = pos [isrc];
|
| 90 |
+
shift[idst] = shift[isrc];
|
| 91 |
+
seq [idst] = seq [isrc];
|
| 92 |
+
|
| 93 |
+
pos [isrc] = -1;
|
| 94 |
+
shift[isrc] = 0;
|
| 95 |
+
seq [isrc].reset();
|
| 96 |
+
|
| 97 |
+
used.erase (isrc);
|
| 98 |
+
used.insert(idst);
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
|
| 102 |
+
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
|
| 103 |
+
assert(i + n <= pos.size());
|
| 104 |
+
|
| 105 |
+
llama_kv_cells_unified res;
|
| 106 |
+
|
| 107 |
+
res.resize(n);
|
| 108 |
+
|
| 109 |
+
for (uint32_t j = 0; j < n; ++j) {
|
| 110 |
+
res.pos[j] = pos[i + j];
|
| 111 |
+
res.seq[j] = seq[i + j];
|
| 112 |
+
|
| 113 |
+
assert(shift[i + j] == 0);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
return res;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
|
| 120 |
+
void set(uint32_t i, const llama_kv_cells_unified & other) {
|
| 121 |
+
assert(i + other.pos.size() <= pos.size());
|
| 122 |
+
|
| 123 |
+
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
| 124 |
+
if (pos[i + j] == -1 && other.pos[j] != -1) {
|
| 125 |
+
used.insert(i + j);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
if (pos[i + j] != -1 && other.pos[j] == -1) {
|
| 129 |
+
used.erase(i + j);
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
if (pos[i + j] != -1) {
|
| 133 |
+
seq_pos_rm(i + j);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
pos[i + j] = other.pos[j];
|
| 137 |
+
seq[i + j] = other.seq[j];
|
| 138 |
+
|
| 139 |
+
if (pos[i + j] != -1) {
|
| 140 |
+
seq_pos_add(i + j);
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
assert(shift[i + j] == 0);
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// note: call only if the cell has seq_id
|
| 148 |
+
// return true if the cell becomes empty
|
| 149 |
+
bool seq_rm(uint32_t i, llama_seq_id seq_id) {
|
| 150 |
+
assert(i < pos.size());
|
| 151 |
+
assert(seq[i].test(seq_id));
|
| 152 |
+
assert(pos[i] != -1);
|
| 153 |
+
assert(seq_id >= 0);
|
| 154 |
+
|
| 155 |
+
seq[i].reset(seq_id);
|
| 156 |
+
seq_pos[seq_id].erase(pos[i]);
|
| 157 |
+
|
| 158 |
+
if (seq[i].none()) {
|
| 159 |
+
pos[i] = -1;
|
| 160 |
+
|
| 161 |
+
used.erase(i);
|
| 162 |
+
|
| 163 |
+
return true;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
return false;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
// return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
|
| 170 |
+
bool seq_keep(uint32_t i, llama_seq_id seq_id) {
|
| 171 |
+
assert(i < pos.size());
|
| 172 |
+
|
| 173 |
+
if (seq[i].test(seq_id)) {
|
| 174 |
+
seq_pos_rm(i);
|
| 175 |
+
seq[i].reset();
|
| 176 |
+
|
| 177 |
+
seq[i].set(seq_id);
|
| 178 |
+
seq_pos[seq_id].insert(pos[i]);
|
| 179 |
+
|
| 180 |
+
return false;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
if (seq[i].any()) {
|
| 184 |
+
seq_pos_rm(i);
|
| 185 |
+
seq[i].reset();
|
| 186 |
+
|
| 187 |
+
pos[i] = -1;
|
| 188 |
+
|
| 189 |
+
used.erase(i);
|
| 190 |
+
|
| 191 |
+
return true;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
assert(pos[i] == -1);
|
| 195 |
+
|
| 196 |
+
return false;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
bool seq_has(uint32_t i, llama_seq_id seq_id) const {
|
| 200 |
+
assert(i < pos.size());
|
| 201 |
+
assert(seq_id >= 0);
|
| 202 |
+
|
| 203 |
+
return seq[i].test(seq_id);
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
// note: call only if the cell is not empty and the seq_id is not in the cell
|
| 207 |
+
void seq_add(uint32_t i, llama_seq_id seq_id) {
|
| 208 |
+
assert(i < pos.size());
|
| 209 |
+
assert(pos[i] != -1);
|
| 210 |
+
assert(!seq[i].test(seq_id));
|
| 211 |
+
|
| 212 |
+
seq[i].set(seq_id);
|
| 213 |
+
seq_pos[seq_id].insert(pos[i]);
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
// the minimum position of sequence seq_id currently present in any of the cells
|
| 217 |
+
// return -1 if the sequence is not present
|
| 218 |
+
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
| 219 |
+
assert(seq_id >= 0);
|
| 220 |
+
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
|
| 221 |
+
|
| 222 |
+
if (seq_pos[seq_id].empty()) {
|
| 223 |
+
return -1;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
return *seq_pos[seq_id].begin();
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
// the maximum position of sequence seq_id currently present in any of the cells
|
| 230 |
+
// return -1 if the sequence is not present
|
| 231 |
+
llama_pos seq_pos_max(llama_seq_id seq_id) const {
|
| 232 |
+
assert(seq_id >= 0);
|
| 233 |
+
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
|
| 234 |
+
|
| 235 |
+
if (seq_pos[seq_id].empty()) {
|
| 236 |
+
return -1;
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
return *seq_pos[seq_id].rbegin();
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
// note: call only if the cell is not empty
|
| 243 |
+
llama_pos pos_get(uint32_t i) const {
|
| 244 |
+
assert(i < pos.size());
|
| 245 |
+
assert(pos[i] != -1);
|
| 246 |
+
|
| 247 |
+
return pos[i];
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
// note: call only if the cell is not empty
|
| 251 |
+
llama_pos get_shift(uint32_t i) const {
|
| 252 |
+
assert(i < pos.size());
|
| 253 |
+
assert(pos[i] != -1);
|
| 254 |
+
|
| 255 |
+
return shift[i];
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
// check if a cell is not empty and its position is within [p0, p1)
|
| 259 |
+
bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
|
| 260 |
+
assert(i < pos.size());
|
| 261 |
+
|
| 262 |
+
return pos[i] >= p0 && pos[i] < p1;
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
// set the position of an empty cell
|
| 266 |
+
// does not modify "has_shift"
|
| 267 |
+
// note: call only if the cell is empty
|
| 268 |
+
void pos_set(uint32_t i, llama_pos p) {
|
| 269 |
+
assert(i < pos.size());
|
| 270 |
+
assert(pos[i] == -1);
|
| 271 |
+
|
| 272 |
+
pos[i] = p;
|
| 273 |
+
|
| 274 |
+
used.insert(i);
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
// pos[i] = pos[i] + d
|
| 278 |
+
// sets "has_shift" to true
|
| 279 |
+
// note: call only if the cell is not empty
|
| 280 |
+
bool pos_add(uint32_t i, llama_pos d) {
|
| 281 |
+
assert(i < pos.size());
|
| 282 |
+
assert(pos[i] != -1);
|
| 283 |
+
|
| 284 |
+
seq_pos_rm(i);
|
| 285 |
+
|
| 286 |
+
pos[i] += d;
|
| 287 |
+
shift[i] += d;
|
| 288 |
+
|
| 289 |
+
seq_pos_add(i);
|
| 290 |
+
|
| 291 |
+
has_shift = true;
|
| 292 |
+
|
| 293 |
+
if (pos[i] < 0) {
|
| 294 |
+
seq_pos_rm(i);
|
| 295 |
+
|
| 296 |
+
seq[i].reset();
|
| 297 |
+
pos[i] = -1;
|
| 298 |
+
|
| 299 |
+
used.erase(i);
|
| 300 |
+
|
| 301 |
+
return true;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
return false;
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
// pos[i] = pos[i] / d
|
| 308 |
+
// sets "has_shift" to true
|
| 309 |
+
// note: call only if the cell is not empty
|
| 310 |
+
void pos_div(uint32_t i, int d) {
|
| 311 |
+
assert(i < pos.size());
|
| 312 |
+
assert(pos[i] != -1);
|
| 313 |
+
|
| 314 |
+
const llama_pos p_old = pos[i];
|
| 315 |
+
|
| 316 |
+
seq_pos_rm(i);
|
| 317 |
+
|
| 318 |
+
pos[i] /= d;
|
| 319 |
+
shift[i] += p_old - pos[i];
|
| 320 |
+
|
| 321 |
+
seq_pos_add(i);
|
| 322 |
+
|
| 323 |
+
has_shift = true;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
private:
|
| 327 |
+
bool has_shift = false;
|
| 328 |
+
|
| 329 |
+
// set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
|
| 330 |
+
std::set<uint32_t> used;
|
| 331 |
+
|
| 332 |
+
std::vector<llama_pos> pos;
|
| 333 |
+
|
| 334 |
+
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
|
| 335 |
+
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
|
| 336 |
+
//
|
| 337 |
+
// cells.pos_add(x, shift_x);
|
| 338 |
+
// cells.pos_div(y, shift_y);
|
| 339 |
+
// ...
|
| 340 |
+
//
|
| 341 |
+
// if (cells.has_shift()) {
|
| 342 |
+
// for (int i = 0; i < n; ++i) {
|
| 343 |
+
// auto shift_i = cells.get_shift(i);
|
| 344 |
+
// ...
|
| 345 |
+
// }
|
| 346 |
+
// cells.reset_shift();
|
| 347 |
+
// }
|
| 348 |
+
//
|
| 349 |
+
std::vector<llama_pos> shift;
|
| 350 |
+
|
| 351 |
+
using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
|
| 352 |
+
|
| 353 |
+
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
| 354 |
+
std::vector<bits_t> seq;
|
| 355 |
+
|
| 356 |
+
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
| 357 |
+
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
| 358 |
+
std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
|
| 359 |
+
|
| 360 |
+
// helper functions for updating `seq_pos`, once cell at a time:
|
| 361 |
+
|
| 362 |
+
// remove cell i
|
| 363 |
+
void seq_pos_rm(uint32_t i) {
|
| 364 |
+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
| 365 |
+
if (seq[i].test(s)) {
|
| 366 |
+
seq_pos[s].erase(pos[i]);
|
| 367 |
+
}
|
| 368 |
+
}
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
// add cell i
|
| 372 |
+
void seq_pos_add(uint32_t i) {
|
| 373 |
+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
| 374 |
+
if (seq[i].test(s)) {
|
| 375 |
+
seq_pos[s].insert(pos[i]);
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
};
|
examples/talk-llama/llama-memory.h
CHANGED
|
@@ -7,8 +7,8 @@ struct llama_memory_params {
|
|
| 7 |
ggml_type type_k;
|
| 8 |
ggml_type type_v;
|
| 9 |
|
| 10 |
-
//
|
| 11 |
-
|
| 12 |
};
|
| 13 |
|
| 14 |
// general concept of LLM memory
|
|
@@ -22,9 +22,10 @@ public:
|
|
| 22 |
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
| 23 |
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
| 24 |
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
| 25 |
-
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos
|
| 26 |
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
| 27 |
|
|
|
|
| 28 |
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
| 29 |
|
| 30 |
virtual bool get_can_edit() const = 0;
|
|
|
|
| 7 |
ggml_type type_k;
|
| 8 |
ggml_type type_v;
|
| 9 |
|
| 10 |
+
// use full-size SWA cache
|
| 11 |
+
bool swa_full;
|
| 12 |
};
|
| 13 |
|
| 14 |
// general concept of LLM memory
|
|
|
|
| 22 |
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
| 23 |
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
| 24 |
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
| 25 |
+
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
|
| 26 |
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
| 27 |
|
| 28 |
+
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
|
| 29 |
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
| 30 |
|
| 31 |
virtual bool get_can_edit() const = 0;
|
examples/talk-llama/llama-model.cpp
CHANGED
|
@@ -463,11 +463,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 463 |
GGML_ASSERT(hparams.n_expert_used == 0);
|
| 464 |
}
|
| 465 |
|
| 466 |
-
// zero-out the array hparams
|
| 467 |
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
| 468 |
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
|
| 469 |
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
|
| 470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false);
|
| 472 |
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
|
| 473 |
|
|
@@ -571,9 +574,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 571 |
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
| 572 |
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
| 573 |
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
|
| 574 |
-
|
| 575 |
-
hparams.
|
| 576 |
-
hparams.n_swa
|
|
|
|
| 577 |
|
| 578 |
switch (hparams.n_expert) {
|
| 579 |
case 16: type = LLM_TYPE_17B_16E; break;
|
|
@@ -852,22 +856,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 852 |
default: type = LLM_TYPE_UNKNOWN;
|
| 853 |
}
|
| 854 |
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
//
|
| 862 |
-
hparams.
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
hparams.n_swa = 131072;
|
| 867 |
-
}
|
| 868 |
-
bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
| 869 |
-
if (!found_swa && hparams.n_swa == 0) {
|
| 870 |
-
throw std::runtime_error("invalid value for sliding_window");
|
| 871 |
}
|
| 872 |
} break;
|
| 873 |
case LLM_ARCH_PHIMOE:
|
|
@@ -937,8 +936,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 937 |
} break;
|
| 938 |
case LLM_ARCH_GEMMA2:
|
| 939 |
{
|
|
|
|
| 940 |
hparams.n_swa = 4096; // default value of gemma 2
|
| 941 |
-
hparams.
|
| 942 |
hparams.attn_soft_cap = true;
|
| 943 |
|
| 944 |
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
|
@@ -955,7 +955,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 955 |
} break;
|
| 956 |
case LLM_ARCH_GEMMA3:
|
| 957 |
{
|
| 958 |
-
hparams.
|
|
|
|
| 959 |
|
| 960 |
hparams.rope_freq_base_train_swa = 10000.0f;
|
| 961 |
hparams.rope_freq_scale_train_swa = 1.0f;
|
|
@@ -1039,7 +1040,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 1039 |
} break;
|
| 1040 |
case LLM_ARCH_COHERE2:
|
| 1041 |
{
|
| 1042 |
-
hparams.
|
|
|
|
| 1043 |
|
| 1044 |
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
| 1045 |
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
|
@@ -2487,7 +2489,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
| 2487 |
|
| 2488 |
// output
|
| 2489 |
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
| 2490 |
-
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2491 |
|
| 2492 |
for (int i = 0; i < n_layer; ++i) {
|
| 2493 |
auto & layer = layers[i];
|
|
@@ -4321,7 +4327,7 @@ void llama_model::print_info() const {
|
|
| 4321 |
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
|
| 4322 |
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
|
| 4323 |
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
|
| 4324 |
-
LLAMA_LOG_INFO("%s:
|
| 4325 |
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
|
| 4326 |
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
|
| 4327 |
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
|
|
@@ -4489,7 +4495,17 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
|
|
| 4489 |
return it->second;
|
| 4490 |
}
|
| 4491 |
|
| 4492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4493 |
// choose long/short freq factors based on the context size
|
| 4494 |
if (layers[il].rope_freqs != nullptr) {
|
| 4495 |
return layers[il].rope_freqs;
|
|
@@ -4517,21 +4533,174 @@ struct llm_build_llama : public llm_graph_context {
|
|
| 4517 |
// inp_pos - contains the positions
|
| 4518 |
ggml_tensor * inp_pos = build_inp_pos();
|
| 4519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4520 |
// temperature tuning
|
| 4521 |
ggml_tensor * inp_attn_scale = nullptr;
|
| 4522 |
-
|
| 4523 |
-
inp_attn_scale = build_inp_attn_scale();
|
| 4524 |
-
}
|
| 4525 |
|
| 4526 |
-
auto * inp_attn =
|
| 4527 |
|
| 4528 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
|
|
| 4529 |
for (int il = 0; il < n_layer; ++il) {
|
| 4530 |
ggml_tensor * inpSA = inpL;
|
| 4531 |
|
| 4532 |
-
bool use_rope =
|
| 4533 |
-
? (il + 1) % hparams.n_no_rope_layer_step != 0
|
| 4534 |
-
: true;
|
| 4535 |
|
| 4536 |
// norm
|
| 4537 |
cur = build_norm(inpL,
|
|
@@ -4542,7 +4711,7 @@ struct llm_build_llama : public llm_graph_context {
|
|
| 4542 |
// self-attention
|
| 4543 |
{
|
| 4544 |
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
| 4545 |
-
ggml_tensor * rope_factors = model.get_rope_factors(
|
| 4546 |
|
| 4547 |
// compute Q and K and RoPE them
|
| 4548 |
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
@@ -4590,7 +4759,7 @@ struct llm_build_llama : public llm_graph_context {
|
|
| 4590 |
cb(Kcur, "Kcur", il);
|
| 4591 |
cb(Vcur, "Vcur", il);
|
| 4592 |
|
| 4593 |
-
if (
|
| 4594 |
// Llama4TextL2Norm
|
| 4595 |
Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
|
| 4596 |
Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
|
|
@@ -4616,7 +4785,6 @@ struct llm_build_llama : public llm_graph_context {
|
|
| 4616 |
|
| 4617 |
// feed-forward network (non-MoE)
|
| 4618 |
if (model.layers[il].ffn_gate_inp == nullptr) {
|
| 4619 |
-
|
| 4620 |
cur = build_norm(ffn_inp,
|
| 4621 |
model.layers[il].ffn_norm, NULL,
|
| 4622 |
LLM_NORM_RMS, il);
|
|
@@ -4629,9 +4797,7 @@ struct llm_build_llama : public llm_graph_context {
|
|
| 4629 |
NULL,
|
| 4630 |
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
| 4631 |
cb(cur, "ffn_out", il);
|
| 4632 |
-
|
| 4633 |
-
} else if (arch == LLM_ARCH_LLAMA4) {
|
| 4634 |
-
// llama4 MoE
|
| 4635 |
ggml_tensor * ffn_inp_normed = build_norm(ffn_inp,
|
| 4636 |
model.layers[il].ffn_norm, NULL,
|
| 4637 |
LLM_NORM_RMS, il);
|
|
@@ -4660,26 +4826,6 @@ struct llm_build_llama : public llm_graph_context {
|
|
| 4660 |
|
| 4661 |
cur = ggml_add(ctx0, moe_out, shexp_out);
|
| 4662 |
cb(cur, "ffn_moe_out_merged", il);
|
| 4663 |
-
|
| 4664 |
-
} else {
|
| 4665 |
-
// MoE branch
|
| 4666 |
-
cur = build_norm(ffn_inp,
|
| 4667 |
-
model.layers[il].ffn_norm, NULL,
|
| 4668 |
-
LLM_NORM_RMS, il);
|
| 4669 |
-
cb(cur, "ffn_norm", il);
|
| 4670 |
-
|
| 4671 |
-
cur = build_moe_ffn(cur,
|
| 4672 |
-
model.layers[il].ffn_gate_inp,
|
| 4673 |
-
model.layers[il].ffn_up_exps,
|
| 4674 |
-
model.layers[il].ffn_gate_exps,
|
| 4675 |
-
model.layers[il].ffn_down_exps,
|
| 4676 |
-
nullptr,
|
| 4677 |
-
n_expert, n_expert_used,
|
| 4678 |
-
LLM_FFN_SILU, true,
|
| 4679 |
-
false, 0.0,
|
| 4680 |
-
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
| 4681 |
-
il);
|
| 4682 |
-
cb(cur, "ffn_moe_out", il);
|
| 4683 |
}
|
| 4684 |
|
| 4685 |
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
@@ -4753,7 +4899,7 @@ struct llm_build_deci : public llm_graph_context {
|
|
| 4753 |
} else if (n_head > 0) {
|
| 4754 |
// self-attention
|
| 4755 |
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
| 4756 |
-
ggml_tensor * rope_factors = model.get_rope_factors(
|
| 4757 |
|
| 4758 |
// compute Q and K and RoPE them
|
| 4759 |
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
@@ -7202,6 +7348,7 @@ struct llm_build_phi2 : public llm_graph_context {
|
|
| 7202 |
}
|
| 7203 |
};
|
| 7204 |
|
|
|
|
| 7205 |
struct llm_build_phi3 : public llm_graph_context {
|
| 7206 |
llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 7207 |
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
@@ -7217,7 +7364,14 @@ struct llm_build_phi3 : public llm_graph_context {
|
|
| 7217 |
// inp_pos - contains the positions
|
| 7218 |
ggml_tensor * inp_pos = build_inp_pos();
|
| 7219 |
|
| 7220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7221 |
|
| 7222 |
for (int il = 0; il < n_layer; ++il) {
|
| 7223 |
auto * residual = inpL;
|
|
@@ -7225,7 +7379,7 @@ struct llm_build_phi3 : public llm_graph_context {
|
|
| 7225 |
// self-attention
|
| 7226 |
{
|
| 7227 |
// rope freq factors for 128k context
|
| 7228 |
-
ggml_tensor * rope_factors = model.get_rope_factors(
|
| 7229 |
|
| 7230 |
ggml_tensor* attn_norm_output = build_norm(inpL,
|
| 7231 |
model.layers[il].attn_norm,
|
|
@@ -7977,7 +8131,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
|
|
| 7977 |
for (int il = 0; il < n_layer; ++il) {
|
| 7978 |
ggml_tensor * inpSA = inpL;
|
| 7979 |
|
| 7980 |
-
ggml_tensor * rope_factors = model.get_rope_factors(
|
| 7981 |
|
| 7982 |
// norm
|
| 7983 |
cur = build_norm(inpL,
|
|
@@ -8277,8 +8431,8 @@ struct llm_build_gemma : public llm_graph_context {
|
|
| 8277 |
}
|
| 8278 |
};
|
| 8279 |
|
| 8280 |
-
struct
|
| 8281 |
-
|
| 8282 |
const int64_t n_embd_head = hparams.n_embd_head_k;
|
| 8283 |
|
| 8284 |
ggml_tensor * cur;
|
|
@@ -8292,7 +8446,7 @@ struct llm_build_gemma2 : public llm_graph_context {
|
|
| 8292 |
// inp_pos - contains the positions
|
| 8293 |
ggml_tensor * inp_pos = build_inp_pos();
|
| 8294 |
|
| 8295 |
-
auto * inp_attn =
|
| 8296 |
|
| 8297 |
for (int il = 0; il < n_layer; ++il) {
|
| 8298 |
// norm
|
|
@@ -8414,8 +8568,8 @@ struct llm_build_gemma2 : public llm_graph_context {
|
|
| 8414 |
}
|
| 8415 |
};
|
| 8416 |
|
| 8417 |
-
struct
|
| 8418 |
-
|
| 8419 |
const int64_t n_embd_head = hparams.n_embd_head_k;
|
| 8420 |
|
| 8421 |
ggml_tensor * cur;
|
|
@@ -8433,13 +8587,11 @@ struct llm_build_gemma3 : public llm_graph_context {
|
|
| 8433 |
ggml_tensor * inp_pos = build_inp_pos();
|
| 8434 |
|
| 8435 |
// TODO: is causal == true correct? might need some changes
|
| 8436 |
-
auto * inp_attn =
|
| 8437 |
|
| 8438 |
for (int il = 0; il < n_layer; ++il) {
|
| 8439 |
-
const
|
| 8440 |
-
|
| 8441 |
-
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
|
| 8442 |
-
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
|
| 8443 |
|
| 8444 |
// norm
|
| 8445 |
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
|
@@ -9016,8 +9168,8 @@ struct llm_build_command_r : public llm_graph_context {
|
|
| 9016 |
}
|
| 9017 |
};
|
| 9018 |
|
| 9019 |
-
struct
|
| 9020 |
-
|
| 9021 |
const int64_t n_embd_head = hparams.n_embd_head_v;
|
| 9022 |
|
| 9023 |
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
@@ -9032,7 +9184,7 @@ struct llm_build_cohere2 : public llm_graph_context {
|
|
| 9032 |
// inp_pos - contains the positions
|
| 9033 |
ggml_tensor * inp_pos = build_inp_pos();
|
| 9034 |
|
| 9035 |
-
auto * inp_attn =
|
| 9036 |
|
| 9037 |
for (int il = 0; il < n_layer; ++il) {
|
| 9038 |
const bool is_swa = hparams.is_swa(il);
|
|
@@ -9045,7 +9197,7 @@ struct llm_build_cohere2 : public llm_graph_context {
|
|
| 9045 |
// self-attention
|
| 9046 |
{
|
| 9047 |
// rope freq factors for 128k context
|
| 9048 |
-
ggml_tensor * rope_factors = model.get_rope_factors(
|
| 9049 |
|
| 9050 |
// compute Q and K and RoPE them
|
| 9051 |
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
@@ -9983,7 +10135,7 @@ struct llm_build_deepseek : public llm_graph_context {
|
|
| 9983 |
// self-attention
|
| 9984 |
{
|
| 9985 |
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
| 9986 |
-
ggml_tensor * rope_factors = model.get_rope_factors(
|
| 9987 |
|
| 9988 |
// compute Q and K and RoPE them
|
| 9989 |
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
@@ -11347,7 +11499,7 @@ struct llm_build_exaone : public llm_graph_context {
|
|
| 11347 |
// self-attention
|
| 11348 |
{
|
| 11349 |
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
| 11350 |
-
ggml_tensor * rope_factors = model.get_rope_factors(
|
| 11351 |
|
| 11352 |
// compute Q and K and RoPE them
|
| 11353 |
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
@@ -12263,7 +12415,7 @@ struct llm_build_granite : public llm_graph_context {
|
|
| 12263 |
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 12264 |
|
| 12265 |
if (use_rope) {
|
| 12266 |
-
ggml_tensor * rope_factors = model.get_rope_factors(
|
| 12267 |
Qcur = ggml_rope_ext(
|
| 12268 |
ctx0, Qcur, inp_pos, rope_factors,
|
| 12269 |
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
@@ -12916,7 +13068,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
|
|
| 12916 |
// self-attention
|
| 12917 |
{
|
| 12918 |
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
| 12919 |
-
ggml_tensor * rope_factors = model.get_rope_factors(
|
| 12920 |
|
| 12921 |
// compute Q and K and RoPE them
|
| 12922 |
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
@@ -13044,6 +13196,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|
| 13044 |
case LLM_ARCH_JINA_BERT_V2:
|
| 13045 |
case LLM_ARCH_NOMIC_BERT:
|
| 13046 |
case LLM_ARCH_NOMIC_BERT_MOE:
|
|
|
|
| 13047 |
{
|
| 13048 |
res = nullptr;
|
| 13049 |
} break;
|
|
@@ -13058,7 +13211,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|
| 13058 |
GGML_TYPE_F32,
|
| 13059 |
GGML_TYPE_F32,
|
| 13060 |
cparams.offload_kqv,
|
| 13061 |
-
std::max((uint32_t) 1, cparams.n_seq_max)
|
|
|
|
| 13062 |
} break;
|
| 13063 |
default:
|
| 13064 |
{
|
|
@@ -13068,14 +13222,36 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|
| 13068 |
|
| 13069 |
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
| 13070 |
|
| 13071 |
-
|
| 13072 |
-
|
| 13073 |
-
|
| 13074 |
-
|
| 13075 |
-
|
| 13076 |
-
|
| 13077 |
-
|
| 13078 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13079 |
}
|
| 13080 |
}
|
| 13081 |
|
|
@@ -13090,11 +13266,14 @@ llm_graph_result_ptr llama_model::build_graph(
|
|
| 13090 |
|
| 13091 |
switch (arch) {
|
| 13092 |
case LLM_ARCH_LLAMA:
|
| 13093 |
-
case LLM_ARCH_LLAMA4:
|
| 13094 |
case LLM_ARCH_MINICPM:
|
| 13095 |
{
|
| 13096 |
llm = std::make_unique<llm_build_llama>(*this, params, gf);
|
| 13097 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13098 |
case LLM_ARCH_DECI:
|
| 13099 |
{
|
| 13100 |
llm = std::make_unique<llm_build_deci>(*this, params, gf);
|
|
@@ -13169,7 +13348,11 @@ llm_graph_result_ptr llama_model::build_graph(
|
|
| 13169 |
case LLM_ARCH_PHI3:
|
| 13170 |
case LLM_ARCH_PHIMOE:
|
| 13171 |
{
|
| 13172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13173 |
} break;
|
| 13174 |
case LLM_ARCH_PLAMO:
|
| 13175 |
{
|
|
@@ -13201,11 +13384,11 @@ llm_graph_result_ptr llama_model::build_graph(
|
|
| 13201 |
} break;
|
| 13202 |
case LLM_ARCH_GEMMA2:
|
| 13203 |
{
|
| 13204 |
-
llm = std::make_unique<
|
| 13205 |
} break;
|
| 13206 |
case LLM_ARCH_GEMMA3:
|
| 13207 |
{
|
| 13208 |
-
llm = std::make_unique<
|
| 13209 |
} break;
|
| 13210 |
case LLM_ARCH_STARCODER2:
|
| 13211 |
{
|
|
@@ -13225,7 +13408,7 @@ llm_graph_result_ptr llama_model::build_graph(
|
|
| 13225 |
} break;
|
| 13226 |
case LLM_ARCH_COHERE2:
|
| 13227 |
{
|
| 13228 |
-
llm = std::make_unique<
|
| 13229 |
} break;
|
| 13230 |
case LLM_ARCH_DBRX:
|
| 13231 |
{
|
|
|
|
| 463 |
GGML_ASSERT(hparams.n_expert_used == 0);
|
| 464 |
}
|
| 465 |
|
|
|
|
| 466 |
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
| 467 |
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
|
| 468 |
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
|
| 469 |
|
| 470 |
+
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
|
| 471 |
+
|
| 472 |
+
std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0);
|
| 473 |
+
|
| 474 |
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false);
|
| 475 |
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
|
| 476 |
|
|
|
|
| 574 |
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
| 575 |
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
| 576 |
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
|
| 577 |
+
|
| 578 |
+
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
|
| 579 |
+
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
|
| 580 |
+
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
| 581 |
|
| 582 |
switch (hparams.n_expert) {
|
| 583 |
case 16: type = LLM_TYPE_17B_16E; break;
|
|
|
|
| 856 |
default: type = LLM_TYPE_UNKNOWN;
|
| 857 |
}
|
| 858 |
|
| 859 |
+
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
| 860 |
+
|
| 861 |
+
if (found_swa && hparams.n_swa > 0) {
|
| 862 |
+
LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n",
|
| 863 |
+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13676");
|
| 864 |
+
|
| 865 |
+
// TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern`
|
| 866 |
+
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
| 867 |
+
|
| 868 |
+
hparams.n_swa = 0;
|
| 869 |
+
hparams.set_swa_pattern(1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 870 |
}
|
| 871 |
} break;
|
| 872 |
case LLM_ARCH_PHIMOE:
|
|
|
|
| 936 |
} break;
|
| 937 |
case LLM_ARCH_GEMMA2:
|
| 938 |
{
|
| 939 |
+
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
| 940 |
hparams.n_swa = 4096; // default value of gemma 2
|
| 941 |
+
hparams.set_swa_pattern(2);
|
| 942 |
hparams.attn_soft_cap = true;
|
| 943 |
|
| 944 |
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
|
|
|
| 955 |
} break;
|
| 956 |
case LLM_ARCH_GEMMA3:
|
| 957 |
{
|
| 958 |
+
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
| 959 |
+
hparams.set_swa_pattern(6);
|
| 960 |
|
| 961 |
hparams.rope_freq_base_train_swa = 10000.0f;
|
| 962 |
hparams.rope_freq_scale_train_swa = 1.0f;
|
|
|
|
| 1040 |
} break;
|
| 1041 |
case LLM_ARCH_COHERE2:
|
| 1042 |
{
|
| 1043 |
+
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
| 1044 |
+
hparams.set_swa_pattern(4);
|
| 1045 |
|
| 1046 |
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
| 1047 |
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
|
|
|
| 2489 |
|
| 2490 |
// output
|
| 2491 |
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
| 2492 |
+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
| 2493 |
+
// if output is NULL, init from the input tok embed
|
| 2494 |
+
if (output == NULL) {
|
| 2495 |
+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
| 2496 |
+
}
|
| 2497 |
|
| 2498 |
for (int i = 0; i < n_layer; ++i) {
|
| 2499 |
auto & layer = layers[i];
|
|
|
|
| 4327 |
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
|
| 4328 |
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
|
| 4329 |
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
|
| 4330 |
+
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
|
| 4331 |
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
|
| 4332 |
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
|
| 4333 |
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
|
|
|
|
| 4495 |
return it->second;
|
| 4496 |
}
|
| 4497 |
|
| 4498 |
+
float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const {
|
| 4499 |
+
return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
|
| 4500 |
+
}
|
| 4501 |
+
|
| 4502 |
+
float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const {
|
| 4503 |
+
return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
|
| 4504 |
+
}
|
| 4505 |
+
|
| 4506 |
+
ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
|
| 4507 |
+
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
| 4508 |
+
|
| 4509 |
// choose long/short freq factors based on the context size
|
| 4510 |
if (layers[il].rope_freqs != nullptr) {
|
| 4511 |
return layers[il].rope_freqs;
|
|
|
|
| 4533 |
// inp_pos - contains the positions
|
| 4534 |
ggml_tensor * inp_pos = build_inp_pos();
|
| 4535 |
|
| 4536 |
+
auto * inp_attn = build_attn_inp_kv_unified();
|
| 4537 |
+
|
| 4538 |
+
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
| 4539 |
+
|
| 4540 |
+
for (int il = 0; il < n_layer; ++il) {
|
| 4541 |
+
ggml_tensor * inpSA = inpL;
|
| 4542 |
+
|
| 4543 |
+
// norm
|
| 4544 |
+
cur = build_norm(inpL,
|
| 4545 |
+
model.layers[il].attn_norm, NULL,
|
| 4546 |
+
LLM_NORM_RMS, il);
|
| 4547 |
+
cb(cur, "attn_norm", il);
|
| 4548 |
+
|
| 4549 |
+
// self-attention
|
| 4550 |
+
{
|
| 4551 |
+
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
| 4552 |
+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
| 4553 |
+
|
| 4554 |
+
// compute Q and K and RoPE them
|
| 4555 |
+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
| 4556 |
+
cb(Qcur, "Qcur", il);
|
| 4557 |
+
if (model.layers[il].bq) {
|
| 4558 |
+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
| 4559 |
+
cb(Qcur, "Qcur", il);
|
| 4560 |
+
}
|
| 4561 |
+
|
| 4562 |
+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
| 4563 |
+
cb(Kcur, "Kcur", il);
|
| 4564 |
+
if (model.layers[il].bk) {
|
| 4565 |
+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
| 4566 |
+
cb(Kcur, "Kcur", il);
|
| 4567 |
+
}
|
| 4568 |
+
|
| 4569 |
+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
| 4570 |
+
cb(Vcur, "Vcur", il);
|
| 4571 |
+
if (model.layers[il].bv) {
|
| 4572 |
+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
| 4573 |
+
cb(Vcur, "Vcur", il);
|
| 4574 |
+
}
|
| 4575 |
+
|
| 4576 |
+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 4577 |
+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
| 4578 |
+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 4579 |
+
|
| 4580 |
+
Qcur = ggml_rope_ext(
|
| 4581 |
+
ctx0, Qcur, inp_pos, rope_factors,
|
| 4582 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 4583 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 4584 |
+
);
|
| 4585 |
+
|
| 4586 |
+
Kcur = ggml_rope_ext(
|
| 4587 |
+
ctx0, Kcur, inp_pos, rope_factors,
|
| 4588 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 4589 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 4590 |
+
);
|
| 4591 |
+
|
| 4592 |
+
cb(Qcur, "Qcur", il);
|
| 4593 |
+
cb(Kcur, "Kcur", il);
|
| 4594 |
+
cb(Vcur, "Vcur", il);
|
| 4595 |
+
|
| 4596 |
+
cur = build_attn(inp_attn, gf,
|
| 4597 |
+
model.layers[il].wo, model.layers[il].bo,
|
| 4598 |
+
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
| 4599 |
+
cb(cur, "attn_out", il);
|
| 4600 |
+
}
|
| 4601 |
+
|
| 4602 |
+
if (il == n_layer - 1) {
|
| 4603 |
+
// skip computing output for unused tokens
|
| 4604 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 4605 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 4606 |
+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 4607 |
+
}
|
| 4608 |
+
|
| 4609 |
+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
| 4610 |
+
cb(ffn_inp, "ffn_inp", il);
|
| 4611 |
+
|
| 4612 |
+
// feed-forward network (non-MoE)
|
| 4613 |
+
if (model.layers[il].ffn_gate_inp == nullptr) {
|
| 4614 |
+
|
| 4615 |
+
cur = build_norm(ffn_inp,
|
| 4616 |
+
model.layers[il].ffn_norm, NULL,
|
| 4617 |
+
LLM_NORM_RMS, il);
|
| 4618 |
+
cb(cur, "ffn_norm", il);
|
| 4619 |
+
|
| 4620 |
+
cur = build_ffn(cur,
|
| 4621 |
+
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
| 4622 |
+
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
| 4623 |
+
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
| 4624 |
+
NULL,
|
| 4625 |
+
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
| 4626 |
+
cb(cur, "ffn_out", il);
|
| 4627 |
+
} else {
|
| 4628 |
+
// MoE branch
|
| 4629 |
+
cur = build_norm(ffn_inp,
|
| 4630 |
+
model.layers[il].ffn_norm, NULL,
|
| 4631 |
+
LLM_NORM_RMS, il);
|
| 4632 |
+
cb(cur, "ffn_norm", il);
|
| 4633 |
+
|
| 4634 |
+
cur = build_moe_ffn(cur,
|
| 4635 |
+
model.layers[il].ffn_gate_inp,
|
| 4636 |
+
model.layers[il].ffn_up_exps,
|
| 4637 |
+
model.layers[il].ffn_gate_exps,
|
| 4638 |
+
model.layers[il].ffn_down_exps,
|
| 4639 |
+
nullptr,
|
| 4640 |
+
n_expert, n_expert_used,
|
| 4641 |
+
LLM_FFN_SILU, true,
|
| 4642 |
+
false, 0.0,
|
| 4643 |
+
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
| 4644 |
+
il);
|
| 4645 |
+
cb(cur, "ffn_moe_out", il);
|
| 4646 |
+
}
|
| 4647 |
+
|
| 4648 |
+
cur = ggml_add(ctx0, cur, ffn_inp);
|
| 4649 |
+
cb(cur, "ffn_out", il);
|
| 4650 |
+
|
| 4651 |
+
cur = build_cvec(cur, il);
|
| 4652 |
+
cb(cur, "l_out", il);
|
| 4653 |
+
|
| 4654 |
+
// input for next layer
|
| 4655 |
+
inpL = cur;
|
| 4656 |
+
}
|
| 4657 |
+
|
| 4658 |
+
cur = inpL;
|
| 4659 |
+
|
| 4660 |
+
cur = build_norm(cur,
|
| 4661 |
+
model.output_norm, NULL,
|
| 4662 |
+
LLM_NORM_RMS, -1);
|
| 4663 |
+
|
| 4664 |
+
cb(cur, "result_norm", -1);
|
| 4665 |
+
res->t_embd = cur;
|
| 4666 |
+
|
| 4667 |
+
// lm_head
|
| 4668 |
+
cur = build_lora_mm(model.output, cur);
|
| 4669 |
+
|
| 4670 |
+
cb(cur, "result_output", -1);
|
| 4671 |
+
res->t_logits = cur;
|
| 4672 |
+
|
| 4673 |
+
ggml_build_forward_expand(gf, cur);
|
| 4674 |
+
}
|
| 4675 |
+
};
|
| 4676 |
+
|
| 4677 |
+
struct llm_build_llama_iswa : public llm_graph_context {
|
| 4678 |
+
llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 4679 |
+
const int64_t n_embd_head = hparams.n_embd_head_v;
|
| 4680 |
+
|
| 4681 |
+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
| 4682 |
+
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
| 4683 |
+
|
| 4684 |
+
ggml_tensor * cur;
|
| 4685 |
+
ggml_tensor * inpL;
|
| 4686 |
+
|
| 4687 |
+
inpL = build_inp_embd(model.tok_embd);
|
| 4688 |
+
|
| 4689 |
+
// inp_pos - contains the positions
|
| 4690 |
+
ggml_tensor * inp_pos = build_inp_pos();
|
| 4691 |
+
|
| 4692 |
// temperature tuning
|
| 4693 |
ggml_tensor * inp_attn_scale = nullptr;
|
| 4694 |
+
inp_attn_scale = build_inp_attn_scale();
|
|
|
|
|
|
|
| 4695 |
|
| 4696 |
+
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
| 4697 |
|
| 4698 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
| 4699 |
+
|
| 4700 |
for (int il = 0; il < n_layer; ++il) {
|
| 4701 |
ggml_tensor * inpSA = inpL;
|
| 4702 |
|
| 4703 |
+
const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
|
|
|
|
|
|
|
| 4704 |
|
| 4705 |
// norm
|
| 4706 |
cur = build_norm(inpL,
|
|
|
|
| 4711 |
// self-attention
|
| 4712 |
{
|
| 4713 |
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
| 4714 |
+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
| 4715 |
|
| 4716 |
// compute Q and K and RoPE them
|
| 4717 |
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
|
|
| 4759 |
cb(Kcur, "Kcur", il);
|
| 4760 |
cb(Vcur, "Vcur", il);
|
| 4761 |
|
| 4762 |
+
if (use_rope && hparams.use_kq_norm) {
|
| 4763 |
// Llama4TextL2Norm
|
| 4764 |
Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
|
| 4765 |
Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
|
|
|
|
| 4785 |
|
| 4786 |
// feed-forward network (non-MoE)
|
| 4787 |
if (model.layers[il].ffn_gate_inp == nullptr) {
|
|
|
|
| 4788 |
cur = build_norm(ffn_inp,
|
| 4789 |
model.layers[il].ffn_norm, NULL,
|
| 4790 |
LLM_NORM_RMS, il);
|
|
|
|
| 4797 |
NULL,
|
| 4798 |
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
| 4799 |
cb(cur, "ffn_out", il);
|
| 4800 |
+
} else {
|
|
|
|
|
|
|
| 4801 |
ggml_tensor * ffn_inp_normed = build_norm(ffn_inp,
|
| 4802 |
model.layers[il].ffn_norm, NULL,
|
| 4803 |
LLM_NORM_RMS, il);
|
|
|
|
| 4826 |
|
| 4827 |
cur = ggml_add(ctx0, moe_out, shexp_out);
|
| 4828 |
cb(cur, "ffn_moe_out_merged", il);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4829 |
}
|
| 4830 |
|
| 4831 |
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
|
|
| 4899 |
} else if (n_head > 0) {
|
| 4900 |
// self-attention
|
| 4901 |
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
| 4902 |
+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
| 4903 |
|
| 4904 |
// compute Q and K and RoPE them
|
| 4905 |
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
|
|
| 7348 |
}
|
| 7349 |
};
|
| 7350 |
|
| 7351 |
+
template<bool iswa>
|
| 7352 |
struct llm_build_phi3 : public llm_graph_context {
|
| 7353 |
llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 7354 |
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
|
|
| 7364 |
// inp_pos - contains the positions
|
| 7365 |
ggml_tensor * inp_pos = build_inp_pos();
|
| 7366 |
|
| 7367 |
+
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
|
| 7368 |
+
inp_attn_type * inp_attn = nullptr;
|
| 7369 |
+
|
| 7370 |
+
if constexpr (iswa) {
|
| 7371 |
+
inp_attn = build_attn_inp_kv_unified_iswa();
|
| 7372 |
+
} else {
|
| 7373 |
+
inp_attn = build_attn_inp_kv_unified();
|
| 7374 |
+
}
|
| 7375 |
|
| 7376 |
for (int il = 0; il < n_layer; ++il) {
|
| 7377 |
auto * residual = inpL;
|
|
|
|
| 7379 |
// self-attention
|
| 7380 |
{
|
| 7381 |
// rope freq factors for 128k context
|
| 7382 |
+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
| 7383 |
|
| 7384 |
ggml_tensor* attn_norm_output = build_norm(inpL,
|
| 7385 |
model.layers[il].attn_norm,
|
|
|
|
| 8131 |
for (int il = 0; il < n_layer; ++il) {
|
| 8132 |
ggml_tensor * inpSA = inpL;
|
| 8133 |
|
| 8134 |
+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
| 8135 |
|
| 8136 |
// norm
|
| 8137 |
cur = build_norm(inpL,
|
|
|
|
| 8431 |
}
|
| 8432 |
};
|
| 8433 |
|
| 8434 |
+
struct llm_build_gemma2_iswa : public llm_graph_context {
|
| 8435 |
+
llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 8436 |
const int64_t n_embd_head = hparams.n_embd_head_k;
|
| 8437 |
|
| 8438 |
ggml_tensor * cur;
|
|
|
|
| 8446 |
// inp_pos - contains the positions
|
| 8447 |
ggml_tensor * inp_pos = build_inp_pos();
|
| 8448 |
|
| 8449 |
+
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
| 8450 |
|
| 8451 |
for (int il = 0; il < n_layer; ++il) {
|
| 8452 |
// norm
|
|
|
|
| 8568 |
}
|
| 8569 |
};
|
| 8570 |
|
| 8571 |
+
struct llm_build_gemma3_iswa : public llm_graph_context {
|
| 8572 |
+
llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 8573 |
const int64_t n_embd_head = hparams.n_embd_head_k;
|
| 8574 |
|
| 8575 |
ggml_tensor * cur;
|
|
|
|
| 8587 |
ggml_tensor * inp_pos = build_inp_pos();
|
| 8588 |
|
| 8589 |
// TODO: is causal == true correct? might need some changes
|
| 8590 |
+
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
| 8591 |
|
| 8592 |
for (int il = 0; il < n_layer; ++il) {
|
| 8593 |
+
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
| 8594 |
+
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
|
|
|
|
|
|
| 8595 |
|
| 8596 |
// norm
|
| 8597 |
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
|
|
|
| 9168 |
}
|
| 9169 |
};
|
| 9170 |
|
| 9171 |
+
struct llm_build_cohere2_iswa : public llm_graph_context {
|
| 9172 |
+
llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 9173 |
const int64_t n_embd_head = hparams.n_embd_head_v;
|
| 9174 |
|
| 9175 |
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
|
|
| 9184 |
// inp_pos - contains the positions
|
| 9185 |
ggml_tensor * inp_pos = build_inp_pos();
|
| 9186 |
|
| 9187 |
+
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
| 9188 |
|
| 9189 |
for (int il = 0; il < n_layer; ++il) {
|
| 9190 |
const bool is_swa = hparams.is_swa(il);
|
|
|
|
| 9197 |
// self-attention
|
| 9198 |
{
|
| 9199 |
// rope freq factors for 128k context
|
| 9200 |
+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
| 9201 |
|
| 9202 |
// compute Q and K and RoPE them
|
| 9203 |
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
|
|
| 10135 |
// self-attention
|
| 10136 |
{
|
| 10137 |
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
| 10138 |
+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
| 10139 |
|
| 10140 |
// compute Q and K and RoPE them
|
| 10141 |
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
|
|
| 11499 |
// self-attention
|
| 11500 |
{
|
| 11501 |
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
| 11502 |
+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
| 11503 |
|
| 11504 |
// compute Q and K and RoPE them
|
| 11505 |
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
|
|
| 12415 |
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 12416 |
|
| 12417 |
if (use_rope) {
|
| 12418 |
+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
| 12419 |
Qcur = ggml_rope_ext(
|
| 12420 |
ctx0, Qcur, inp_pos, rope_factors,
|
| 12421 |
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
|
|
| 13068 |
// self-attention
|
| 13069 |
{
|
| 13070 |
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
| 13071 |
+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
| 13072 |
|
| 13073 |
// compute Q and K and RoPE them
|
| 13074 |
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
|
|
| 13196 |
case LLM_ARCH_JINA_BERT_V2:
|
| 13197 |
case LLM_ARCH_NOMIC_BERT:
|
| 13198 |
case LLM_ARCH_NOMIC_BERT_MOE:
|
| 13199 |
+
case LLM_ARCH_WAVTOKENIZER_DEC:
|
| 13200 |
{
|
| 13201 |
res = nullptr;
|
| 13202 |
} break;
|
|
|
|
| 13211 |
GGML_TYPE_F32,
|
| 13212 |
GGML_TYPE_F32,
|
| 13213 |
cparams.offload_kqv,
|
| 13214 |
+
std::max((uint32_t) 1, cparams.n_seq_max),
|
| 13215 |
+
cparams.n_seq_max);
|
| 13216 |
} break;
|
| 13217 |
default:
|
| 13218 |
{
|
|
|
|
| 13222 |
|
| 13223 |
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
| 13224 |
|
| 13225 |
+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
| 13226 |
+
GGML_ASSERT(hparams.is_swa_any());
|
| 13227 |
+
|
| 13228 |
+
res = new llama_kv_cache_unified_iswa(
|
| 13229 |
+
*this,
|
| 13230 |
+
params.type_k,
|
| 13231 |
+
params.type_v,
|
| 13232 |
+
!cparams.flash_attn,
|
| 13233 |
+
cparams.offload_kqv,
|
| 13234 |
+
params.swa_full,
|
| 13235 |
+
cparams.n_ctx,
|
| 13236 |
+
cparams.n_seq_max,
|
| 13237 |
+
cparams.n_batch,
|
| 13238 |
+
padding);
|
| 13239 |
+
} else {
|
| 13240 |
+
GGML_ASSERT(!hparams.is_swa_any());
|
| 13241 |
+
|
| 13242 |
+
res = new llama_kv_cache_unified(
|
| 13243 |
+
*this,
|
| 13244 |
+
nullptr,
|
| 13245 |
+
params.type_k,
|
| 13246 |
+
params.type_v,
|
| 13247 |
+
!cparams.flash_attn,
|
| 13248 |
+
cparams.offload_kqv,
|
| 13249 |
+
cparams.n_ctx,
|
| 13250 |
+
cparams.n_seq_max,
|
| 13251 |
+
padding,
|
| 13252 |
+
hparams.n_swa,
|
| 13253 |
+
hparams.swa_type);
|
| 13254 |
+
}
|
| 13255 |
}
|
| 13256 |
}
|
| 13257 |
|
|
|
|
| 13266 |
|
| 13267 |
switch (arch) {
|
| 13268 |
case LLM_ARCH_LLAMA:
|
|
|
|
| 13269 |
case LLM_ARCH_MINICPM:
|
| 13270 |
{
|
| 13271 |
llm = std::make_unique<llm_build_llama>(*this, params, gf);
|
| 13272 |
} break;
|
| 13273 |
+
case LLM_ARCH_LLAMA4:
|
| 13274 |
+
{
|
| 13275 |
+
llm = std::make_unique<llm_build_llama_iswa>(*this, params, gf);
|
| 13276 |
+
} break;
|
| 13277 |
case LLM_ARCH_DECI:
|
| 13278 |
{
|
| 13279 |
llm = std::make_unique<llm_build_deci>(*this, params, gf);
|
|
|
|
| 13348 |
case LLM_ARCH_PHI3:
|
| 13349 |
case LLM_ARCH_PHIMOE:
|
| 13350 |
{
|
| 13351 |
+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
| 13352 |
+
llm = std::make_unique<llm_build_phi3<true>> (*this, params, gf);
|
| 13353 |
+
} else {
|
| 13354 |
+
llm = std::make_unique<llm_build_phi3<false>>(*this, params, gf);
|
| 13355 |
+
}
|
| 13356 |
} break;
|
| 13357 |
case LLM_ARCH_PLAMO:
|
| 13358 |
{
|
|
|
|
| 13384 |
} break;
|
| 13385 |
case LLM_ARCH_GEMMA2:
|
| 13386 |
{
|
| 13387 |
+
llm = std::make_unique<llm_build_gemma2_iswa>(*this, params, gf);
|
| 13388 |
} break;
|
| 13389 |
case LLM_ARCH_GEMMA3:
|
| 13390 |
{
|
| 13391 |
+
llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
|
| 13392 |
} break;
|
| 13393 |
case LLM_ARCH_STARCODER2:
|
| 13394 |
{
|
|
|
|
| 13408 |
} break;
|
| 13409 |
case LLM_ARCH_COHERE2:
|
| 13410 |
{
|
| 13411 |
+
llm = std::make_unique<llm_build_cohere2_iswa>(*this, params, gf);
|
| 13412 |
} break;
|
| 13413 |
case LLM_ARCH_DBRX:
|
| 13414 |
{
|
examples/talk-llama/llama-model.h
CHANGED
|
@@ -398,7 +398,10 @@ struct llama_model {
|
|
| 398 |
|
| 399 |
const struct ggml_tensor * get_tensor(const char * name) const;
|
| 400 |
|
| 401 |
-
|
|
|
|
|
|
|
|
|
|
| 402 |
|
| 403 |
// note: can mutate `cparams`
|
| 404 |
// TODO: move this to new llm_arch_model_i interface
|
|
|
|
| 398 |
|
| 399 |
const struct ggml_tensor * get_tensor(const char * name) const;
|
| 400 |
|
| 401 |
+
float get_rope_freq_base (const llama_cparams & cparams, int il) const;
|
| 402 |
+
float get_rope_freq_scale(const llama_cparams & cparams, int il) const;
|
| 403 |
+
|
| 404 |
+
ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
|
| 405 |
|
| 406 |
// note: can mutate `cparams`
|
| 407 |
// TODO: move this to new llm_arch_model_i interface
|
examples/talk-llama/llama-sampling.cpp
CHANGED
|
@@ -798,7 +798,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
|
|
| 798 |
}
|
| 799 |
|
| 800 |
// if we have enough values the operation was a success
|
| 801 |
-
if (filtered_tokens.size() >= ctx->min_keep) {
|
| 802 |
memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
| 803 |
cur_p->size = filtered_tokens.size();
|
| 804 |
min_p_applied = true;
|
|
@@ -909,7 +909,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
|
|
| 909 |
cum_sum += cur_p->data[idx].p;
|
| 910 |
|
| 911 |
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
| 912 |
-
if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
|
| 913 |
last_idx = i + 1;
|
| 914 |
break;
|
| 915 |
}
|
|
|
|
| 798 |
}
|
| 799 |
|
| 800 |
// if we have enough values the operation was a success
|
| 801 |
+
if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
|
| 802 |
memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
| 803 |
cur_p->size = filtered_tokens.size();
|
| 804 |
min_p_applied = true;
|
|
|
|
| 909 |
cum_sum += cur_p->data[idx].p;
|
| 910 |
|
| 911 |
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
| 912 |
+
if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) {
|
| 913 |
last_idx = i + 1;
|
| 914 |
break;
|
| 915 |
}
|
examples/talk-llama/llama-vocab.cpp
CHANGED
|
@@ -835,7 +835,7 @@ struct llm_tokenizer_ugm_session {
|
|
| 835 |
}
|
| 836 |
|
| 837 |
// initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
|
| 838 |
-
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -
|
| 839 |
// at the beginning tokenization score is zero
|
| 840 |
tokenization_results[0] = { vocab.token_unk(), 0, 0 };
|
| 841 |
|
|
@@ -867,7 +867,7 @@ struct llm_tokenizer_ugm_session {
|
|
| 867 |
const double challenger_score = current_best.score_sum + token_score;
|
| 868 |
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
| 869 |
if (challenger_score > current_champ.score_sum) {
|
| 870 |
-
struct best_tokenization challenger = { token_id, input_offset,
|
| 871 |
current_champ = challenger;
|
| 872 |
}
|
| 873 |
}
|
|
@@ -881,7 +881,7 @@ struct llm_tokenizer_ugm_session {
|
|
| 881 |
prefix_offset = input_offset + n_utf8_code_units;
|
| 882 |
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
| 883 |
if (challenger_score > current_champ.score_sum) {
|
| 884 |
-
struct best_tokenization challenger = { vocab.token_unk(), input_offset,
|
| 885 |
current_champ = challenger;
|
| 886 |
}
|
| 887 |
}
|
|
@@ -1007,7 +1007,7 @@ private:
|
|
| 1007 |
struct best_tokenization {
|
| 1008 |
llama_token token_id;
|
| 1009 |
size_t input_offset;
|
| 1010 |
-
|
| 1011 |
};
|
| 1012 |
|
| 1013 |
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
|
|
|
|
| 835 |
}
|
| 836 |
|
| 837 |
// initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
|
| 838 |
+
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX});
|
| 839 |
// at the beginning tokenization score is zero
|
| 840 |
tokenization_results[0] = { vocab.token_unk(), 0, 0 };
|
| 841 |
|
|
|
|
| 867 |
const double challenger_score = current_best.score_sum + token_score;
|
| 868 |
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
| 869 |
if (challenger_score > current_champ.score_sum) {
|
| 870 |
+
struct best_tokenization challenger = { token_id, input_offset, challenger_score };
|
| 871 |
current_champ = challenger;
|
| 872 |
}
|
| 873 |
}
|
|
|
|
| 881 |
prefix_offset = input_offset + n_utf8_code_units;
|
| 882 |
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
| 883 |
if (challenger_score > current_champ.score_sum) {
|
| 884 |
+
struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score };
|
| 885 |
current_champ = challenger;
|
| 886 |
}
|
| 887 |
}
|
|
|
|
| 1007 |
struct best_tokenization {
|
| 1008 |
llama_token token_id;
|
| 1009 |
size_t input_offset;
|
| 1010 |
+
double score_sum;
|
| 1011 |
};
|
| 1012 |
|
| 1013 |
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
|
examples/talk-llama/llama.h
CHANGED
|
@@ -361,10 +361,11 @@ extern "C" {
|
|
| 361 |
|
| 362 |
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
| 363 |
bool embeddings; // if true, extract embeddings (together with logits)
|
| 364 |
-
bool offload_kqv; //
|
| 365 |
-
bool flash_attn; //
|
| 366 |
-
bool no_perf; //
|
| 367 |
-
bool op_offload; //
|
|
|
|
| 368 |
};
|
| 369 |
|
| 370 |
// model quantization parameters
|
|
@@ -470,6 +471,7 @@ extern "C" {
|
|
| 470 |
LLAMA_API int64_t llama_time_us(void);
|
| 471 |
|
| 472 |
LLAMA_API size_t llama_max_devices(void);
|
|
|
|
| 473 |
|
| 474 |
LLAMA_API bool llama_supports_mmap (void);
|
| 475 |
LLAMA_API bool llama_supports_mlock (void);
|
|
@@ -607,71 +609,14 @@ extern "C" {
|
|
| 607 |
// KV cache
|
| 608 |
//
|
| 609 |
|
| 610 |
-
// TODO: start using struct llama_kv_cache
|
| 611 |
-
|
| 612 |
-
// Information associated with an individual cell in the KV cache view.
|
| 613 |
-
struct llama_kv_cache_view_cell {
|
| 614 |
-
// The position for this cell. Takes KV cache shifts into account.
|
| 615 |
-
// May be negative if the cell is not populated.
|
| 616 |
-
llama_pos pos;
|
| 617 |
-
};
|
| 618 |
-
|
| 619 |
-
// An updateable view of the KV cache.
|
| 620 |
-
struct llama_kv_cache_view {
|
| 621 |
-
// Number of KV cache cells. This will be the same as the context size.
|
| 622 |
-
int32_t n_cells;
|
| 623 |
-
|
| 624 |
-
// Maximum number of sequences that can exist in a cell. It's not an error
|
| 625 |
-
// if there are more sequences in a cell than this value, however they will
|
| 626 |
-
// not be visible in the view cells_sequences.
|
| 627 |
-
int32_t n_seq_max;
|
| 628 |
-
|
| 629 |
-
// Number of tokens in the cache. For example, if there are two populated
|
| 630 |
-
// cells, the first with 1 sequence id in it and the second with 2 sequence
|
| 631 |
-
// ids then you'll have 3 tokens.
|
| 632 |
-
int32_t token_count;
|
| 633 |
-
|
| 634 |
-
// Number of populated cache cells.
|
| 635 |
-
int32_t used_cells;
|
| 636 |
-
|
| 637 |
-
// Maximum contiguous empty slots in the cache.
|
| 638 |
-
int32_t max_contiguous;
|
| 639 |
-
|
| 640 |
-
// Index to the start of the max_contiguous slot range. Can be negative
|
| 641 |
-
// when cache is full.
|
| 642 |
-
int32_t max_contiguous_idx;
|
| 643 |
-
|
| 644 |
-
// Information for an individual cell.
|
| 645 |
-
struct llama_kv_cache_view_cell * cells;
|
| 646 |
-
|
| 647 |
-
// The sequences for each cell. There will be n_seq_max items per cell.
|
| 648 |
-
llama_seq_id * cells_sequences;
|
| 649 |
-
};
|
| 650 |
-
|
| 651 |
-
// Create an empty KV cache view. (use only for debugging purposes)
|
| 652 |
-
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
|
| 653 |
-
|
| 654 |
-
// Free a KV cache view. (use only for debugging purposes)
|
| 655 |
-
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
|
| 656 |
-
|
| 657 |
-
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
| 658 |
-
// TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
|
| 659 |
-
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
| 660 |
-
|
| 661 |
-
///
|
| 662 |
-
|
| 663 |
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
| 664 |
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
| 665 |
-
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx)
|
| 666 |
-
|
| 667 |
-
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
| 668 |
-
"use llama_kv_self_n_tokens instead");
|
| 669 |
|
| 670 |
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
| 671 |
-
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx)
|
| 672 |
-
|
| 673 |
-
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
|
| 674 |
-
"use llama_kv_self_used_cells instead");
|
| 675 |
|
| 676 |
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
| 677 |
LLAMA_API void llama_kv_self_clear(
|
|
@@ -730,10 +675,18 @@ extern "C" {
|
|
| 730 |
llama_pos p1,
|
| 731 |
int d);
|
| 732 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 733 |
// Returns the largest position present in the KV cache for the specified sequence
|
|
|
|
| 734 |
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
| 735 |
struct llama_context * ctx,
|
| 736 |
-
|
| 737 |
|
| 738 |
// Defragment the KV cache
|
| 739 |
// This will be applied:
|
|
@@ -747,61 +700,6 @@ extern "C" {
|
|
| 747 |
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
| 748 |
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
|
| 749 |
|
| 750 |
-
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
|
| 751 |
-
struct llama_context * ctx),
|
| 752 |
-
"use llama_kv_self_clear instead");
|
| 753 |
-
|
| 754 |
-
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
|
| 755 |
-
struct llama_context * ctx,
|
| 756 |
-
llama_seq_id seq_id,
|
| 757 |
-
llama_pos p0,
|
| 758 |
-
llama_pos p1),
|
| 759 |
-
"use llama_kv_self_seq_rm instead");
|
| 760 |
-
|
| 761 |
-
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
|
| 762 |
-
struct llama_context * ctx,
|
| 763 |
-
llama_seq_id seq_id_src,
|
| 764 |
-
llama_seq_id seq_id_dst,
|
| 765 |
-
llama_pos p0,
|
| 766 |
-
llama_pos p1),
|
| 767 |
-
"use llama_kv_self_seq_cp instead");
|
| 768 |
-
|
| 769 |
-
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
|
| 770 |
-
struct llama_context * ctx,
|
| 771 |
-
llama_seq_id seq_id),
|
| 772 |
-
"use llama_kv_self_seq_keep instead");
|
| 773 |
-
|
| 774 |
-
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
|
| 775 |
-
struct llama_context * ctx,
|
| 776 |
-
llama_seq_id seq_id,
|
| 777 |
-
llama_pos p0,
|
| 778 |
-
llama_pos p1,
|
| 779 |
-
llama_pos delta),
|
| 780 |
-
"use llama_kv_self_seq_add instead");
|
| 781 |
-
|
| 782 |
-
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
|
| 783 |
-
struct llama_context * ctx,
|
| 784 |
-
llama_seq_id seq_id,
|
| 785 |
-
llama_pos p0,
|
| 786 |
-
llama_pos p1,
|
| 787 |
-
int d),
|
| 788 |
-
"use llama_kv_self_seq_div instead");
|
| 789 |
-
|
| 790 |
-
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
| 791 |
-
struct llama_context * ctx,
|
| 792 |
-
llama_seq_id seq_id),
|
| 793 |
-
"use llama_kv_self_seq_pos_max instead");
|
| 794 |
-
|
| 795 |
-
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
|
| 796 |
-
"use llama_kv_self_defrag instead");
|
| 797 |
-
|
| 798 |
-
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
|
| 799 |
-
"use llama_kv_self_can_shift instead");
|
| 800 |
-
|
| 801 |
-
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
|
| 802 |
-
"use llama_kv_self_update instead");
|
| 803 |
-
|
| 804 |
-
|
| 805 |
//
|
| 806 |
// State / sessions
|
| 807 |
//
|
|
@@ -943,9 +841,12 @@ extern "C" {
|
|
| 943 |
// Requires KV cache.
|
| 944 |
// For encode-decoder contexts, processes the batch using the decoder.
|
| 945 |
// Positive return values does not mean a fatal error, but rather a warning.
|
| 946 |
-
//
|
| 947 |
-
//
|
| 948 |
-
//
|
|
|
|
|
|
|
|
|
|
| 949 |
LLAMA_API int32_t llama_decode(
|
| 950 |
struct llama_context * ctx,
|
| 951 |
struct llama_batch batch);
|
|
|
|
| 361 |
|
| 362 |
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
| 363 |
bool embeddings; // if true, extract embeddings (together with logits)
|
| 364 |
+
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
|
| 365 |
+
bool flash_attn; // use flash attention [EXPERIMENTAL]
|
| 366 |
+
bool no_perf; // measure performance timings
|
| 367 |
+
bool op_offload; // offload host tensor operations to device
|
| 368 |
+
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
| 369 |
};
|
| 370 |
|
| 371 |
// model quantization parameters
|
|
|
|
| 471 |
LLAMA_API int64_t llama_time_us(void);
|
| 472 |
|
| 473 |
LLAMA_API size_t llama_max_devices(void);
|
| 474 |
+
LLAMA_API size_t llama_max_parallel_sequences(void);
|
| 475 |
|
| 476 |
LLAMA_API bool llama_supports_mmap (void);
|
| 477 |
LLAMA_API bool llama_supports_mlock (void);
|
|
|
|
| 609 |
// KV cache
|
| 610 |
//
|
| 611 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
| 613 |
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
| 614 |
+
DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
|
| 615 |
+
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
|
|
|
|
|
|
|
| 616 |
|
| 617 |
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
| 618 |
+
DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
|
| 619 |
+
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
|
|
|
|
|
|
|
| 620 |
|
| 621 |
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
| 622 |
LLAMA_API void llama_kv_self_clear(
|
|
|
|
| 675 |
llama_pos p1,
|
| 676 |
int d);
|
| 677 |
|
| 678 |
+
// Returns the smallest position present in the KV cache for the specified sequence
|
| 679 |
+
// This is typically non-zero only for SWA caches
|
| 680 |
+
// Return -1 if the sequence is empty
|
| 681 |
+
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
| 682 |
+
struct llama_context * ctx,
|
| 683 |
+
llama_seq_id seq_id);
|
| 684 |
+
|
| 685 |
// Returns the largest position present in the KV cache for the specified sequence
|
| 686 |
+
// Return -1 if the sequence is empty
|
| 687 |
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
| 688 |
struct llama_context * ctx,
|
| 689 |
+
llama_seq_id seq_id);
|
| 690 |
|
| 691 |
// Defragment the KV cache
|
| 692 |
// This will be applied:
|
|
|
|
| 700 |
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
| 701 |
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
|
| 702 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
//
|
| 704 |
// State / sessions
|
| 705 |
//
|
|
|
|
| 841 |
// Requires KV cache.
|
| 842 |
// For encode-decoder contexts, processes the batch using the decoder.
|
| 843 |
// Positive return values does not mean a fatal error, but rather a warning.
|
| 844 |
+
// Upon non-zero return values, the KV cache state is restored to the state before this call
|
| 845 |
+
// 0 - success
|
| 846 |
+
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
| 847 |
+
// 2 - aborted
|
| 848 |
+
// -1 - invalid input batch
|
| 849 |
+
// < -1 - error
|
| 850 |
LLAMA_API int32_t llama_decode(
|
| 851 |
struct llama_context * ctx,
|
| 852 |
struct llama_batch batch);
|