ggerganov commited on
Commit
5d037b9
·
1 Parent(s): 78bfd81

talk-llama : sync llama.cpp

Browse files
examples/talk-llama/llama-batch.cpp CHANGED
@@ -1,5 +1,6 @@
1
  #include "llama-batch.h"
2
 
 
3
  #include <cstring>
4
  #include <algorithm>
5
 
@@ -281,9 +282,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
281
  batch = in_batch;
282
  GGML_ASSERT(batch.n_tokens > 0);
283
  if (!batch.pos) {
 
284
  pos.resize(batch.n_tokens);
285
  for (int32_t i = 0; i < batch.n_tokens; i++) {
286
- pos[i] = i + p0;
287
  }
288
  batch.pos = pos.data();
289
  }
 
1
  #include "llama-batch.h"
2
 
3
+ #include <cassert>
4
  #include <cstring>
5
  #include <algorithm>
6
 
 
282
  batch = in_batch;
283
  GGML_ASSERT(batch.n_tokens > 0);
284
  if (!batch.pos) {
285
+ assert(p0 >= 0);
286
  pos.resize(batch.n_tokens);
287
  for (int32_t i = 0; i < batch.n_tokens; i++) {
288
+ pos[i] = p0 + i;
289
  }
290
  batch.pos = pos.data();
291
  }
examples/talk-llama/llama-context.cpp CHANGED
@@ -25,7 +25,11 @@ llama_context::llama_context(
25
 
26
  const auto & hparams = model.hparams;
27
 
28
- cparams.n_seq_max = std::max(1u, params.n_seq_max);
 
 
 
 
29
  cparams.n_threads = params.n_threads;
30
  cparams.n_threads_batch = params.n_threads_batch;
31
  cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -93,6 +97,7 @@ llama_context::llama_context(
93
  }
94
 
95
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
 
96
  cparams.op_offload = params.op_offload;
97
 
98
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@@ -176,8 +181,9 @@ llama_context::llama_context(
176
  // init the memory module
177
  if (!hparams.vocab_only) {
178
  llama_memory_params params_mem = {
179
- /*.type_k =*/ params.type_k,
180
- /*.type_v =*/ params.type_v,
 
181
  };
182
 
183
  memory.reset(model.create_memory(params_mem, cparams));
@@ -687,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) {
687
 
688
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
689
 
 
690
  if (batch.token) {
691
  for (int32_t i = 0; i < n_tokens; ++i) {
692
  if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
693
  LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
694
  return -1;
695
  }
 
 
 
 
 
696
  }
697
  }
698
 
@@ -846,7 +858,7 @@ int llama_context::encode(llama_batch & inp_batch) {
846
 
847
  int llama_context::decode(llama_batch & inp_batch) {
848
  if (!memory) {
849
- LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
850
  return encode(inp_batch);
851
  }
852
 
@@ -855,11 +867,17 @@ int llama_context::decode(llama_batch & inp_batch) {
855
  return -1;
856
  }
857
 
 
 
 
 
 
 
 
858
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
859
 
860
  // temporary allocate memory for the input batch if needed
861
- // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
862
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
863
 
864
  const llama_batch & batch = batch_allocr.batch;
865
 
@@ -875,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) {
875
 
876
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
877
 
 
878
  if (batch.token) {
879
  for (int64_t i = 0; i < n_tokens_all; ++i) {
880
  if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
881
  LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
882
- throw std::runtime_error("invalid token");
 
 
 
 
 
883
  }
884
  }
885
  }
@@ -947,8 +971,6 @@ int llama_context::decode(llama_batch & inp_batch) {
947
 
948
  // find KV slot
949
  if (!kv_self->find_slot(ubatch)) {
950
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
951
-
952
  return 1;
953
  }
954
 
@@ -2093,6 +2115,7 @@ llama_context_params llama_context_default_params() {
2093
  /*.flash_attn =*/ false,
2094
  /*.no_perf =*/ true,
2095
  /*.op_offload =*/ true,
 
2096
  };
2097
 
2098
  return result;
@@ -2287,65 +2310,51 @@ int32_t llama_apply_adapter_cvec(
2287
  return res ? 0 : -1;
2288
  }
2289
 
2290
- //
2291
- // kv cache view
2292
- //
2293
-
2294
- llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
2295
- const auto * kv = ctx->get_kv_self();
2296
- if (kv == nullptr) {
2297
- LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2298
- return {};
2299
- }
2300
-
2301
- return llama_kv_cache_view_init(*kv, n_seq_max);
2302
- }
2303
-
2304
- void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
2305
- const auto * kv = ctx->get_kv_self();
2306
- if (kv == nullptr) {
2307
- LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2308
- return;
2309
- }
2310
-
2311
- llama_kv_cache_view_update(view, kv);
2312
- }
2313
-
2314
  //
2315
  // kv cache
2316
  //
2317
 
2318
  // deprecated
2319
- int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
2320
- return llama_kv_self_n_tokens(ctx);
2321
- }
2322
-
2323
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2324
  const auto * kv = ctx->get_kv_self();
2325
  if (!kv) {
2326
  return 0;
2327
  }
2328
 
2329
- return kv->get_n_tokens();
2330
- }
2331
 
2332
- // deprecated
2333
- int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
2334
- return llama_kv_self_used_cells(ctx);
 
 
 
 
 
 
 
2335
  }
2336
 
 
 
2337
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2338
  const auto * kv = ctx->get_kv_self();
2339
  if (!kv) {
2340
  return 0;
2341
  }
2342
 
2343
- return kv->get_used_cells();
2344
- }
2345
 
2346
- // deprecated
2347
- void llama_kv_cache_clear(llama_context * ctx) {
2348
- llama_kv_self_clear(ctx);
 
 
 
 
 
 
 
2349
  }
2350
 
2351
  void llama_kv_self_clear(llama_context * ctx) {
@@ -2357,15 +2366,6 @@ void llama_kv_self_clear(llama_context * ctx) {
2357
  kv->clear();
2358
  }
2359
 
2360
- // deprecated
2361
- bool llama_kv_cache_seq_rm(
2362
- llama_context * ctx,
2363
- llama_seq_id seq_id,
2364
- llama_pos p0,
2365
- llama_pos p1) {
2366
- return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
2367
- }
2368
-
2369
  bool llama_kv_self_seq_rm(
2370
  llama_context * ctx,
2371
  llama_seq_id seq_id,
@@ -2379,16 +2379,6 @@ bool llama_kv_self_seq_rm(
2379
  return kv->seq_rm(seq_id, p0, p1);
2380
  }
2381
 
2382
- // deprecated
2383
- void llama_kv_cache_seq_cp(
2384
- llama_context * ctx,
2385
- llama_seq_id seq_id_src,
2386
- llama_seq_id seq_id_dst,
2387
- llama_pos p0,
2388
- llama_pos p1) {
2389
- llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2390
- }
2391
-
2392
  void llama_kv_self_seq_cp(
2393
  llama_context * ctx,
2394
  llama_seq_id seq_id_src,
@@ -2403,13 +2393,6 @@ void llama_kv_self_seq_cp(
2403
  kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2404
  }
2405
 
2406
- // deprecated
2407
- void llama_kv_cache_seq_keep(
2408
- llama_context * ctx,
2409
- llama_seq_id seq_id) {
2410
- llama_kv_self_seq_keep(ctx, seq_id);
2411
- }
2412
-
2413
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2414
  auto * kv = ctx->get_kv_self();
2415
  if (!kv) {
@@ -2419,16 +2402,6 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2419
  kv->seq_keep(seq_id);
2420
  }
2421
 
2422
- // deprecated
2423
- void llama_kv_cache_seq_add(
2424
- llama_context * ctx,
2425
- llama_seq_id seq_id,
2426
- llama_pos p0,
2427
- llama_pos p1,
2428
- llama_pos delta) {
2429
- llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2430
- }
2431
-
2432
  void llama_kv_self_seq_add(
2433
  llama_context * ctx,
2434
  llama_seq_id seq_id,
@@ -2443,16 +2416,6 @@ void llama_kv_self_seq_add(
2443
  kv->seq_add(seq_id, p0, p1, delta);
2444
  }
2445
 
2446
- // deprecated
2447
- void llama_kv_cache_seq_div(
2448
- llama_context * ctx,
2449
- llama_seq_id seq_id,
2450
- llama_pos p0,
2451
- llama_pos p1,
2452
- int d) {
2453
- llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2454
- }
2455
-
2456
  void llama_kv_self_seq_div(
2457
  llama_context * ctx,
2458
  llama_seq_id seq_id,
@@ -2467,25 +2430,24 @@ void llama_kv_self_seq_div(
2467
  kv->seq_div(seq_id, p0, p1, d);
2468
  }
2469
 
2470
- // deprecated
2471
- llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2472
- return llama_kv_self_seq_pos_max(ctx, seq_id);
 
 
 
 
2473
  }
2474
 
2475
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2476
  const auto * kv = ctx->get_kv_self();
2477
  if (!kv) {
2478
- return 0;
2479
  }
2480
 
2481
  return kv->seq_pos_max(seq_id);
2482
  }
2483
 
2484
- // deprecated
2485
- void llama_kv_cache_defrag(llama_context * ctx) {
2486
- llama_kv_self_defrag(ctx);
2487
- }
2488
-
2489
  void llama_kv_self_defrag(llama_context * ctx) {
2490
  auto * kv = ctx->get_kv_self();
2491
  if (!kv) {
@@ -2496,11 +2458,6 @@ void llama_kv_self_defrag(llama_context * ctx) {
2496
  kv->defrag_sched(-1.0f);
2497
  }
2498
 
2499
- // deprecated
2500
- bool llama_kv_cache_can_shift(const llama_context * ctx) {
2501
- return llama_kv_self_can_shift(ctx);
2502
- }
2503
-
2504
  bool llama_kv_self_can_shift(const llama_context * ctx) {
2505
  const auto * kv = ctx->get_kv_self();
2506
  if (!kv) {
@@ -2510,11 +2467,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) {
2510
  return kv->get_can_shift();
2511
  }
2512
 
2513
- // deprecated
2514
- void llama_kv_cache_update(llama_context * ctx) {
2515
- llama_kv_self_update(ctx);
2516
- }
2517
-
2518
  // llama state API
2519
 
2520
  // deprecated
@@ -2637,7 +2589,21 @@ int32_t llama_encode(
2637
  int32_t llama_decode(
2638
  llama_context * ctx,
2639
  llama_batch batch) {
2640
- const int ret = ctx->decode(batch);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2641
  if (ret != 0) {
2642
  LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2643
  }
 
25
 
26
  const auto & hparams = model.hparams;
27
 
28
+ cparams.n_seq_max = std::max(1u, params.n_seq_max);
29
+ if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
30
+ throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
31
+ }
32
+
33
  cparams.n_threads = params.n_threads;
34
  cparams.n_threads_batch = params.n_threads_batch;
35
  cparams.yarn_ext_factor = params.yarn_ext_factor;
 
97
  }
98
 
99
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
100
+
101
  cparams.op_offload = params.op_offload;
102
 
103
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
 
181
  // init the memory module
182
  if (!hparams.vocab_only) {
183
  llama_memory_params params_mem = {
184
+ /*.type_k =*/ params.type_k,
185
+ /*.type_v =*/ params.type_v,
186
+ /*.swa_full =*/ params.swa_full,
187
  };
188
 
189
  memory.reset(model.create_memory(params_mem, cparams));
 
693
 
694
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
695
 
696
+ // TODO: move the validation to the llama_batch_allocr
697
  if (batch.token) {
698
  for (int32_t i = 0; i < n_tokens; ++i) {
699
  if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
700
  LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
701
  return -1;
702
  }
703
+
704
+ if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
705
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
706
+ throw -1;
707
+ }
708
  }
709
  }
710
 
 
858
 
859
  int llama_context::decode(llama_batch & inp_batch) {
860
  if (!memory) {
861
+ LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
862
  return encode(inp_batch);
863
  }
864
 
 
867
  return -1;
868
  }
869
 
870
+ if (!inp_batch.pos) {
871
+ if (inp_batch.seq_id) {
872
+ LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
873
+ return -1;
874
+ }
875
+ }
876
+
877
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
878
 
879
  // temporary allocate memory for the input batch if needed
880
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
 
881
 
882
  const llama_batch & batch = batch_allocr.batch;
883
 
 
893
 
894
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
895
 
896
+ // TODO: move the validation to the llama_batch_allocr
897
  if (batch.token) {
898
  for (int64_t i = 0; i < n_tokens_all; ++i) {
899
  if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
900
  LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
901
+ return -1;
902
+ }
903
+
904
+ if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
905
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
906
+ return -1;
907
  }
908
  }
909
  }
 
971
 
972
  // find KV slot
973
  if (!kv_self->find_slot(ubatch)) {
 
 
974
  return 1;
975
  }
976
 
 
2115
  /*.flash_attn =*/ false,
2116
  /*.no_perf =*/ true,
2117
  /*.op_offload =*/ true,
2118
+ /*.swa_full =*/ true,
2119
  };
2120
 
2121
  return result;
 
2310
  return res ? 0 : -1;
2311
  }
2312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2313
  //
2314
  // kv cache
2315
  //
2316
 
2317
  // deprecated
 
 
 
 
2318
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2319
  const auto * kv = ctx->get_kv_self();
2320
  if (!kv) {
2321
  return 0;
2322
  }
2323
 
2324
+ int32_t res = 0;
 
2325
 
2326
+ for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2327
+ const llama_pos p0 = kv->seq_pos_min(s);
2328
+ const llama_pos p1 = kv->seq_pos_max(s);
2329
+
2330
+ if (p0 >= 0) {
2331
+ res += (p1 - p0) + 1;
2332
+ }
2333
+ }
2334
+
2335
+ return res;
2336
  }
2337
 
2338
+ // deprecated
2339
+ // note: this is the same as above - will be removed anyway, so it's ok
2340
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2341
  const auto * kv = ctx->get_kv_self();
2342
  if (!kv) {
2343
  return 0;
2344
  }
2345
 
2346
+ int32_t res = 0;
 
2347
 
2348
+ for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2349
+ const llama_pos p0 = kv->seq_pos_min(s);
2350
+ const llama_pos p1 = kv->seq_pos_max(s);
2351
+
2352
+ if (p0 >= 0) {
2353
+ res += (p1 - p0) + 1;
2354
+ }
2355
+ }
2356
+
2357
+ return res;
2358
  }
2359
 
2360
  void llama_kv_self_clear(llama_context * ctx) {
 
2366
  kv->clear();
2367
  }
2368
 
 
 
 
 
 
 
 
 
 
2369
  bool llama_kv_self_seq_rm(
2370
  llama_context * ctx,
2371
  llama_seq_id seq_id,
 
2379
  return kv->seq_rm(seq_id, p0, p1);
2380
  }
2381
 
 
 
 
 
 
 
 
 
 
 
2382
  void llama_kv_self_seq_cp(
2383
  llama_context * ctx,
2384
  llama_seq_id seq_id_src,
 
2393
  kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2394
  }
2395
 
 
 
 
 
 
 
 
2396
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2397
  auto * kv = ctx->get_kv_self();
2398
  if (!kv) {
 
2402
  kv->seq_keep(seq_id);
2403
  }
2404
 
 
 
 
 
 
 
 
 
 
 
2405
  void llama_kv_self_seq_add(
2406
  llama_context * ctx,
2407
  llama_seq_id seq_id,
 
2416
  kv->seq_add(seq_id, p0, p1, delta);
2417
  }
2418
 
 
 
 
 
 
 
 
 
 
 
2419
  void llama_kv_self_seq_div(
2420
  llama_context * ctx,
2421
  llama_seq_id seq_id,
 
2430
  kv->seq_div(seq_id, p0, p1, d);
2431
  }
2432
 
2433
+ llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2434
+ const auto * kv = ctx->get_kv_self();
2435
+ if (!kv) {
2436
+ return -1;
2437
+ }
2438
+
2439
+ return kv->seq_pos_min(seq_id);
2440
  }
2441
 
2442
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2443
  const auto * kv = ctx->get_kv_self();
2444
  if (!kv) {
2445
+ return -1;
2446
  }
2447
 
2448
  return kv->seq_pos_max(seq_id);
2449
  }
2450
 
 
 
 
 
 
2451
  void llama_kv_self_defrag(llama_context * ctx) {
2452
  auto * kv = ctx->get_kv_self();
2453
  if (!kv) {
 
2458
  kv->defrag_sched(-1.0f);
2459
  }
2460
 
 
 
 
 
 
2461
  bool llama_kv_self_can_shift(const llama_context * ctx) {
2462
  const auto * kv = ctx->get_kv_self();
2463
  if (!kv) {
 
2467
  return kv->get_can_shift();
2468
  }
2469
 
 
 
 
 
 
2470
  // llama state API
2471
 
2472
  // deprecated
 
2589
  int32_t llama_decode(
2590
  llama_context * ctx,
2591
  llama_batch batch) {
2592
+ int ret = ctx->decode(batch);
2593
+
2594
+ // defrag and try again
2595
+ // TODO: distinguish return code when we are sure that even after defrag there is no space available
2596
+ if (ret == 1) {
2597
+ llama_kv_self_defrag(ctx);
2598
+ ret = ctx->decode(batch);
2599
+
2600
+ if (ret == 1) {
2601
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2602
+
2603
+ return ret;
2604
+ }
2605
+ }
2606
+
2607
  if (ret != 0) {
2608
  LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2609
  }
examples/talk-llama/llama-cparams.cpp CHANGED
@@ -1 +1,5 @@
1
  #include "llama-cparams.h"
 
 
 
 
 
1
  #include "llama-cparams.h"
2
+
3
+ size_t llama_max_parallel_sequences(void) {
4
+ return LLAMA_MAX_PARALLEL_SEQUENCES;
5
+ }
examples/talk-llama/llama-cparams.h CHANGED
@@ -4,6 +4,8 @@
4
 
5
  #include <cstdint>
6
 
 
 
7
  struct llama_cparams {
8
  uint32_t n_ctx; // context size used during inference
9
  uint32_t n_batch;
 
4
 
5
  #include <cstdint>
6
 
7
+ #define LLAMA_MAX_PARALLEL_SEQUENCES 64
8
+
9
  struct llama_cparams {
10
  uint32_t n_ctx; // context size used during inference
11
  uint32_t n_batch;
examples/talk-llama/llama-grammar.cpp CHANGED
@@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
1177
  for (const auto & trigger_pattern : grammar.trigger_patterns) {
1178
  if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
1179
  grammar.awaiting_trigger = false;
1180
- // get from the first match to the end of the string
1181
- auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
 
 
 
 
 
 
 
 
 
 
1182
  // std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
1183
  grammar.trigger_buffer.clear();
1184
  llama_grammar_accept_str(grammar, constrained_str);
 
1177
  for (const auto & trigger_pattern : grammar.trigger_patterns) {
1178
  if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
1179
  grammar.awaiting_trigger = false;
1180
+ // get from the first matched capturing group to the end of the string
1181
+ size_t start = std::string::npos;
1182
+ for (auto i = 1u; i < match.size(); i++) {
1183
+ if (match.length(i) > 0) {
1184
+ start = match.position(i);
1185
+ break;
1186
+ }
1187
+ }
1188
+ if (start == std::string::npos) {
1189
+ start = match.position(0);
1190
+ }
1191
+ auto constrained_str = grammar.trigger_buffer.substr(start);
1192
  // std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
1193
  grammar.trigger_buffer.clear();
1194
  llama_grammar_accept_str(grammar, constrained_str);
examples/talk-llama/llama-graph.cpp CHANGED
@@ -9,33 +9,6 @@
9
  #include <cmath>
10
  #include <cstring>
11
 
12
- static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
13
- // TODO move to hparams if a T5 variant appears that uses a different value
14
- const int64_t max_distance = 128;
15
-
16
- if (bidirectional) {
17
- n_buckets >>= 1;
18
- }
19
-
20
- const int64_t max_exact = n_buckets >> 1;
21
-
22
- int32_t relative_position = x - y;
23
- int32_t relative_bucket = 0;
24
-
25
- if (bidirectional) {
26
- relative_bucket += (relative_position > 0) * n_buckets;
27
- relative_position = abs(relative_position);
28
- } else {
29
- relative_position = -std::min<int32_t>(relative_position, 0);
30
- }
31
-
32
- int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
33
- relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
34
- relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
35
-
36
- return relative_bucket;
37
- }
38
-
39
  void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
40
  if (ubatch->token) {
41
  const int64_t n_tokens = ubatch->n_tokens;
@@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
110
 
111
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
112
  if (pos_bucket) {
113
- const int64_t n_tokens = ubatch->n_tokens;
114
-
115
- GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
116
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
117
-
118
- int32_t * data = (int32_t *) pos_bucket->data;
119
-
120
- const int64_t n_kv = kv_self->n;
121
-
122
- for (int h = 0; h < 1; ++h) {
123
- for (int j = 0; j < n_tokens; ++j) {
124
- for (int i = 0; i < n_kv; ++i) {
125
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
126
- }
127
- }
128
- }
129
  }
130
  }
131
 
@@ -403,99 +361,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
403
  }
404
 
405
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
406
- if (self_kq_mask || self_kq_mask_swa) {
407
- const int64_t n_kv = kv_self->n;
408
- const int64_t n_tokens = ubatch->n_tokens;
409
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
410
- const int64_t n_seqs = ubatch->n_seqs;
411
-
412
- float * data = nullptr;
413
- float * data_swa = nullptr;
414
-
415
- if (self_kq_mask) {
416
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
417
- data = (float *) self_kq_mask->data;
418
- }
419
-
420
- if (self_kq_mask_swa) {
421
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
422
- data_swa = (float *) self_kq_mask_swa->data;
423
- }
424
-
425
- // Use only the previous KV cells of the correct sequence for each token of the ubatch.
426
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
427
- // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
428
- // Causal mask:
429
- // xxx-------
430
- // xxxx------
431
- // xxxxx-----
432
- // Non-causal mask:
433
- // xxxxx-----
434
- // xxxxx-----
435
- // xxxxx-----
436
- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
437
- for (int h = 0; h < 1; ++h) {
438
- for (int s = 0; s < n_seqs; ++s) {
439
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
440
-
441
- for (int j = 0; j < n_seq_tokens; ++j) {
442
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
443
- for (int i = 0; i < n_kv; ++i) {
444
- float f;
445
- // mask the token if:
446
- if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
447
- || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
448
- ) {
449
- f = -INFINITY;
450
- } else {
451
- if (hparams.use_alibi) {
452
- f = -std::abs(kv_self->cells[i].pos - pos);
453
- } else {
454
- f = 0.0f;
455
- }
456
- }
457
-
458
- if (data) {
459
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
460
- }
461
-
462
- // may need to cut off old tokens for sliding window
463
- // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
464
- if (data_swa) {
465
- if (hparams.n_attn_chunk) {
466
- llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
467
- if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
468
- f = -INFINITY;
469
- }
470
- } else {
471
- if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
472
- f = -INFINITY;
473
- }
474
- }
475
- data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
476
- }
477
- }
478
- }
479
- }
480
 
481
- // mask padded tokens
482
- if (data) {
483
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
484
- for (int j = 0; j < n_kv; ++j) {
485
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
486
- }
487
- }
488
- }
489
 
490
- // mask padded tokens
491
- if (data_swa) {
492
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
493
- for (int j = 0; j < n_kv; ++j) {
494
- data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
495
- }
496
- }
497
- }
498
- }
499
  }
500
  }
501
 
@@ -545,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
545
  n_layer (hparams.n_layer),
546
  n_rot (hparams.n_rot),
547
  n_ctx (cparams.n_ctx),
548
- n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
549
  n_head (hparams.n_head()),
550
  n_head_kv (hparams.n_head_kv()),
551
  n_embd_head_k (hparams.n_embd_head_k),
@@ -1153,7 +1029,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1153
 
1154
  auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
1155
 
1156
- const auto n_kv = kv_self->n;
1157
 
1158
  auto & cur = inp->pos_bucket;
1159
 
@@ -1188,16 +1064,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1188
  ggml_tensor * kq_b,
1189
  ggml_tensor * kq_mask,
1190
  ggml_tensor * v_mla,
1191
- bool v_trans,
1192
  float kq_scale) const {
1193
- //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1194
- //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1195
-
1196
- //const int64_t n_head = hparams.n_head(il);
1197
- //const int64_t n_head_kv = hparams.n_head_kv(il);
1198
 
1199
- //const auto & n_embd_head_k = hparams.n_embd_head_k;
1200
- //const auto & n_embd_head_v = hparams.n_embd_head_v;
 
1201
 
1202
  const auto n_tokens = q->ne[1];
1203
  const auto n_head = q->ne[2];
@@ -1336,17 +1208,11 @@ ggml_tensor * llm_graph_context::build_attn(
1336
 
1337
  const auto & kq_mask = inp->get_kq_mask();
1338
 
1339
- ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1340
- //cb(q, "q", il);
1341
-
1342
- ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1343
- //cb(k, "k", il);
1344
-
1345
- ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1346
- //cb(k, "v", il);
1347
-
1348
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1349
 
 
1350
  cb(cur, "kqv_out", il);
1351
 
1352
  if (wo) {
@@ -1369,22 +1235,16 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1369
 
1370
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1371
 
1372
- const auto n_kv = kv_self->n;
1373
-
1374
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1375
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1376
- ggml_set_input(inp->self_kq_mask);
1377
-
1378
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1379
 
1380
- if (hparams.n_swa_pattern > 1) {
1381
- GGML_ASSERT(hparams.n_swa > 0);
1382
 
1383
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1384
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1385
- ggml_set_input(inp->self_kq_mask_swa);
1386
 
1387
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1388
  }
1389
 
1390
  return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
@@ -1409,81 +1269,104 @@ ggml_tensor * llm_graph_context::build_attn(
1409
  ggml_build_forward_expand(gf, v_cur);
1410
 
1411
  const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1412
- const auto & n_ctx = cparams.n_ctx;
1413
 
1414
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1415
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
 
 
 
 
1416
 
1417
- const auto n_tokens = q_cur->ne[2];
 
 
1418
 
1419
- const bool v_trans = !cparams.flash_attn;
 
1420
 
1421
- // store to KV cache
1422
- {
1423
- const auto kv_head = kv_self->head;
 
 
 
 
1424
 
1425
- GGML_ASSERT(kv_self->size == n_ctx);
 
 
1426
 
1427
- ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
1428
- //cb(k_cache_view, "k_cache_view", il);
1429
 
1430
- // note: storing RoPE-ed version of K in the KV cache
1431
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
1432
 
1433
- v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
1434
 
1435
- ggml_tensor * v_cache_view = nullptr;
 
1436
 
1437
- if (!v_trans) {
1438
- v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
1439
- } else {
1440
- // note: the V cache is transposed when not using flash attention
1441
- v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
1442
- ( n_ctx)*ggml_element_size(kv_self->v_l[il]),
1443
- (kv_head)*ggml_element_size(kv_self->v_l[il]));
1444
 
1445
- v_cur = ggml_transpose(ctx0, v_cur);
1446
- }
1447
- //cb(v_cache_view, "v_cache_view", il);
 
 
1448
 
1449
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
 
 
 
 
 
 
1450
  }
1451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1452
  const bool is_swa = hparams.is_swa(il);
1453
 
 
 
 
 
 
 
 
 
 
 
1454
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1455
 
1456
- const auto n_kv = kv_self->n;
 
 
1457
 
1458
- const int64_t n_head_kv = hparams.n_head_kv(il);
1459
-
1460
- const auto & n_embd_head_k = hparams.n_embd_head_k;
1461
- const auto & n_embd_head_v = hparams.n_embd_head_v;
1462
-
1463
- ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1464
- //cb(q, "q", il);
1465
-
1466
- ggml_tensor * k =
1467
- ggml_view_3d(ctx0, kv_self->k_l[il],
1468
- n_embd_head_k, n_kv, n_head_kv,
1469
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
1470
- ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
1471
- 0);
1472
- //cb(k, "k", il);
1473
-
1474
- ggml_tensor * v = !v_trans ?
1475
- ggml_view_3d(ctx0, kv_self->v_l[il],
1476
- n_embd_head_v, n_kv, n_head_kv,
1477
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
1478
- ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
1479
- 0) :
1480
- ggml_view_3d(ctx0, kv_self->v_l[il],
1481
- n_kv, n_embd_head_v, n_head_kv,
1482
- ggml_element_size(kv_self->v_l[il])*n_ctx,
1483
- ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
1484
- 0);
1485
-
1486
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
1487
  cb(cur, "kqv_out", il);
1488
 
1489
  if (wo) {
@@ -1534,17 +1417,11 @@ ggml_tensor * llm_graph_context::build_attn(
1534
 
1535
  const auto & kq_mask = inp->get_kq_mask_cross();
1536
 
1537
- ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1538
- //cb(q, "q", il);
1539
-
1540
- ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1541
- //cb(k, "k", il);
1542
-
1543
- ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1544
- //cb(k, "v", il);
1545
-
1546
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1547
 
 
1548
  cb(cur, "kqv_out", il);
1549
 
1550
  if (wo) {
@@ -1712,3 +1589,30 @@ void llm_graph_context::build_pooling(
1712
 
1713
  ggml_build_forward_expand(gf, cur);
1714
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  #include <cmath>
10
  #include <cstring>
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
13
  if (ubatch->token) {
14
  const int64_t n_tokens = ubatch->n_tokens;
 
83
 
84
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
85
  if (pos_bucket) {
86
+ kv_self->set_input_pos_bucket(pos_bucket, ubatch);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  }
88
  }
89
 
 
361
  }
362
 
363
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
364
+ if (self_kq_mask) {
365
+ kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366
+ }
367
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
+ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370
+ if (self_kq_mask) {
371
+ kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
372
+ }
 
 
 
 
373
 
374
+ if (self_kq_mask_swa) {
375
+ kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
 
 
 
 
 
 
 
376
  }
377
  }
378
 
 
422
  n_layer (hparams.n_layer),
423
  n_rot (hparams.n_rot),
424
  n_ctx (cparams.n_ctx),
 
425
  n_head (hparams.n_head()),
426
  n_head_kv (hparams.n_head_kv()),
427
  n_embd_head_k (hparams.n_embd_head_k),
 
1029
 
1030
  auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
1031
 
1032
+ const auto n_kv = kv_self->get_n();
1033
 
1034
  auto & cur = inp->pos_bucket;
1035
 
 
1064
  ggml_tensor * kq_b,
1065
  ggml_tensor * kq_mask,
1066
  ggml_tensor * v_mla,
 
1067
  float kq_scale) const {
1068
+ const bool v_trans = v->nb[1] > v->nb[2];
 
 
 
 
1069
 
1070
+ q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1071
+ k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1072
+ v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1073
 
1074
  const auto n_tokens = q->ne[1];
1075
  const auto n_head = q->ne[2];
 
1208
 
1209
  const auto & kq_mask = inp->get_kq_mask();
1210
 
1211
+ ggml_tensor * q = q_cur;
1212
+ ggml_tensor * k = k_cur;
1213
+ ggml_tensor * v = v_cur;
 
 
 
 
 
 
 
1214
 
1215
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1216
  cb(cur, "kqv_out", il);
1217
 
1218
  if (wo) {
 
1235
 
1236
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1237
 
1238
+ {
1239
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
 
 
 
 
 
1240
 
1241
+ const auto n_kv = kv_self->get_n();
 
1242
 
1243
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1244
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1245
+ ggml_set_input(inp->self_kq_mask);
1246
 
1247
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1248
  }
1249
 
1250
  return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
 
1269
  ggml_build_forward_expand(gf, v_cur);
1270
 
1271
  const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
 
1272
 
1273
+ // store to KV cache
1274
+ {
1275
+ ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1276
+ ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1277
+ }
1278
+
1279
+ const auto & kq_mask = inp->get_kq_mask();
1280
 
1281
+ ggml_tensor * q = q_cur;
1282
+ ggml_tensor * k = kv_self->get_k(ctx0, il);
1283
+ ggml_tensor * v = kv_self->get_v(ctx0, il);
1284
 
1285
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1286
+ cb(cur, "kqv_out", il);
1287
 
1288
+ if (wo) {
1289
+ cur = build_lora_mm(wo, cur);
1290
+ if (arch == LLM_ARCH_GLM4) {
1291
+ // GLM4 seems to have numerical issues with half-precision accumulators
1292
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1293
+ }
1294
+ }
1295
 
1296
+ if (wo_b) {
1297
+ cur = ggml_add(ctx0, cur, wo_b);
1298
+ }
1299
 
1300
+ return cur;
1301
+ }
1302
 
1303
+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1304
+ const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1305
 
1306
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1307
 
1308
+ {
1309
+ const auto n_kv = kv_self->get_kv_base()->get_n();
1310
 
1311
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1312
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1313
+ ggml_set_input(inp->self_kq_mask);
 
 
 
 
1314
 
1315
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1316
+ }
1317
+
1318
+ {
1319
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1320
 
1321
+ const auto n_kv = kv_self->get_kv_swa()->get_n();
1322
+
1323
+ inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1324
+ //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1325
+ ggml_set_input(inp->self_kq_mask_swa);
1326
+
1327
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1328
  }
1329
 
1330
+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1331
+ }
1332
+
1333
+ ggml_tensor * llm_graph_context::build_attn(
1334
+ llm_graph_input_attn_kv_unified_iswa * inp,
1335
+ ggml_cgraph * gf,
1336
+ ggml_tensor * wo,
1337
+ ggml_tensor * wo_b,
1338
+ ggml_tensor * q_cur,
1339
+ ggml_tensor * k_cur,
1340
+ ggml_tensor * v_cur,
1341
+ ggml_tensor * kq_b,
1342
+ ggml_tensor * v_mla,
1343
+ float kq_scale,
1344
+ int il) const {
1345
+ // these nodes are added to the graph together so that they are not reordered
1346
+ // by doing so, the number of splits in the graph is reduced
1347
+ ggml_build_forward_expand(gf, q_cur);
1348
+ ggml_build_forward_expand(gf, k_cur);
1349
+ ggml_build_forward_expand(gf, v_cur);
1350
+
1351
  const bool is_swa = hparams.is_swa(il);
1352
 
1353
+ const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1354
+
1355
+ const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
1356
+
1357
+ // store to KV cache
1358
+ {
1359
+ ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1360
+ ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
1361
+ }
1362
+
1363
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1364
 
1365
+ ggml_tensor * q = q_cur;
1366
+ ggml_tensor * k = kv->get_k(ctx0, il);
1367
+ ggml_tensor * v = kv->get_v(ctx0, il);
1368
 
1369
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1370
  cb(cur, "kqv_out", il);
1371
 
1372
  if (wo) {
 
1417
 
1418
  const auto & kq_mask = inp->get_kq_mask_cross();
1419
 
1420
+ ggml_tensor * q = q_cur;
1421
+ ggml_tensor * k = k_cur;
1422
+ ggml_tensor * v = v_cur;
 
 
 
 
 
 
 
1423
 
1424
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1425
  cb(cur, "kqv_out", il);
1426
 
1427
  if (wo) {
 
1589
 
1590
  ggml_build_forward_expand(gf, cur);
1591
  }
1592
+
1593
+ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
1594
+ // TODO move to hparams if a T5 variant appears that uses a different value
1595
+ const int64_t max_distance = 128;
1596
+
1597
+ if (bidirectional) {
1598
+ n_buckets >>= 1;
1599
+ }
1600
+
1601
+ const int64_t max_exact = n_buckets >> 1;
1602
+
1603
+ int32_t relative_position = x - y;
1604
+ int32_t relative_bucket = 0;
1605
+
1606
+ if (bidirectional) {
1607
+ relative_bucket += (relative_position > 0) * n_buckets;
1608
+ relative_position = abs(relative_position);
1609
+ } else {
1610
+ relative_position = -std::min<int32_t>(relative_position, 0);
1611
+ }
1612
+
1613
+ int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
1614
+ relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
1615
+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
1616
+
1617
+ return relative_bucket;
1618
+ }
examples/talk-llama/llama-graph.h CHANGED
@@ -19,6 +19,7 @@ struct llama_cparams;
19
 
20
  class llama_memory_i;
21
  class llama_kv_cache_unified;
 
22
  class llama_kv_cache_recurrent;
23
 
24
  // certain models (typically multi-modal) can produce different types of graphs
@@ -255,6 +256,31 @@ public:
255
 
256
  void set_input(const llama_ubatch * ubatch) override;
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
259
  ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
260
 
@@ -266,7 +292,7 @@ public:
266
  const llama_hparams & hparams;
267
  const llama_cparams & cparams;
268
 
269
- const llama_kv_cache_unified * kv_self;
270
  };
271
 
272
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -378,7 +404,6 @@ struct llm_graph_context {
378
  const int64_t n_layer;
379
  const int64_t n_rot;
380
  const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
381
- const int64_t n_ctx_per_seq;
382
  const int64_t n_head;
383
  const int64_t n_head_kv;
384
  const int64_t n_embd_head_k;
@@ -507,13 +532,12 @@ struct llm_graph_context {
507
 
508
  ggml_tensor * build_attn_mha(
509
  ggml_cgraph * gf,
510
- ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
511
- ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
512
- ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
513
  ggml_tensor * kq_b,
514
  ggml_tensor * kq_mask,
515
- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
516
- bool v_trans,
517
  float kq_scale) const;
518
 
519
  llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
@@ -546,6 +570,21 @@ struct llm_graph_context {
546
  float kq_scale,
547
  int il) const;
548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  llm_graph_input_attn_cross * build_attn_inp_cross() const;
550
 
551
  ggml_tensor * build_attn(
@@ -596,3 +635,6 @@ struct llm_graph_context {
596
  ggml_tensor * cls_out,
597
  ggml_tensor * cls_out_b) const;
598
  };
 
 
 
 
19
 
20
  class llama_memory_i;
21
  class llama_kv_cache_unified;
22
+ class llama_kv_cache_unified_iswa;
23
  class llama_kv_cache_recurrent;
24
 
25
  // certain models (typically multi-modal) can produce different types of graphs
 
256
 
257
  void set_input(const llama_ubatch * ubatch) override;
258
 
259
+ ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
260
+
261
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
262
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
263
+
264
+ const llama_hparams & hparams;
265
+ const llama_cparams & cparams;
266
+
267
+ const llama_kv_cache_unified * kv_self;
268
+ };
269
+
270
+ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
271
+ public:
272
+ llm_graph_input_attn_kv_unified_iswa(
273
+ const llama_hparams & hparams,
274
+ const llama_cparams & cparams,
275
+ const llama_kv_cache_unified_iswa * kv_self) :
276
+ hparams(hparams),
277
+ cparams(cparams),
278
+ kv_self(kv_self) {
279
+ }
280
+ ~llm_graph_input_attn_kv_unified_iswa() = default;
281
+
282
+ void set_input(const llama_ubatch * ubatch) override;
283
+
284
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
285
  ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
286
 
 
292
  const llama_hparams & hparams;
293
  const llama_cparams & cparams;
294
 
295
+ const llama_kv_cache_unified_iswa * kv_self;
296
  };
297
 
298
  class llm_graph_input_attn_cross : public llm_graph_input_i {
 
404
  const int64_t n_layer;
405
  const int64_t n_rot;
406
  const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
 
407
  const int64_t n_head;
408
  const int64_t n_head_kv;
409
  const int64_t n_embd_head_k;
 
532
 
533
  ggml_tensor * build_attn_mha(
534
  ggml_cgraph * gf,
535
+ ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
536
+ ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
537
+ ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
538
  ggml_tensor * kq_b,
539
  ggml_tensor * kq_mask,
540
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
 
541
  float kq_scale) const;
542
 
543
  llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
 
570
  float kq_scale,
571
  int il) const;
572
 
573
+ llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
574
+
575
+ ggml_tensor * build_attn(
576
+ llm_graph_input_attn_kv_unified_iswa * inp,
577
+ ggml_cgraph * gf,
578
+ ggml_tensor * wo,
579
+ ggml_tensor * wo_b,
580
+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
581
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
582
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
583
+ ggml_tensor * kq_b,
584
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
585
+ float kq_scale,
586
+ int il) const;
587
+
588
  llm_graph_input_attn_cross * build_attn_inp_cross() const;
589
 
590
  ggml_tensor * build_attn(
 
635
  ggml_tensor * cls_out,
636
  ggml_tensor * cls_out_b) const;
637
  };
638
+
639
+ // TODO: better name
640
+ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
examples/talk-llama/llama-hparams.cpp CHANGED
@@ -2,6 +2,22 @@
2
 
3
  #include "ggml.h"
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  uint32_t llama_hparams::n_head(uint32_t il) const {
6
  if (il < n_layer) {
7
  return n_head_arr[il];
@@ -72,7 +88,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
72
 
73
  bool llama_hparams::is_swa(uint32_t il) const {
74
  if (il < n_layer) {
75
- return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
76
  }
77
 
78
  GGML_ABORT("fatal error");
 
2
 
3
  #include "ggml.h"
4
 
5
+ void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
6
+ for (uint32_t il = 0; il < n_layer; ++il) {
7
+ swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
8
+ }
9
+ }
10
+
11
+ bool llama_hparams::is_swa_any() const {
12
+ for (uint32_t il = 0; il < n_layer; ++il) {
13
+ if (swa_layers[il]) {
14
+ return true;
15
+ }
16
+ }
17
+
18
+ return false;
19
+ }
20
+
21
  uint32_t llama_hparams::n_head(uint32_t il) const {
22
  if (il < n_layer) {
23
  return n_head_arr[il];
 
88
 
89
  bool llama_hparams::is_swa(uint32_t il) const {
90
  if (il < n_layer) {
91
+ return swa_layers[il];
92
  }
93
 
94
  GGML_ABORT("fatal error");
examples/talk-llama/llama-hparams.h CHANGED
@@ -14,6 +14,12 @@ enum llama_expert_gating_func_type {
14
  LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
15
  };
16
 
 
 
 
 
 
 
17
  struct llama_hparams_posnet {
18
  uint32_t n_embd;
19
  uint32_t n_layer;
@@ -35,8 +41,6 @@ struct llama_hparams {
35
  uint32_t n_embd_features = 0;
36
  uint32_t n_layer;
37
  uint32_t n_rot;
38
- uint32_t n_swa = 0; // sliding window attention (SWA)
39
- uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
40
  uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
41
  uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
42
  uint32_t n_expert = 0;
@@ -96,6 +100,15 @@ struct llama_hparams {
96
 
97
  std::array<int, 4> rope_sections;
98
 
 
 
 
 
 
 
 
 
 
99
  // for State Space Models
100
  uint32_t ssm_d_conv = 0;
101
  uint32_t ssm_d_inner = 0;
@@ -116,11 +129,10 @@ struct llama_hparams {
116
  bool causal_attn = true;
117
  bool use_alibi = false;
118
  bool attn_soft_cap = false;
 
119
 
 
120
  uint32_t n_moe_layer_step = 0;
121
- bool use_kq_norm = true;
122
- uint32_t n_attn_chunk = 0;
123
- // values below seems to be fixed on llama4
124
  uint32_t n_no_rope_layer_step = 4;
125
  uint32_t n_attn_temp_floor_scale = 8192;
126
  float f_attn_temp_scale = 0.1;
@@ -133,6 +145,23 @@ struct llama_hparams {
133
  enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
134
  enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  uint32_t n_head(uint32_t il = 0) const;
137
 
138
  uint32_t n_head_kv(uint32_t il = 0) const;
 
14
  LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
15
  };
16
 
17
+ enum llama_swa_type {
18
+ LLAMA_SWA_TYPE_NONE = 0,
19
+ LLAMA_SWA_TYPE_STANDARD = 1,
20
+ LLAMA_SWA_TYPE_CHUNKED = 2,
21
+ };
22
+
23
  struct llama_hparams_posnet {
24
  uint32_t n_embd;
25
  uint32_t n_layer;
 
41
  uint32_t n_embd_features = 0;
42
  uint32_t n_layer;
43
  uint32_t n_rot;
 
 
44
  uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
45
  uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
46
  uint32_t n_expert = 0;
 
100
 
101
  std::array<int, 4> rope_sections;
102
 
103
+ // Sliding Window Attention (SWA)
104
+ llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
105
+ // the size of the sliding window (0 - no SWA)
106
+ uint32_t n_swa = 0;
107
+ // if swa_layers[il] == true, then layer il is SWA
108
+ // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
109
+ // by default, all layers are dense
110
+ std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
111
+
112
  // for State Space Models
113
  uint32_t ssm_d_conv = 0;
114
  uint32_t ssm_d_inner = 0;
 
129
  bool causal_attn = true;
130
  bool use_alibi = false;
131
  bool attn_soft_cap = false;
132
+ bool use_kq_norm = true;
133
 
134
+ // llama4
135
  uint32_t n_moe_layer_step = 0;
 
 
 
136
  uint32_t n_no_rope_layer_step = 4;
137
  uint32_t n_attn_temp_floor_scale = 8192;
138
  float f_attn_temp_scale = 0.1;
 
145
  enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
146
  enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
147
 
148
+ // this value n_pattern means that every nth layer is dense (i.e. non-SWA)
149
+ // note that if n_pattern == 0, all layers are SWA
150
+ // if n_pattern == 1, all layers are dense
151
+ // example: n_pattern = 3
152
+ // il == 0: swa
153
+ // il == 1: swa
154
+ // il == 2: dense
155
+ // il == 3: swa
156
+ // il == 4: swa
157
+ // il == 5: dense
158
+ // il == 6: swa
159
+ // etc ...
160
+ void set_swa_pattern(uint32_t n_pattern);
161
+
162
+ // return true if one of the layers is SWA
163
+ bool is_swa_any() const;
164
+
165
  uint32_t n_head(uint32_t il = 0) const;
166
 
167
  uint32_t n_head_kv(uint32_t il = 0) const;
examples/talk-llama/llama-kv-cache.cpp CHANGED
@@ -23,32 +23,21 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
23
  }
24
 
25
  llama_kv_cache_unified::llama_kv_cache_unified(
26
- const llama_model & model,
27
- ggml_type type_k,
28
- ggml_type type_v,
29
- bool v_trans,
30
- bool offload,
31
- uint32_t kv_size,
32
- uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
33
- const int32_t n_layer = hparams.n_layer;
34
-
35
- has_shift = false;
36
- can_shift = true;
37
-
38
- LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n",
39
- __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding);
40
-
41
- GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
42
-
43
- head = 0;
44
- size = kv_size;
45
- used = 0;
46
-
47
- this->type_k = type_k;
48
- this->type_v = type_v;
49
-
50
- cells.clear();
51
- cells.resize(kv_size);
52
 
53
  // create a context for each buffer type
54
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@@ -56,7 +45,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
56
  auto it = ctx_map.find(buft);
57
  if (it == ctx_map.end()) {
58
  ggml_init_params params = {
59
- /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
60
  /*.mem_buffer =*/ NULL,
61
  /*.no_alloc =*/ true,
62
  };
@@ -75,37 +64,48 @@ llama_kv_cache_unified::llama_kv_cache_unified(
75
  return it->second;
76
  };
77
 
78
- k_l.reserve(n_layer);
79
- v_l.reserve(n_layer);
80
 
81
- for (int i = 0; i < n_layer; i++) {
82
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
83
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
 
 
 
 
 
 
 
84
 
85
  const char * dev_name = "CPU";
86
 
87
  ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
88
 
89
  if (offload) {
90
- auto * dev = model.dev_layer(i);
91
  buft = ggml_backend_dev_buffer_type(dev);
92
 
93
  dev_name = ggml_backend_dev_name(dev);
94
  }
95
 
96
- LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, i, dev_name);
97
 
98
  ggml_context * ctx = ctx_for_buft(buft);
99
  if (!ctx) {
100
  throw std::runtime_error("failed to create ggml context for kv cache");
101
  }
102
 
103
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
104
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
105
- ggml_format_name(k, "cache_k_l%d", i);
106
- ggml_format_name(v, "cache_v_l%d", i);
107
- k_l.push_back(k);
108
- v_l.push_back(v);
 
 
 
 
 
109
  }
110
 
111
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -117,8 +117,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
117
  if (!buf) {
118
  throw std::runtime_error("failed to allocate buffer for kv cache");
119
  }
120
- ggml_backend_buffer_clear(buf, 0);
121
  LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
 
 
122
  bufs.emplace_back(buf);
123
  }
124
 
@@ -126,20 +128,17 @@ llama_kv_cache_unified::llama_kv_cache_unified(
126
  const size_t memory_size_k = size_k_bytes();
127
  const size_t memory_size_v = size_v_bytes();
128
 
129
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
130
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
131
  ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
132
  ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
133
  }
134
  }
135
 
136
  void llama_kv_cache_unified::clear() {
137
- for (int32_t i = 0; i < (int32_t) size; ++i) {
138
- cells[i].pos = -1;
139
- cells[i].seq_id.clear();
140
- }
141
  head = 0;
142
- used = 0;
143
 
144
  for (auto & buf : bufs) {
145
  ggml_backend_buffer_clear(buf.get(), 0);
@@ -147,7 +146,7 @@ void llama_kv_cache_unified::clear() {
147
  }
148
 
149
  bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
150
- uint32_t new_head = size;
151
 
152
  if (p0 < 0) {
153
  p0 = 0;
@@ -157,32 +156,20 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
157
  p1 = std::numeric_limits<llama_pos>::max();
158
  }
159
 
160
- for (uint32_t i = 0; i < size; ++i) {
161
- if (cells[i].pos >= p0 && cells[i].pos < p1) {
162
- if (seq_id < 0) {
163
- cells[i].seq_id.clear();
164
- } else if (cells[i].has_seq_id(seq_id)) {
165
- cells[i].seq_id.erase(seq_id);
166
- } else {
167
- continue;
168
- }
169
- if (cells[i].is_empty()) {
170
- // keep count of the number of used cells
171
- if (cells[i].pos >= 0) {
172
- used--;
173
- }
174
-
175
- cells[i].pos = -1;
176
 
177
- if (new_head == size) {
178
- new_head = i;
179
- }
180
  }
181
  }
182
  }
183
 
184
  // If we freed up a slot, set head to it so searching can start there.
185
- if (new_head != size && new_head < head) {
186
  head = new_head;
187
  }
188
 
@@ -202,49 +189,40 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
202
  p1 = std::numeric_limits<llama_pos>::max();
203
  }
204
 
205
- // otherwise, this is the KV of a Transformer-like model
206
- head = 0;
 
 
207
 
208
- for (uint32_t i = 0; i < size; ++i) {
209
- if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
210
- cells[i].seq_id.insert(seq_id_dst);
211
  }
212
  }
213
  }
214
 
215
  void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
216
- uint32_t new_head = size;
217
 
218
- for (uint32_t i = 0; i < size; ++i) {
219
- if (!cells[i].has_seq_id(seq_id)) {
220
- if (cells[i].pos >= 0) {
221
- used--;
222
- }
223
-
224
- cells[i].pos = -1;
225
- cells[i].seq_id.clear();
226
-
227
- if (new_head == size){
228
  new_head = i;
229
  }
230
- } else {
231
- cells[i].seq_id.clear();
232
- cells[i].seq_id.insert(seq_id);
233
  }
234
  }
235
 
236
  // If we freed up a slot, set head to it so searching can start there.
237
- if (new_head != size && new_head < head) {
238
  head = new_head;
239
  }
240
  }
241
 
242
- void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
243
- if (delta == 0) {
244
  return;
245
  }
246
 
247
- uint32_t new_head = size;
248
 
249
  if (p0 < 0) {
250
  p0 = 0;
@@ -254,24 +232,19 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
254
  p1 = std::numeric_limits<llama_pos>::max();
255
  }
256
 
257
- // If there is no range then return early to avoid looping over the
258
  if (p0 == p1) {
259
  return;
260
  }
261
 
262
- for (uint32_t i = 0; i < size; ++i) {
263
- if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
264
- has_shift = true;
265
- cells[i].pos += delta;
266
- cells[i].delta += delta;
267
 
268
- if (cells[i].pos < 0) {
269
- if (!cells[i].is_empty()) {
270
- used--;
271
- }
272
- cells[i].pos = -1;
273
- cells[i].seq_id.clear();
274
- if (new_head == size) {
275
  new_head = i;
276
  }
277
  }
@@ -280,7 +253,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
280
 
281
  // If we freed up a slot, set head to it so searching can start there.
282
  // Otherwise we just start the next search from the beginning.
283
- head = new_head != size ? new_head : 0;
284
  }
285
 
286
  void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
@@ -301,66 +274,41 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
301
  return;
302
  }
303
 
304
- for (uint32_t i = 0; i < size; ++i) {
305
- if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
306
- has_shift = true;
 
307
 
308
- {
309
- llama_pos p_old = cells[i].pos;
310
- cells[i].pos /= d;
311
- cells[i].delta += cells[i].pos - p_old;
312
- }
313
  }
314
  }
315
  }
316
 
317
- llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
318
- llama_pos result = 0;
319
-
320
- for (uint32_t i = 0; i < size; ++i) {
321
- if (cells[i].has_seq_id(seq_id)) {
322
- result = std::max(result, cells[i].pos);
323
- }
324
- }
325
 
326
- return result;
 
327
  }
328
 
329
  void llama_kv_cache_unified::restore() {
330
- if (pending.ranges.empty()) {
331
- return;
332
  }
333
 
334
- uint32_t new_head = size;
335
-
336
- for (auto & range : pending.ranges) {
337
- for (uint32_t i = range.c0; i < range.c1; ++i) {
338
- cells[i].seq_id.clear();
339
-
340
- // keep count of the number of used cells
341
- if (cells[i].pos >= 0) {
342
- used--;
343
- }
344
-
345
- cells[i].pos = -1;
346
- }
347
-
348
- new_head = std::min(new_head, range.c0);
349
- }
350
-
351
- if (new_head != size && new_head < head) {
352
- head = new_head;
353
- }
354
  }
355
 
356
  void llama_kv_cache_unified::commit() {
357
- if (pending.ranges.empty()) {
358
- LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
359
- __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
360
  return;
361
  }
362
 
363
- pending.ranges.clear();
364
  }
365
 
366
  bool llama_kv_cache_unified::update(llama_context & lctx) {
@@ -368,7 +316,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
368
 
369
  auto * sched = lctx.get_sched();
370
 
371
- if (has_shift) {
372
  if (!get_can_shift()) {
373
  GGML_ABORT("The current KV cache / model configuration does not support K-shift");
374
  }
@@ -392,13 +340,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
392
  need_reserve = true;
393
  }
394
 
395
- {
396
- has_shift = false;
397
-
398
- for (uint32_t i = 0; i < size; ++i) {
399
- cells[i].delta = 0;
400
- }
401
- }
402
  }
403
 
404
  if (do_defrag) {
@@ -429,7 +371,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
429
  void llama_kv_cache_unified::defrag_sched(float thold) {
430
  // - do not defrag small contexts (i.e. < 2048 tokens)
431
  // - count the padding towards the number of used tokens
432
- const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
433
 
434
  // queue defragmentation for next llama_kv_cache_update
435
  if (fragmentation > thold) {
@@ -440,7 +382,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
440
  }
441
 
442
  void llama_kv_cache_unified::set_full() {
443
- n = size;
444
 
445
  // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
446
  // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
@@ -450,51 +392,67 @@ void llama_kv_cache_unified::set_full() {
450
  head = 0;
451
  }
452
 
453
- llama_sbatch llama_kv_cache_unified::sbatch_init(
454
- const llama_batch & batch,
455
- bool logits_all) {
456
  return llama_sbatch(batch, hparams.n_embd, true, logits_all);
457
  }
458
 
459
- llama_ubatch llama_kv_cache_unified::ubatch_next(
460
- llama_sbatch & sbatch,
461
- uint32_t n_ubatch,
462
- bool embd_pooled) const {
463
  GGML_UNUSED(embd_pooled);
464
  return sbatch.split_simple(n_ubatch);
465
  }
466
 
467
- bool llama_kv_cache_unified::find_slot(
468
- const llama_ubatch & ubatch) {
469
  const uint32_t n_tokens = ubatch.n_tokens;
470
- const uint32_t n_seqs = ubatch.n_seqs;
471
- const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
472
 
473
  // if we have enough unused cells before the current head ->
474
  // better to start searching from the beginning of the cache, hoping to fill it
475
- if (head > used + 2*ubatch.n_tokens) {
476
  head = 0;
477
  }
478
 
479
  // otherwise, one cell per token.
480
 
481
- if (n_tokens > size) {
482
- LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
483
  return false;
484
  }
485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  uint32_t n_tested = 0;
487
 
488
  while (true) {
489
- if (head + n_tokens > size) {
490
- n_tested += size - head;
491
  head = 0;
492
  continue;
493
  }
494
 
495
  bool found = true;
496
  for (uint32_t i = 0; i < n_tokens; i++) {
497
- if (cells[head + i].pos >= 0) {
 
498
  found = false;
499
  head += i + 1;
500
  n_tested += i + 1;
@@ -506,66 +464,257 @@ bool llama_kv_cache_unified::find_slot(
506
  break;
507
  }
508
 
509
- if (n_tested >= size) {
510
  //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
511
  return false;
512
  }
513
  }
514
 
515
- for (uint32_t s = 0; s < n_seqs; s++) {
516
- for (uint32_t i = 0; i < n_seq_tokens; ++i) {
517
- uint32_t k = s*n_seq_tokens + i;
518
- cells[head + k].pos = ubatch.pos[k];
519
 
520
- for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
521
- cells[head + k].seq_id.insert(ubatch.seq_id[s][j]);
522
- }
 
 
523
  }
524
  }
525
 
526
- used += n_tokens;
527
-
528
- pending.ranges.push_back({head, head + n_tokens});
529
-
530
  // a heuristic, to avoid attending the full cache if it is not yet utilized
531
  // after enough generations, the benefit from this heuristic disappears
532
  // if we start defragmenting the cache, the benefit from this will be more important
533
- n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
 
 
 
 
534
 
535
- //printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
 
536
 
 
537
  return true;
538
  }
539
 
540
- int32_t llama_kv_cache_unified::get_n_tokens() const {
541
- int32_t result = 0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
 
543
- for (uint32_t i = 0; i < size; i++) {
544
- result += cells[i].seq_id.size();
 
 
 
 
 
545
  }
546
 
547
- return result;
 
 
 
 
 
548
  }
549
 
550
- int32_t llama_kv_cache_unified::get_used_cells() const {
551
- return used;
 
 
 
 
 
 
 
 
 
 
552
  }
553
 
554
- bool llama_kv_cache_unified::get_can_shift() const {
555
- return can_shift;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
  }
557
 
558
- llama_pos llama_kv_cache_unified::get_pos_max() const {
559
- llama_pos pos_max = -1;
560
- for (const auto & cell : cells) {
561
- pos_max = std::max(pos_max, cell.pos);
 
 
 
562
  }
 
 
 
 
 
 
 
 
 
 
 
563
 
564
- return pos_max;
 
 
 
 
 
 
 
 
 
565
  }
566
 
567
  size_t llama_kv_cache_unified::total_size() const {
568
  size_t size = 0;
 
569
  for (const auto & buf : bufs) {
570
  size += ggml_backend_buffer_get_size(buf.get());
571
  }
@@ -576,8 +725,8 @@ size_t llama_kv_cache_unified::total_size() const {
576
  size_t llama_kv_cache_unified::size_k_bytes() const {
577
  size_t size_k_bytes = 0;
578
 
579
- for (const auto & k : k_l) {
580
- size_k_bytes += ggml_nbytes(k);
581
  }
582
 
583
  return size_k_bytes;
@@ -586,8 +735,8 @@ size_t llama_kv_cache_unified::size_k_bytes() const {
586
  size_t llama_kv_cache_unified::size_v_bytes() const {
587
  size_t size_v_bytes = 0;
588
 
589
- for (const auto & v : v_l) {
590
- size_v_bytes += ggml_nbytes(v);
591
  }
592
 
593
  return size_v_bytes;
@@ -651,13 +800,7 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
651
  GGML_UNUSED(ubatch);
652
 
653
  if (k_shift) {
654
- assert(ggml_backend_buffer_is_host(k_shift->buffer));
655
-
656
- int32_t * data = (int32_t *) k_shift->data;
657
-
658
- for (uint32_t i = 0; i < kv_self->size; ++i) {
659
- data[i] = kv_self->cells[i].delta;
660
- }
661
  }
662
  }
663
 
@@ -667,13 +810,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
667
  ggml_cgraph * gf) const {
668
  auto res = std::make_unique<llm_graph_result>();
669
 
670
- const auto & n_layer = hparams.n_layer;
671
-
672
  const auto & n_embd_head_k = hparams.n_embd_head_k;
673
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
674
 
675
- const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
676
-
677
  //GGML_ASSERT(kv_self->size == n_ctx);
678
 
679
  auto inp = std::make_unique<llm_graph_input_k_shift>(this);
@@ -681,24 +820,22 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
681
  inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
682
  ggml_set_input(inp->k_shift);
683
 
684
- for (uint32_t il = 0; il < n_layer; ++il) {
 
 
685
  const int64_t n_head_kv = hparams.n_head_kv(il);
686
  const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
687
 
688
- const bool is_swa = hparams.is_swa(il);
 
689
 
690
- // note: the swa rope params could become part of the cparams in the future
691
- // if we decide to make them configurable, like the non-sliding ones
692
- const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
693
- const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
694
-
695
- ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
696
 
697
  ggml_tensor * k =
698
- ggml_view_3d(ctx, k_l[il],
699
- n_embd_head_k, n_head_kv, size,
700
- ggml_row_size(k_l[il]->type, n_embd_head_k),
701
- ggml_row_size(k_l[il]->type, n_embd_k_gqa),
702
  0);
703
 
704
  ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
@@ -803,44 +940,46 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
803
  nm++;
804
  }
805
 
806
- for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
 
 
807
  const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
808
  const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
809
 
810
- ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il],
811
  n_embd_k_gqa, nm,
812
- ggml_row_size(k_l[il]->type, n_embd_k_gqa),
813
- ggml_row_size(k_l[il]->type, n_embd_k_gqa*i));
814
 
815
- ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il],
816
  n_embd_k_gqa, nm,
817
- ggml_row_size(k_l[il]->type, n_embd_k_gqa),
818
- ggml_row_size(k_l[il]->type, n_embd_k_gqa*id));
819
 
820
  ggml_tensor * view_v_src;
821
  ggml_tensor * view_v_dst;
822
 
823
  if (cparams.flash_attn) {
824
  // NOTE: the V cache is not transposed when using flash attention
825
- view_v_src = ggml_view_2d(ctx, v_l[il],
826
  n_embd_v_gqa, nm,
827
- ggml_row_size(v_l[il]->type, n_embd_v_gqa),
828
- ggml_row_size(v_l[il]->type, n_embd_v_gqa*i));
829
 
830
- view_v_dst = ggml_view_2d(ctx, v_l[il],
831
  n_embd_v_gqa, nm,
832
- ggml_row_size(v_l[il]->type, n_embd_v_gqa),
833
- ggml_row_size(v_l[il]->type, n_embd_v_gqa*id));
834
  } else {
835
- view_v_src = ggml_view_2d(ctx, v_l[il],
836
  nm, n_embd_v_gqa,
837
- ggml_row_size(v_l[il]->type, size),
838
- ggml_row_size(v_l[il]->type, i));
839
 
840
- view_v_dst = ggml_view_2d(ctx, v_l[il],
841
  nm, n_embd_v_gqa,
842
- ggml_row_size(v_l[il]->type, size),
843
- ggml_row_size(v_l[il]->type, id));
844
  }
845
 
846
  ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
@@ -857,10 +996,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
857
  }
858
 
859
  bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
860
- const uint32_t n_layer = hparams.n_layer;
861
 
862
- const uint32_t n_kv = cell_max();
863
- const uint32_t n_used = used;
864
 
865
  assert(n_used <= n_kv);
866
 
@@ -888,9 +1027,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
888
  ids.resize(n_kv, n_kv);
889
 
890
  for (uint32_t i0 = 0; i0 < n_used; ++i0) {
891
- const auto & cell0 = cells[i0];
892
-
893
- if (!cell0.is_empty()) {
894
  ids[i0] = i0;
895
 
896
  continue;
@@ -901,7 +1038,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
901
  uint32_t nh = 1;
902
 
903
  // determine the size of the hole
904
- while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
905
  nh++;
906
  }
907
 
@@ -910,9 +1047,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
910
 
911
  // starting from the end, find nh non-empty cells
912
  for (; is > i0; --is) {
913
- const auto & cell1 = cells[is];
914
-
915
- if (cell1.is_empty() || ids[is] != n_kv) {
916
  continue;
917
  }
918
 
@@ -939,9 +1074,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
939
 
940
  // go back and move the nf cells to the hole
941
  for (; i1 < n_kv; ++i1) {
942
- auto & cell1 = cells[i1];
943
-
944
- if (cell1.is_empty() || ids[i1] != n_kv) {
945
  if (n_moves == max_moves) {
946
  stop = true;
947
  break;
@@ -955,10 +1088,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
955
  ids[i1] = i0 + nf;
956
 
957
  // move the cell meta data
958
- cells[i0 + nf] = cell1;
959
 
960
- // clear the old cell and move the head there
961
- cell1 = kv_cell();
962
  head = n_used;
963
 
964
  if (!cont) {
@@ -993,16 +1124,30 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
993
  return true;
994
  }
995
 
996
- uint32_t llama_kv_cache_unified::cell_max() const {
997
- for (uint32_t i = size; i > 0; --i) {
998
- const kv_cell & cell = cells[i - 1];
999
 
1000
- if (cell.pos >= 0 && !cell.is_empty()) {
1001
- return i;
1002
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1003
  }
1004
 
1005
- return 0;
1006
  }
1007
 
1008
  void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
@@ -1011,23 +1156,24 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
1011
 
1012
  // Count the number of cells with the specified seq_id
1013
  // Find all the ranges of cells with this seq id (or all, when -1)
1014
- uint32_t cell_range_begin = size;
1015
- for (uint32_t i = 0; i < size; ++i) {
1016
- const auto & cell = cells[i];
1017
- if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
1018
  ++cell_count;
1019
- if (cell_range_begin == size) {
1020
  cell_range_begin = i;
1021
  }
1022
  } else {
1023
- if (cell_range_begin != size) {
1024
  cell_ranges.emplace_back(cell_range_begin, i);
1025
- cell_range_begin = size;
1026
  }
1027
  }
1028
  }
1029
- if (cell_range_begin != size) {
1030
- cell_ranges.emplace_back(cell_range_begin, size);
 
1031
  }
1032
 
1033
  // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
@@ -1064,17 +1210,24 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
1064
  void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
1065
  for (const auto & range : cell_ranges) {
1066
  for (uint32_t i = range.first; i < range.second; ++i) {
1067
- const auto & cell = cells[i];
1068
- const llama_pos pos = cell.pos;
1069
- const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
 
 
 
 
 
 
 
 
 
1070
 
1071
  io.write(&pos, sizeof(pos));
1072
  io.write(&n_seq_id, sizeof(n_seq_id));
1073
 
1074
- if (n_seq_id) {
1075
- for (auto seq_id : cell.seq_id) {
1076
- io.write(&seq_id, sizeof(seq_id));
1077
- }
1078
  }
1079
  }
1080
  }
@@ -1082,7 +1235,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
1082
 
1083
  void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
1084
  const uint32_t v_trans = this->v_trans ? 1 : 0;
1085
- const uint32_t n_layer = hparams.n_layer;
1086
 
1087
  io.write(&v_trans, sizeof(v_trans));
1088
  io.write(&n_layer, sizeof(n_layer));
@@ -1091,56 +1244,63 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1091
 
1092
  // Iterate and write all the keys first, each row is a cell
1093
  // Get whole range at a time
1094
- for (uint32_t il = 0; il < n_layer; ++il) {
 
 
1095
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1096
 
1097
  // Write key type
1098
- const int32_t k_type_i = (int32_t)k_l[il]->type;
1099
  io.write(&k_type_i, sizeof(k_type_i));
1100
 
1101
  // Write row size of key
1102
- const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1103
  io.write(&k_size_row, sizeof(k_size_row));
1104
 
1105
  // Read each range of cells of k_size length each into tmp_buf and write out
1106
  for (const auto & range : cell_ranges) {
1107
  const size_t range_size = range.second - range.first;
1108
  const size_t buf_size = range_size * k_size_row;
1109
- io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
1110
  }
1111
  }
1112
 
1113
  if (!v_trans) {
1114
- for (uint32_t il = 0; il < n_layer; ++il) {
 
 
1115
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1116
 
1117
  // Write value type
1118
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1119
  io.write(&v_type_i, sizeof(v_type_i));
1120
 
1121
  // Write row size of value
1122
- const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
1123
  io.write(&v_size_row, sizeof(v_size_row));
1124
 
1125
  // Read each range of cells of v_size length each into tmp_buf and write out
1126
  for (const auto & range : cell_ranges) {
1127
  const size_t range_size = range.second - range.first;
1128
  const size_t buf_size = range_size * v_size_row;
1129
- io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
1130
  }
1131
  }
1132
  } else {
1133
  // When v is transposed, we also need the element size and get the element ranges from each row
1134
- const uint32_t kv_size = size;
1135
- for (uint32_t il = 0; il < n_layer; ++il) {
 
 
 
1136
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1137
 
1138
  // Write value type
1139
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1140
  io.write(&v_type_i, sizeof(v_type_i));
1141
 
1142
  // Write element size
1143
- const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
1144
  io.write(&v_size_el, sizeof(v_size_el));
1145
 
1146
  // Write GQA embedding size
@@ -1153,7 +1313,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1153
  const size_t range_size = range.second - range.first;
1154
  const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1155
  const size_t buf_size = range_size * v_size_el;
1156
- io.write_tensor(v_l[il], src_offset, buf_size);
1157
  }
1158
  }
1159
  }
@@ -1170,8 +1330,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1170
  llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1171
 
1172
  batch.n_tokens = cell_count;
1173
- batch.n_seq_tokens = cell_count;
1174
- batch.n_seqs = 1;
1175
 
1176
  for (uint32_t i = 0; i < cell_count; ++i) {
1177
  llama_pos pos;
@@ -1180,32 +1338,40 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1180
  io.read_to(&pos, sizeof(pos));
1181
  io.read_to(&n_seq_id, sizeof(n_seq_id));
1182
 
1183
- if (n_seq_id != 0) {
1184
  LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1185
  return false;
1186
  }
1187
 
1188
- batch.pos[i] = pos;
 
 
 
 
 
 
 
 
1189
  }
1190
- batch.n_seq_id[0] = 1;
1191
- batch.seq_id[0] = &dest_seq_id;
1192
  if (!find_slot(batch)) {
1193
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1194
  return false;
1195
  }
 
1196
  commit();
1197
 
1198
  // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1199
  // Assume that this is one contiguous block of cells
1200
- GGML_ASSERT(head + cell_count <= size);
1201
- GGML_ASSERT(cells[head].pos == batch.pos[0]);
1202
- GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
1203
- GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
1204
- GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
1205
  } else {
1206
  // whole KV cache restore
1207
 
1208
- if (cell_count > size) {
1209
  LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1210
  return false;
1211
  }
@@ -1213,34 +1379,28 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1213
  clear();
1214
 
1215
  for (uint32_t i = 0; i < cell_count; ++i) {
1216
- kv_cell & cell = cells[i];
1217
-
1218
  llama_pos pos;
1219
  uint32_t n_seq_id;
1220
 
1221
  io.read_to(&pos, sizeof(pos));
1222
  io.read_to(&n_seq_id, sizeof(n_seq_id));
1223
 
1224
- cell.pos = pos;
1225
 
1226
  for (uint32_t j = 0; j < n_seq_id; ++j) {
1227
  llama_seq_id seq_id;
1228
  io.read_to(&seq_id, sizeof(seq_id));
1229
 
1230
- // TODO: llama_kv_cache_unified should have a notion of max sequences
1231
- //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
1232
- if (seq_id < 0) {
1233
- //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
1234
- LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
1235
  return false;
1236
  }
1237
 
1238
- cell.seq_id.insert(seq_id);
1239
  }
1240
  }
1241
 
1242
  head = 0;
1243
- used = cell_count;
1244
  }
1245
 
1246
  return true;
@@ -1249,15 +1409,16 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1249
  bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
1250
  uint32_t v_trans;
1251
  uint32_t n_layer;
 
1252
  io.read_to(&v_trans, sizeof(v_trans));
1253
  io.read_to(&n_layer, sizeof(n_layer));
1254
 
1255
- if (n_layer != hparams.n_layer) {
1256
- LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
1257
  return false;
1258
  }
1259
- if (cell_count > size) {
1260
- LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
1261
  return false;
1262
  }
1263
  if (this->v_trans != (bool) v_trans) {
@@ -1266,13 +1427,15 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1266
  }
1267
 
1268
  // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1269
- for (uint32_t il = 0; il < n_layer; ++il) {
 
 
1270
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1271
 
1272
  // Read type of key
1273
  int32_t k_type_i_ref;
1274
  io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1275
- const int32_t k_type_i = (int32_t) k_l[il]->type;
1276
  if (k_type_i != k_type_i_ref) {
1277
  LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1278
  return false;
@@ -1281,7 +1444,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1281
  // Read row size of key
1282
  uint64_t k_size_row_ref;
1283
  io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1284
- const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1285
  if (k_size_row != k_size_row_ref) {
1286
  LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1287
  return false;
@@ -1289,18 +1452,20 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1289
 
1290
  if (cell_count) {
1291
  // Read and set the keys for the whole cell range
1292
- ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1293
  }
1294
  }
1295
 
1296
  if (!this->v_trans) {
1297
- for (uint32_t il = 0; il < n_layer; ++il) {
 
 
1298
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1299
 
1300
  // Read type of value
1301
  int32_t v_type_i_ref;
1302
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1303
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1304
  if (v_type_i != v_type_i_ref) {
1305
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1306
  return false;
@@ -1309,7 +1474,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1309
  // Read row size of value
1310
  uint64_t v_size_row_ref;
1311
  io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1312
- const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
1313
  if (v_size_row != v_size_row_ref) {
1314
  LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1315
  return false;
@@ -1317,18 +1482,20 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1317
 
1318
  if (cell_count) {
1319
  // Read and set the values for the whole cell range
1320
- ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1321
  }
1322
  }
1323
  } else {
1324
  // For each layer, read the values for each cell (transposed)
1325
- for (uint32_t il = 0; il < n_layer; ++il) {
 
 
1326
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1327
 
1328
  // Read type of value
1329
  int32_t v_type_i_ref;
1330
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1331
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1332
  if (v_type_i != v_type_i_ref) {
1333
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1334
  return false;
@@ -1337,7 +1504,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1337
  // Read element size of value
1338
  uint32_t v_size_el_ref;
1339
  io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1340
- const size_t v_size_el = ggml_type_size(v_l[il]->type);
1341
  if (v_size_el != v_size_el_ref) {
1342
  LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1343
  return false;
@@ -1354,8 +1521,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1354
  if (cell_count) {
1355
  // For each row in the transposed matrix, read the values for the whole cell range
1356
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1357
- const size_t dst_offset = (head + j * size) * v_size_el;
1358
- ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1359
  }
1360
  }
1361
  }
@@ -1364,6 +1531,193 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1364
  return true;
1365
  }
1366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1367
  //
1368
  // llama_kv_cache_recurrent
1369
  //
@@ -1373,19 +1727,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1373
  ggml_type type_k,
1374
  ggml_type type_v,
1375
  bool offload,
1376
- uint32_t kv_size) : hparams(model.hparams) {
 
1377
  const int32_t n_layer = hparams.n_layer;
1378
 
1379
- LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
1380
- __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
1381
 
1382
  head = 0;
1383
  size = kv_size;
1384
  used = 0;
1385
 
1386
- this->type_k = type_k;
1387
- this->type_v = type_v;
1388
-
1389
  cells.clear();
1390
  cells.resize(kv_size);
1391
 
@@ -1623,8 +1975,8 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
1623
  }
1624
  }
1625
 
1626
- void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
1627
- if (delta == 0) {
1628
  return;
1629
  }
1630
 
@@ -1647,7 +1999,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
1647
  if (tail_id >= 0) {
1648
  kv_cell & cell = cells[tail_id];
1649
  if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
1650
- cell.pos += delta;
1651
  }
1652
  }
1653
  }
@@ -1683,8 +2035,24 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
1683
  }
1684
  }
1685
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1686
  llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
1687
- llama_pos result = 0;
1688
 
1689
  for (uint32_t i = 0; i < size; ++i) {
1690
  if (cells[i].has_seq_id(seq_id)) {
@@ -1707,8 +2075,8 @@ void llama_kv_cache_recurrent::commit() {
1707
  pending.ranges.clear();
1708
  }
1709
 
1710
- bool llama_kv_cache_recurrent::update(llama_context & lctx) {
1711
- GGML_UNUSED(lctx);
1712
  return false;
1713
  }
1714
 
@@ -1769,7 +2137,7 @@ bool llama_kv_cache_recurrent::find_slot(
1769
  if (seq_id < 0 || (uint32_t) seq_id >= size) {
1770
  // too big seq_id
1771
  // TODO: would it be possible to resize the cache instead?
1772
- LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
1773
  return false;
1774
  }
1775
  if (j > 0) {
@@ -1912,29 +2280,6 @@ bool llama_kv_cache_recurrent::find_slot(
1912
  return n >= n_seqs;
1913
  }
1914
 
1915
- int32_t llama_kv_cache_recurrent::get_n_tokens() const {
1916
- int32_t result = 0;
1917
-
1918
- for (uint32_t i = 0; i < size; i++) {
1919
- result += cells[i].seq_id.size();
1920
- }
1921
-
1922
- return result;
1923
- }
1924
-
1925
- int32_t llama_kv_cache_recurrent::get_used_cells() const {
1926
- return used;
1927
- }
1928
-
1929
- llama_pos llama_kv_cache_recurrent::get_pos_max() const {
1930
- llama_pos pos_max = -1;
1931
- for (const auto & cell : cells) {
1932
- pos_max = std::max(pos_max, cell.pos);
1933
- }
1934
-
1935
- return pos_max;
1936
- }
1937
-
1938
  bool llama_kv_cache_recurrent::get_can_shift() const {
1939
  return false;
1940
  }
@@ -2063,6 +2408,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
2063
  io.read_to(&cell_count, sizeof(cell_count));
2064
 
2065
  bool res = true;
 
2066
  res = res && state_read_meta(io, cell_count, seq_id);
2067
  res = res && state_read_data(io, cell_count);
2068
 
@@ -2391,104 +2737,3 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
2391
 
2392
  return true;
2393
  }
2394
-
2395
- //
2396
- // kv cache view
2397
- //
2398
-
2399
- llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) {
2400
- llama_kv_cache_view result = {
2401
- /*.n_cells = */ 0,
2402
- /*.n_seq_max = */ n_seq_max,
2403
- /*.token_count = */ 0,
2404
- /*.used_cells = */ kv.get_used_cells(),
2405
- /*.max_contiguous = */ 0,
2406
- /*.max_contiguous_idx = */ -1,
2407
- /*.cells = */ nullptr,
2408
- /*.cells_sequences = */ nullptr,
2409
- };
2410
-
2411
- return result;
2412
- }
2413
-
2414
- void llama_kv_cache_view_free(llama_kv_cache_view * view) {
2415
- if (view->cells != nullptr) {
2416
- free(view->cells);
2417
- view->cells = nullptr;
2418
- }
2419
- if (view->cells_sequences != nullptr) {
2420
- free(view->cells_sequences);
2421
- view->cells_sequences = nullptr;
2422
- }
2423
- }
2424
-
2425
- void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv) {
2426
- // TODO: rework this in the future, for now quick hack
2427
- const llama_kv_cache_unified * kvu = dynamic_cast<const llama_kv_cache_unified *>(kv);
2428
- if (kvu == nullptr) {
2429
- LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__);
2430
- return;
2431
- }
2432
-
2433
- if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) {
2434
- view->n_cells = int32_t(kvu->size);
2435
- void * p = realloc(view->cells, sizeof(llama_kv_cache_view_cell) * view->n_cells);
2436
- GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
2437
- view->cells = (llama_kv_cache_view_cell *)p;
2438
- p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
2439
- GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
2440
- view->cells_sequences = (llama_seq_id *)p;
2441
- }
2442
-
2443
- const std::vector<llama_kv_cache_unified::kv_cell> & kv_cells = kvu->cells;
2444
- llama_kv_cache_view_cell * c_curr = view->cells;
2445
- llama_seq_id * cs_curr = view->cells_sequences;
2446
- int32_t used_cells = 0;
2447
- int32_t token_count = 0;
2448
- int32_t curr_contig_idx = -1;
2449
- uint32_t max_contig = 0;
2450
- int32_t max_contig_idx = -1;
2451
-
2452
- for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) {
2453
- const size_t curr_size = kv_cells[i].seq_id.size();
2454
- token_count += curr_size;
2455
- c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
2456
-
2457
- if (curr_size > 0) {
2458
- if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
2459
- max_contig = i - curr_contig_idx;
2460
- max_contig_idx = curr_contig_idx;
2461
- }
2462
- curr_contig_idx = -1;
2463
- } else if (curr_contig_idx < 0) {
2464
- curr_contig_idx = i;
2465
- }
2466
-
2467
- int seq_idx = 0;
2468
- for (const llama_seq_id it : kv_cells[i].seq_id) {
2469
- if (seq_idx >= view->n_seq_max) {
2470
- break;
2471
- }
2472
- cs_curr[seq_idx] = it;
2473
- seq_idx++;
2474
- }
2475
- if (seq_idx != 0) {
2476
- used_cells++;
2477
- }
2478
- for (; seq_idx < view->n_seq_max; seq_idx++) {
2479
- cs_curr[seq_idx] = -1;
2480
- }
2481
- }
2482
- if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
2483
- max_contig_idx = curr_contig_idx;
2484
- max_contig = kv_cells.size() - curr_contig_idx;
2485
- }
2486
- view->max_contiguous = max_contig;
2487
- view->max_contiguous_idx = max_contig_idx;
2488
- view->token_count = token_count;
2489
- view->used_cells = used_cells;
2490
- if (uint32_t(used_cells) != kvu->used) {
2491
- LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
2492
- __func__, kvu->used, used_cells);
2493
- }
2494
- }
 
23
  }
24
 
25
  llama_kv_cache_unified::llama_kv_cache_unified(
26
+ const llama_model & model,
27
+ layer_filter_cb && filter,
28
+ ggml_type type_k,
29
+ ggml_type type_v,
30
+ bool v_trans,
31
+ bool offload,
32
+ uint32_t kv_size,
33
+ uint32_t n_seq_max,
34
+ uint32_t n_pad,
35
+ uint32_t n_swa,
36
+ llama_swa_type swa_type) :
37
+ model(model), hparams(model.hparams), v_trans(v_trans),
38
+ n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
39
+
40
+ GGML_ASSERT(kv_size % n_pad == 0);
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  // create a context for each buffer type
43
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
 
45
  auto it = ctx_map.find(buft);
46
  if (it == ctx_map.end()) {
47
  ggml_init_params params = {
48
+ /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
49
  /*.mem_buffer =*/ NULL,
50
  /*.no_alloc =*/ true,
51
  };
 
64
  return it->second;
65
  };
66
 
67
+ head = 0;
 
68
 
69
+ cells.resize(kv_size);
70
+
71
+ for (uint32_t il = 0; il < hparams.n_layer; il++) {
72
+ if (filter && !filter(il)) {
73
+ LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
74
+ continue;
75
+ }
76
+
77
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
78
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
79
 
80
  const char * dev_name = "CPU";
81
 
82
  ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
83
 
84
  if (offload) {
85
+ auto * dev = model.dev_layer(il);
86
  buft = ggml_backend_dev_buffer_type(dev);
87
 
88
  dev_name = ggml_backend_dev_name(dev);
89
  }
90
 
91
+ LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
92
 
93
  ggml_context * ctx = ctx_for_buft(buft);
94
  if (!ctx) {
95
  throw std::runtime_error("failed to create ggml context for kv cache");
96
  }
97
 
98
+ ggml_tensor * k;
99
+ ggml_tensor * v;
100
+
101
+ k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
102
+ v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
103
+
104
+ ggml_format_name(k, "cache_k_l%d", il);
105
+ ggml_format_name(v, "cache_v_l%d", il);
106
+
107
+ map_layer_ids[il] = layers.size();
108
+ layers.push_back({ il, k, v });
109
  }
110
 
111
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
 
117
  if (!buf) {
118
  throw std::runtime_error("failed to allocate buffer for kv cache");
119
  }
120
+
121
  LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
122
+
123
+ ggml_backend_buffer_clear(buf, 0);
124
  bufs.emplace_back(buf);
125
  }
126
 
 
128
  const size_t memory_size_k = size_k_bytes();
129
  const size_t memory_size_v = size_v_bytes();
130
 
131
+ LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
132
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
133
  ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
134
  ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
135
  }
136
  }
137
 
138
  void llama_kv_cache_unified::clear() {
139
+ cells.reset();
140
+
 
 
141
  head = 0;
 
142
 
143
  for (auto & buf : bufs) {
144
  ggml_backend_buffer_clear(buf.get(), 0);
 
146
  }
147
 
148
  bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
149
+ uint32_t new_head = cells.size();
150
 
151
  if (p0 < 0) {
152
  p0 = 0;
 
156
  p1 = std::numeric_limits<llama_pos>::max();
157
  }
158
 
159
+ for (uint32_t i = 0; i < cells.size(); ++i) {
160
+ if (!cells.pos_in(i, p0, p1)) {
161
+ continue;
162
+ }
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
165
+ if (new_head == cells.size()) {
166
+ new_head = i;
167
  }
168
  }
169
  }
170
 
171
  // If we freed up a slot, set head to it so searching can start there.
172
+ if (new_head != cells.size() && new_head < head) {
173
  head = new_head;
174
  }
175
 
 
189
  p1 = std::numeric_limits<llama_pos>::max();
190
  }
191
 
192
+ for (uint32_t i = 0; i < cells.size(); ++i) {
193
+ if (!cells.pos_in(i, p0, p1)) {
194
+ continue;
195
+ }
196
 
197
+ if (cells.seq_has(i, seq_id_src)) {
198
+ cells.seq_add(i, seq_id_dst);
 
199
  }
200
  }
201
  }
202
 
203
  void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
204
+ uint32_t new_head = cells.size();
205
 
206
+ for (uint32_t i = 0; i < cells.size(); ++i) {
207
+ if (cells.seq_keep(i, seq_id)) {
208
+ if (new_head == cells.size()) {
 
 
 
 
 
 
 
209
  new_head = i;
210
  }
 
 
 
211
  }
212
  }
213
 
214
  // If we freed up a slot, set head to it so searching can start there.
215
+ if (new_head != cells.size() && new_head < head) {
216
  head = new_head;
217
  }
218
  }
219
 
220
+ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
221
+ if (shift == 0) {
222
  return;
223
  }
224
 
225
+ uint32_t new_head = cells.size();
226
 
227
  if (p0 < 0) {
228
  p0 = 0;
 
232
  p1 = std::numeric_limits<llama_pos>::max();
233
  }
234
 
235
+ // If there is no range then return early to avoid looping over all cells.
236
  if (p0 == p1) {
237
  return;
238
  }
239
 
240
+ for (uint32_t i = 0; i < cells.size(); ++i) {
241
+ if (!cells.pos_in(i, p0, p1)) {
242
+ continue;
243
+ }
 
244
 
245
+ if (cells.seq_has(i, seq_id)) {
246
+ if (cells.pos_add(i, shift)) {
247
+ if (new_head == cells.size()) {
 
 
 
 
248
  new_head = i;
249
  }
250
  }
 
253
 
254
  // If we freed up a slot, set head to it so searching can start there.
255
  // Otherwise we just start the next search from the beginning.
256
+ head = new_head != cells.size() ? new_head : 0;
257
  }
258
 
259
  void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
 
274
  return;
275
  }
276
 
277
+ for (uint32_t i = 0; i < cells.size(); ++i) {
278
+ if (!cells.pos_in(i, p0, p1)) {
279
+ continue;
280
+ }
281
 
282
+ if (cells.seq_has(i, seq_id)) {
283
+ cells.pos_div(i, d);
 
 
 
284
  }
285
  }
286
  }
287
 
288
+ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
289
+ return cells.seq_pos_min(seq_id);
290
+ }
 
 
 
 
 
291
 
292
+ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
293
+ return cells.seq_pos_max(seq_id);
294
  }
295
 
296
  void llama_kv_cache_unified::restore() {
297
+ for (auto & state : recovery.states) {
298
+ cells.set(state.i, state.cells);
299
  }
300
 
301
+ recovery.clear();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  }
303
 
304
  void llama_kv_cache_unified::commit() {
305
+ if (recovery.states.empty()) {
306
+ LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
307
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
308
  return;
309
  }
310
 
311
+ recovery.clear();
312
  }
313
 
314
  bool llama_kv_cache_unified::update(llama_context & lctx) {
 
316
 
317
  auto * sched = lctx.get_sched();
318
 
319
+ if (cells.get_has_shift()) {
320
  if (!get_can_shift()) {
321
  GGML_ABORT("The current KV cache / model configuration does not support K-shift");
322
  }
 
340
  need_reserve = true;
341
  }
342
 
343
+ cells.reset_shift();
 
 
 
 
 
 
344
  }
345
 
346
  if (do_defrag) {
 
371
  void llama_kv_cache_unified::defrag_sched(float thold) {
372
  // - do not defrag small contexts (i.e. < 2048 tokens)
373
  // - count the padding towards the number of used tokens
374
+ const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f;
375
 
376
  // queue defragmentation for next llama_kv_cache_update
377
  if (fragmentation > thold) {
 
382
  }
383
 
384
  void llama_kv_cache_unified::set_full() {
385
+ n = cells.size();
386
 
387
  // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
388
  // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
 
392
  head = 0;
393
  }
394
 
395
+ llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
 
 
396
  return llama_sbatch(batch, hparams.n_embd, true, logits_all);
397
  }
398
 
399
+ llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
 
 
 
400
  GGML_UNUSED(embd_pooled);
401
  return sbatch.split_simple(n_ubatch);
402
  }
403
 
404
+ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
 
405
  const uint32_t n_tokens = ubatch.n_tokens;
 
 
406
 
407
  // if we have enough unused cells before the current head ->
408
  // better to start searching from the beginning of the cache, hoping to fill it
409
+ if (head > cells.get_used() + 2*ubatch.n_tokens) {
410
  head = 0;
411
  }
412
 
413
  // otherwise, one cell per token.
414
 
415
+ if (n_tokens > cells.size()) {
416
+ LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
417
  return false;
418
  }
419
 
420
+ //#define FIND_SLOT_DEBUG 1
421
+ #if FIND_SLOT_DEBUG
422
+ LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
423
+
424
+ // for debugging
425
+ {
426
+ std::string ss;
427
+ if (n_swa > 0) {
428
+ for (uint32_t i = 0; i < size; ++i) {
429
+ if (cells.is_empty(i)) {
430
+ ss += '.';
431
+ } else {
432
+ ss += 'x';
433
+ }
434
+ if (i%256 == 255) {
435
+ ss += '\n';
436
+ }
437
+ }
438
+ }
439
+ LLAMA_LOG_WARN("\n%s\n", ss.c_str());
440
+ }
441
+ #endif
442
+
443
  uint32_t n_tested = 0;
444
 
445
  while (true) {
446
+ if (head + n_tokens > cells.size()) {
447
+ n_tested += cells.size() - head;
448
  head = 0;
449
  continue;
450
  }
451
 
452
  bool found = true;
453
  for (uint32_t i = 0; i < n_tokens; i++) {
454
+ // TODO: improve to accept cells that are masked by the SWA
455
+ if (!cells.is_empty(head + i)) {
456
  found = false;
457
  head += i + 1;
458
  n_tested += i + 1;
 
464
  break;
465
  }
466
 
467
+ if (n_tested >= cells.size()) {
468
  //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
469
  return false;
470
  }
471
  }
472
 
473
+ // store the old state of the cells in the recovery stack
474
+ recovery.states.push_back({head, cells.cp(head, n_tokens)});
 
 
475
 
476
+ for (uint32_t i = 0; i < n_tokens; ++i) {
477
+ cells.pos_set(head + i, ubatch.pos[i]);
478
+
479
+ for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
480
+ cells.seq_add(head + i, ubatch.seq_id[i][j]);
481
  }
482
  }
483
 
 
 
 
 
484
  // a heuristic, to avoid attending the full cache if it is not yet utilized
485
  // after enough generations, the benefit from this heuristic disappears
486
  // if we start defragmenting the cache, the benefit from this will be more important
487
+ n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
488
+
489
+ #ifdef FIND_SLOT_DEBUG
490
+ LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
491
+ #endif
492
 
493
+ return true;
494
+ }
495
 
496
+ bool llama_kv_cache_unified::get_can_shift() const {
497
  return true;
498
  }
499
 
500
+ uint32_t llama_kv_cache_unified::get_n() const {
501
+ return n;
502
+ }
503
+
504
+ uint32_t llama_kv_cache_unified::get_size() const {
505
+ return cells.size();
506
+ }
507
+
508
+ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
509
+ const int32_t ikv = map_layer_ids.at(il);
510
+
511
+ auto * k = layers[ikv].k;
512
+
513
+ return ggml_view_3d(ctx, k,
514
+ hparams.n_embd_head_k, hparams.n_head_kv(il), n,
515
+ ggml_row_size(k->type, hparams.n_embd_head_k),
516
+ ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
517
+ 0);
518
+ }
519
+
520
+ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const {
521
+ const int32_t ikv = map_layer_ids.at(il);
522
+
523
+ auto * v = layers[ikv].v;
524
 
525
+ if (!v_trans) {
526
+ // note: v->nb[1] <= v->nb[2]
527
+ return ggml_view_3d(ctx, v,
528
+ hparams.n_embd_head_v, hparams.n_head_kv(il), n,
529
+ ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
530
+ ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
531
+ 0);
532
  }
533
 
534
+ // note: v->nb[1] > v->nb[2]
535
+ return ggml_view_3d(ctx, v,
536
+ n, hparams.n_head_kv(il), hparams.n_embd_head_v,
537
+ ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
538
+ ggml_row_size(v->type, v->ne[1]), // v->nb[2]
539
+ 0);
540
  }
541
 
542
+ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
543
+ const int32_t ikv = map_layer_ids.at(il);
544
+
545
+ auto * k = layers[ikv].k;
546
+
547
+ const int64_t n_tokens = k_cur->ne[2];
548
+
549
+ ggml_tensor * k_view = ggml_view_1d(ctx, k,
550
+ n_tokens*hparams.n_embd_k_gqa(il),
551
+ ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head);
552
+
553
+ return ggml_cpy(ctx, k_cur, k_view);
554
  }
555
 
556
+ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
557
+ const int32_t ikv = map_layer_ids.at(il);
558
+
559
+ auto * v = layers[ikv].v;
560
+
561
+ const int64_t n_tokens = v_cur->ne[2];
562
+
563
+ v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
564
+
565
+ ggml_tensor * v_view = nullptr;
566
+
567
+ if (!v_trans) {
568
+ v_view = ggml_view_1d(ctx, v,
569
+ n_tokens*hparams.n_embd_v_gqa(il),
570
+ ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head);
571
+ } else {
572
+ // note: the V cache is transposed when not using flash attention
573
+ v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
574
+ (v->ne[1])*ggml_element_size(v),
575
+ ( head)*ggml_element_size(v));
576
+
577
+ v_cur = ggml_transpose(ctx, v_cur);
578
+ }
579
+
580
+ return ggml_cpy(ctx, v_cur, v_view);
581
+ }
582
+
583
+ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) {
584
+ // no pruning is needed when the cache does not use SWA
585
+ GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
586
+
587
+ int n_attended = 0;
588
+
589
+ for (uint32_t i = 0; i < cells.size(); ++i) {
590
+ if (!cells.seq_has(i, seq_id)) {
591
+ continue;
592
+ }
593
+
594
+ const llama_pos p0 = cells.pos_get(i);
595
+
596
+ if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
597
+ n_attended++;
598
+ }
599
+
600
+ if (is_masked_swa(p0, pmax)) {
601
+ cells.seq_rm(i, seq_id);
602
+ }
603
+ }
604
+
605
+ if (n_attended < std::min<int>(n_swa, pmin)) {
606
+ LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa);
607
+ }
608
+ }
609
+
610
+ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
611
+ const int64_t n_tokens = ubatch->n_tokens;
612
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
613
+ const int64_t n_seqs = ubatch->n_seqs;
614
+
615
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
616
+ float * data = (float *) dst->data;
617
+
618
+ const int64_t n_kv = n;
619
+
620
+ // Use only the previous KV cells of the correct sequence for each token of the ubatch.
621
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
622
+ // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
623
+ // Causal mask:
624
+ // xxx-------
625
+ // xxxx------
626
+ // xxxxx-----
627
+ // Non-causal mask:
628
+ // xxxxx-----
629
+ // xxxxx-----
630
+ // xxxxx-----
631
+ // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
632
+ for (int h = 0; h < 1; ++h) {
633
+ for (int s = 0; s < n_seqs; ++s) {
634
+ const llama_seq_id seq_id = ubatch->seq_id[s][0];
635
+
636
+ for (int j = 0; j < n_seq_tokens; ++j) {
637
+ const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
638
+
639
+ for (int i = 0; i < n_kv; ++i) {
640
+ float f = 0.0f;
641
+
642
+ bool masked = false;
643
+
644
+ if (cells.is_empty(i)) {
645
+ masked = true;
646
+ } else {
647
+ const llama_pos p0 = cells.pos_get(i);
648
+
649
+ // mask the token if not the same sequence
650
+ masked = masked || (!cells.seq_has(i, seq_id));
651
+
652
+ // mask future tokens
653
+ masked = masked || (causal_attn && p0 > p1);
654
+
655
+ // apply SWA if any
656
+ masked = masked || (is_masked_swa(p0, p1));
657
+
658
+ if (!masked && hparams.use_alibi) {
659
+ f = -std::abs(p0 - p1);
660
+ }
661
+ }
662
+
663
+ if (masked) {
664
+ f = -INFINITY;
665
+ }
666
+
667
+ data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
668
+ }
669
+ }
670
+ }
671
+
672
+ // mask padded tokens
673
+ if (data) {
674
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
675
+ for (int j = 0; j < n_kv; ++j) {
676
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
677
+ }
678
+ }
679
+ }
680
+ }
681
  }
682
 
683
+ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
684
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
685
+
686
+ int32_t * data = (int32_t *) dst->data;
687
+
688
+ for (uint32_t i = 0; i < cells.size(); ++i) {
689
+ data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
690
  }
691
+ }
692
+
693
+ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
694
+ const int64_t n_tokens = ubatch->n_tokens;
695
+
696
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
697
+ GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
698
+
699
+ int32_t * data = (int32_t *) dst->data;
700
+
701
+ const int64_t n_kv = n;
702
 
703
+ for (int h = 0; h < 1; ++h) {
704
+ for (int j = 0; j < n_tokens; ++j) {
705
+ for (int i = 0; i < n_kv; ++i) {
706
+ // the position when the cells is empty is irrelevant - it will be masked out later in the attention
707
+ const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
708
+
709
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
710
+ }
711
+ }
712
+ }
713
  }
714
 
715
  size_t llama_kv_cache_unified::total_size() const {
716
  size_t size = 0;
717
+
718
  for (const auto & buf : bufs) {
719
  size += ggml_backend_buffer_get_size(buf.get());
720
  }
 
725
  size_t llama_kv_cache_unified::size_k_bytes() const {
726
  size_t size_k_bytes = 0;
727
 
728
+ for (const auto & layer : layers) {
729
+ size_k_bytes += ggml_nbytes(layer.k);
730
  }
731
 
732
  return size_k_bytes;
 
735
  size_t llama_kv_cache_unified::size_v_bytes() const {
736
  size_t size_v_bytes = 0;
737
 
738
+ for (const auto & layer : layers) {
739
+ size_v_bytes += ggml_nbytes(layer.v);
740
  }
741
 
742
  return size_v_bytes;
 
800
  GGML_UNUSED(ubatch);
801
 
802
  if (k_shift) {
803
+ kv_self->set_input_k_shift(k_shift);
 
 
 
 
 
 
804
  }
805
  }
806
 
 
810
  ggml_cgraph * gf) const {
811
  auto res = std::make_unique<llm_graph_result>();
812
 
 
 
813
  const auto & n_embd_head_k = hparams.n_embd_head_k;
814
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
815
 
 
 
816
  //GGML_ASSERT(kv_self->size == n_ctx);
817
 
818
  auto inp = std::make_unique<llm_graph_input_k_shift>(this);
 
820
  inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
821
  ggml_set_input(inp->k_shift);
822
 
823
+ for (const auto & layer : layers) {
824
+ const uint32_t il = layer.il;
825
+
826
  const int64_t n_head_kv = hparams.n_head_kv(il);
827
  const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
828
 
829
+ const float freq_base_l = model.get_rope_freq_base (cparams, il);
830
+ const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
831
 
832
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
 
 
 
 
833
 
834
  ggml_tensor * k =
835
+ ggml_view_3d(ctx, layer.k,
836
+ n_embd_head_k, n_head_kv, cells.size(),
837
+ ggml_row_size(layer.k->type, n_embd_head_k),
838
+ ggml_row_size(layer.k->type, n_embd_k_gqa),
839
  0);
840
 
841
  ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
 
940
  nm++;
941
  }
942
 
943
+ for (const auto & layer : layers) {
944
+ const uint32_t il = layer.il;
945
+
946
  const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
947
  const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
948
 
949
+ ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
950
  n_embd_k_gqa, nm,
951
+ ggml_row_size(layer.k->type, n_embd_k_gqa),
952
+ ggml_row_size(layer.k->type, n_embd_k_gqa*i));
953
 
954
+ ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
955
  n_embd_k_gqa, nm,
956
+ ggml_row_size(layer.k->type, n_embd_k_gqa),
957
+ ggml_row_size(layer.k->type, n_embd_k_gqa*id));
958
 
959
  ggml_tensor * view_v_src;
960
  ggml_tensor * view_v_dst;
961
 
962
  if (cparams.flash_attn) {
963
  // NOTE: the V cache is not transposed when using flash attention
964
+ view_v_src = ggml_view_2d(ctx, layer.v,
965
  n_embd_v_gqa, nm,
966
+ ggml_row_size(layer.v->type, n_embd_v_gqa),
967
+ ggml_row_size(layer.v->type, n_embd_v_gqa*i));
968
 
969
+ view_v_dst = ggml_view_2d(ctx, layer.v,
970
  n_embd_v_gqa, nm,
971
+ ggml_row_size(layer.v->type, n_embd_v_gqa),
972
+ ggml_row_size(layer.v->type, n_embd_v_gqa*id));
973
  } else {
974
+ view_v_src = ggml_view_2d(ctx, layer.v,
975
  nm, n_embd_v_gqa,
976
+ ggml_row_size(layer.v->type, cells.size()),
977
+ ggml_row_size(layer.v->type, i));
978
 
979
+ view_v_dst = ggml_view_2d(ctx, layer.v,
980
  nm, n_embd_v_gqa,
981
+ ggml_row_size(layer.v->type, cells.size()),
982
+ ggml_row_size(layer.v->type, id));
983
  }
984
 
985
  ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
 
996
  }
997
 
998
  bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
999
+ const uint32_t n_layer = layers.size();
1000
 
1001
+ const uint32_t n_kv = cells.used_max_p1();
1002
+ const uint32_t n_used = cells.get_used();
1003
 
1004
  assert(n_used <= n_kv);
1005
 
 
1027
  ids.resize(n_kv, n_kv);
1028
 
1029
  for (uint32_t i0 = 0; i0 < n_used; ++i0) {
1030
+ if (!cells.is_empty(i0)) {
 
 
1031
  ids[i0] = i0;
1032
 
1033
  continue;
 
1038
  uint32_t nh = 1;
1039
 
1040
  // determine the size of the hole
1041
+ while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
1042
  nh++;
1043
  }
1044
 
 
1047
 
1048
  // starting from the end, find nh non-empty cells
1049
  for (; is > i0; --is) {
1050
+ if (cells.is_empty(is) || ids[is] != n_kv) {
 
 
1051
  continue;
1052
  }
1053
 
 
1074
 
1075
  // go back and move the nf cells to the hole
1076
  for (; i1 < n_kv; ++i1) {
1077
+ if (cells.is_empty(i1) || ids[i1] != n_kv) {
 
 
1078
  if (n_moves == max_moves) {
1079
  stop = true;
1080
  break;
 
1088
  ids[i1] = i0 + nf;
1089
 
1090
  // move the cell meta data
1091
+ cells.mv(i1, i0 + nf);
1092
 
 
 
1093
  head = n_used;
1094
 
1095
  if (!cont) {
 
1124
  return true;
1125
  }
1126
 
1127
+ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
1128
+ assert(p0 >= 0 && p1 >= 0);
 
1129
 
1130
+ switch (swa_type) {
1131
+ case LLAMA_SWA_TYPE_NONE:
1132
+ {
1133
+ } break;
1134
+ case LLAMA_SWA_TYPE_STANDARD:
1135
+ {
1136
+ if (p1 - p0 >= (int32_t) n_swa) {
1137
+ return true;
1138
+ }
1139
+ } break;
1140
+ case LLAMA_SWA_TYPE_CHUNKED:
1141
+ {
1142
+ const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
1143
+
1144
+ if (p0 < pos_chunk_start) {
1145
+ return true;
1146
+ }
1147
+ } break;
1148
  }
1149
 
1150
+ return false;
1151
  }
1152
 
1153
  void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
 
1156
 
1157
  // Count the number of cells with the specified seq_id
1158
  // Find all the ranges of cells with this seq id (or all, when -1)
1159
+ uint32_t cell_range_begin = cells.size();
1160
+
1161
+ for (uint32_t i = 0; i < cells.size(); ++i) {
1162
+ if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1163
  ++cell_count;
1164
+ if (cell_range_begin == cells.size()) {
1165
  cell_range_begin = i;
1166
  }
1167
  } else {
1168
+ if (cell_range_begin != cells.size()) {
1169
  cell_ranges.emplace_back(cell_range_begin, i);
1170
+ cell_range_begin = cells.size();
1171
  }
1172
  }
1173
  }
1174
+
1175
+ if (cell_range_begin != cells.size()) {
1176
+ cell_ranges.emplace_back(cell_range_begin, cells.size());
1177
  }
1178
 
1179
  // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
 
1210
  void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
1211
  for (const auto & range : cell_ranges) {
1212
  for (uint32_t i = range.first; i < range.second; ++i) {
1213
+ std::vector<llama_seq_id> seq_ids;
1214
+
1215
+ for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
1216
+ if (cur == seq_id || seq_id == -1) {
1217
+ if (cells.seq_has(i, cur)) {
1218
+ seq_ids.push_back(cur);
1219
+ }
1220
+ }
1221
+ }
1222
+
1223
+ const llama_pos pos = cells.pos_get(i);
1224
+ const uint32_t n_seq_id = seq_ids.size();
1225
 
1226
  io.write(&pos, sizeof(pos));
1227
  io.write(&n_seq_id, sizeof(n_seq_id));
1228
 
1229
+ for (const auto & seq_id : seq_ids) {
1230
+ io.write(&seq_id, sizeof(seq_id));
 
 
1231
  }
1232
  }
1233
  }
 
1235
 
1236
  void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
1237
  const uint32_t v_trans = this->v_trans ? 1 : 0;
1238
+ const uint32_t n_layer = layers.size();
1239
 
1240
  io.write(&v_trans, sizeof(v_trans));
1241
  io.write(&n_layer, sizeof(n_layer));
 
1244
 
1245
  // Iterate and write all the keys first, each row is a cell
1246
  // Get whole range at a time
1247
+ for (const auto & layer : layers) {
1248
+ const uint32_t il = layer.il;
1249
+
1250
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1251
 
1252
  // Write key type
1253
+ const int32_t k_type_i = (int32_t)layer.k->type;
1254
  io.write(&k_type_i, sizeof(k_type_i));
1255
 
1256
  // Write row size of key
1257
+ const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
1258
  io.write(&k_size_row, sizeof(k_size_row));
1259
 
1260
  // Read each range of cells of k_size length each into tmp_buf and write out
1261
  for (const auto & range : cell_ranges) {
1262
  const size_t range_size = range.second - range.first;
1263
  const size_t buf_size = range_size * k_size_row;
1264
+ io.write_tensor(layer.k, range.first * k_size_row, buf_size);
1265
  }
1266
  }
1267
 
1268
  if (!v_trans) {
1269
+ for (const auto & layer : layers) {
1270
+ const uint32_t il = layer.il;
1271
+
1272
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1273
 
1274
  // Write value type
1275
+ const int32_t v_type_i = (int32_t)layer.v->type;
1276
  io.write(&v_type_i, sizeof(v_type_i));
1277
 
1278
  // Write row size of value
1279
+ const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
1280
  io.write(&v_size_row, sizeof(v_size_row));
1281
 
1282
  // Read each range of cells of v_size length each into tmp_buf and write out
1283
  for (const auto & range : cell_ranges) {
1284
  const size_t range_size = range.second - range.first;
1285
  const size_t buf_size = range_size * v_size_row;
1286
+ io.write_tensor(layer.v, range.first * v_size_row, buf_size);
1287
  }
1288
  }
1289
  } else {
1290
  // When v is transposed, we also need the element size and get the element ranges from each row
1291
+ const uint32_t kv_size = cells.size();
1292
+
1293
+ for (const auto & layer : layers) {
1294
+ const uint32_t il = layer.il;
1295
+
1296
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1297
 
1298
  // Write value type
1299
+ const int32_t v_type_i = (int32_t)layer.v->type;
1300
  io.write(&v_type_i, sizeof(v_type_i));
1301
 
1302
  // Write element size
1303
+ const uint32_t v_size_el = ggml_type_size(layer.v->type);
1304
  io.write(&v_size_el, sizeof(v_size_el));
1305
 
1306
  // Write GQA embedding size
 
1313
  const size_t range_size = range.second - range.first;
1314
  const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1315
  const size_t buf_size = range_size * v_size_el;
1316
+ io.write_tensor(layer.v, src_offset, buf_size);
1317
  }
1318
  }
1319
  }
 
1330
  llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1331
 
1332
  batch.n_tokens = cell_count;
 
 
1333
 
1334
  for (uint32_t i = 0; i < cell_count; ++i) {
1335
  llama_pos pos;
 
1338
  io.read_to(&pos, sizeof(pos));
1339
  io.read_to(&n_seq_id, sizeof(n_seq_id));
1340
 
1341
+ if (n_seq_id != 1) {
1342
  LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1343
  return false;
1344
  }
1345
 
1346
+ // read the sequence id, but directly discard it - we will use dest_seq_id instead
1347
+ {
1348
+ llama_seq_id seq_id;
1349
+ io.read_to(&seq_id, sizeof(seq_id));
1350
+ }
1351
+
1352
+ batch.pos[i] = pos;
1353
+ batch.n_seq_id[i] = n_seq_id;
1354
+ batch.seq_id[i] = &dest_seq_id;
1355
  }
1356
+
 
1357
  if (!find_slot(batch)) {
1358
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1359
  return false;
1360
  }
1361
+
1362
  commit();
1363
 
1364
  // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1365
  // Assume that this is one contiguous block of cells
1366
+ GGML_ASSERT(head + cell_count <= cells.size());
1367
+ GGML_ASSERT(cells.pos_get(head) == batch.pos[0]);
1368
+ GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
1369
+ GGML_ASSERT(cells.seq_has(head, dest_seq_id));
1370
+ GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
1371
  } else {
1372
  // whole KV cache restore
1373
 
1374
+ if (cell_count > cells.size()) {
1375
  LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1376
  return false;
1377
  }
 
1379
  clear();
1380
 
1381
  for (uint32_t i = 0; i < cell_count; ++i) {
 
 
1382
  llama_pos pos;
1383
  uint32_t n_seq_id;
1384
 
1385
  io.read_to(&pos, sizeof(pos));
1386
  io.read_to(&n_seq_id, sizeof(n_seq_id));
1387
 
1388
+ cells.pos_set(i, pos);
1389
 
1390
  for (uint32_t j = 0; j < n_seq_id; ++j) {
1391
  llama_seq_id seq_id;
1392
  io.read_to(&seq_id, sizeof(seq_id));
1393
 
1394
+ if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
1395
+ LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
 
 
 
1396
  return false;
1397
  }
1398
 
1399
+ cells.seq_add(i, seq_id);
1400
  }
1401
  }
1402
 
1403
  head = 0;
 
1404
  }
1405
 
1406
  return true;
 
1409
  bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
1410
  uint32_t v_trans;
1411
  uint32_t n_layer;
1412
+
1413
  io.read_to(&v_trans, sizeof(v_trans));
1414
  io.read_to(&n_layer, sizeof(n_layer));
1415
 
1416
+ if (n_layer != layers.size()) {
1417
+ LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
1418
  return false;
1419
  }
1420
+ if (cell_count > cells.size()) {
1421
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
1422
  return false;
1423
  }
1424
  if (this->v_trans != (bool) v_trans) {
 
1427
  }
1428
 
1429
  // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1430
+ for (const auto & layer : layers) {
1431
+ const uint32_t il = layer.il;
1432
+
1433
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1434
 
1435
  // Read type of key
1436
  int32_t k_type_i_ref;
1437
  io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1438
+ const int32_t k_type_i = (int32_t) layer.k->type;
1439
  if (k_type_i != k_type_i_ref) {
1440
  LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1441
  return false;
 
1444
  // Read row size of key
1445
  uint64_t k_size_row_ref;
1446
  io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1447
+ const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
1448
  if (k_size_row != k_size_row_ref) {
1449
  LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1450
  return false;
 
1452
 
1453
  if (cell_count) {
1454
  // Read and set the keys for the whole cell range
1455
+ ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1456
  }
1457
  }
1458
 
1459
  if (!this->v_trans) {
1460
+ for (const auto & layer : layers) {
1461
+ const uint32_t il = layer.il;
1462
+
1463
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1464
 
1465
  // Read type of value
1466
  int32_t v_type_i_ref;
1467
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1468
+ const int32_t v_type_i = (int32_t)layer.v->type;
1469
  if (v_type_i != v_type_i_ref) {
1470
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1471
  return false;
 
1474
  // Read row size of value
1475
  uint64_t v_size_row_ref;
1476
  io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1477
+ const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
1478
  if (v_size_row != v_size_row_ref) {
1479
  LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1480
  return false;
 
1482
 
1483
  if (cell_count) {
1484
  // Read and set the values for the whole cell range
1485
+ ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1486
  }
1487
  }
1488
  } else {
1489
  // For each layer, read the values for each cell (transposed)
1490
+ for (const auto & layer : layers) {
1491
+ const uint32_t il = layer.il;
1492
+
1493
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1494
 
1495
  // Read type of value
1496
  int32_t v_type_i_ref;
1497
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1498
+ const int32_t v_type_i = (int32_t)layer.v->type;
1499
  if (v_type_i != v_type_i_ref) {
1500
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1501
  return false;
 
1504
  // Read element size of value
1505
  uint32_t v_size_el_ref;
1506
  io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1507
+ const size_t v_size_el = ggml_type_size(layer.v->type);
1508
  if (v_size_el != v_size_el_ref) {
1509
  LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1510
  return false;
 
1521
  if (cell_count) {
1522
  // For each row in the transposed matrix, read the values for the whole cell range
1523
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1524
+ const size_t dst_offset = (head + j * cells.size()) * v_size_el;
1525
+ ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1526
  }
1527
  }
1528
  }
 
1531
  return true;
1532
  }
1533
 
1534
+ //
1535
+ // llama_kv_cache_unified_iswa
1536
+ //
1537
+
1538
+ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
1539
+ const llama_model & model,
1540
+ ggml_type type_k,
1541
+ ggml_type type_v,
1542
+ bool v_trans,
1543
+ bool offload,
1544
+ bool swa_full,
1545
+ uint32_t kv_size,
1546
+ uint32_t n_seq_max,
1547
+ uint32_t n_batch,
1548
+ uint32_t n_pad) : hparams(model.hparams) {
1549
+ llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
1550
+ llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
1551
+
1552
+ const uint32_t size_base = kv_size;
1553
+
1554
+ uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
1555
+
1556
+ // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
1557
+ if (swa_full) {
1558
+ LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
1559
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
1560
+
1561
+ size_swa = size_base;
1562
+ do_prune = false;
1563
+ }
1564
+
1565
+ LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
1566
+
1567
+ kv_base = std::make_unique<llama_kv_cache_unified>(
1568
+ model, std::move(filter_base), type_k, type_v,
1569
+ v_trans, offload, size_base, n_seq_max, n_pad,
1570
+ 0, LLAMA_SWA_TYPE_NONE);
1571
+
1572
+ LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
1573
+
1574
+ kv_swa = std::make_unique<llama_kv_cache_unified>(
1575
+ model, std::move(filter_swa), type_k, type_v,
1576
+ v_trans, offload, size_swa, n_seq_max, n_pad,
1577
+ hparams.n_swa, hparams.swa_type);
1578
+ }
1579
+
1580
+ void llama_kv_cache_unified_iswa::clear() {
1581
+ kv_base->clear();
1582
+ kv_swa ->clear();
1583
+ }
1584
+
1585
+ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1586
+ bool res = true;
1587
+
1588
+ res = res & kv_base->seq_rm(seq_id, p0, p1);
1589
+ res = res & kv_swa ->seq_rm(seq_id, p0, p1);
1590
+
1591
+ return res;
1592
+ }
1593
+
1594
+ void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
1595
+ kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1596
+ kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1597
+ }
1598
+
1599
+ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
1600
+ kv_base->seq_keep(seq_id);
1601
+ kv_swa ->seq_keep(seq_id);
1602
+ }
1603
+
1604
+ void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
1605
+ kv_base->seq_add(seq_id, p0, p1, shift);
1606
+ kv_swa ->seq_add(seq_id, p0, p1, shift);
1607
+ }
1608
+
1609
+ void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
1610
+ kv_base->seq_div(seq_id, p0, p1, d);
1611
+ kv_swa ->seq_div(seq_id, p0, p1, d);
1612
+ }
1613
+
1614
+ llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
1615
+ // the base cache is a superset of the SWA cache, so we can just check the SWA cache
1616
+ return kv_swa->seq_pos_min(seq_id);
1617
+ }
1618
+
1619
+ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
1620
+ return kv_swa->seq_pos_max(seq_id);
1621
+ }
1622
+
1623
+ void llama_kv_cache_unified_iswa::restore() {
1624
+ kv_base->restore();
1625
+ kv_swa ->restore();
1626
+ }
1627
+
1628
+ void llama_kv_cache_unified_iswa::commit() {
1629
+ kv_base->commit();
1630
+ kv_swa ->commit();
1631
+
1632
+ // slide the attention window, forgetting/pruning old tokens that are outside the window
1633
+ if (do_prune) {
1634
+ for (const auto & [seq_id, entry] : pending.pos) {
1635
+ kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax);
1636
+ }
1637
+
1638
+ }
1639
+
1640
+ pending.clear();
1641
+ }
1642
+
1643
+ bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
1644
+ bool res = true;
1645
+
1646
+ res = res & kv_base->update(lctx);
1647
+ res = res & kv_swa ->update(lctx);
1648
+
1649
+ return res;
1650
+ }
1651
+
1652
+ void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
1653
+ kv_base->defrag_sched(thold);
1654
+ kv_swa ->defrag_sched(thold);
1655
+ }
1656
+
1657
+ void llama_kv_cache_unified_iswa::set_full() {
1658
+ kv_base->set_full();
1659
+ kv_swa ->set_full();
1660
+ }
1661
+
1662
+ llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1663
+ pending.clear();
1664
+
1665
+ if (do_prune) {
1666
+ for (int i = 0; i < batch.n_tokens; ++i) {
1667
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
1668
+ const llama_seq_id seq_id = batch.seq_id[i][s];
1669
+ const llama_pos pos = batch.pos[i];
1670
+
1671
+ if (pending.pos.find(seq_id) == pending.pos.end()) {
1672
+ pending.pos[seq_id].pmin = pos;
1673
+ pending.pos[seq_id].pmax = pos;
1674
+ } else {
1675
+ pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos);
1676
+ pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos);
1677
+ }
1678
+ }
1679
+ }
1680
+ }
1681
+
1682
+ return llama_sbatch(batch, hparams.n_embd, true, logits_all);
1683
+ }
1684
+
1685
+ llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1686
+ GGML_UNUSED(embd_pooled);
1687
+ return sbatch.split_simple(n_ubatch);
1688
+ }
1689
+
1690
+ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
1691
+ bool res = true;
1692
+
1693
+ res = res & kv_base->find_slot(batch);
1694
+ res = res & kv_swa ->find_slot(batch);
1695
+
1696
+ return res;
1697
+ }
1698
+
1699
+ bool llama_kv_cache_unified_iswa::get_can_shift() const {
1700
+ return kv_base->get_size() == kv_swa->get_size();
1701
+ }
1702
+
1703
+ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1704
+ kv_base->state_write(io, seq_id);
1705
+ kv_swa ->state_write(io, seq_id);
1706
+ }
1707
+
1708
+ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1709
+ kv_base->state_read(io, seq_id);
1710
+ kv_swa ->state_read(io, seq_id);
1711
+ }
1712
+
1713
+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const {
1714
+ return kv_base.get();
1715
+ }
1716
+
1717
+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
1718
+ return kv_swa.get();
1719
+ }
1720
+
1721
  //
1722
  // llama_kv_cache_recurrent
1723
  //
 
1727
  ggml_type type_k,
1728
  ggml_type type_v,
1729
  bool offload,
1730
+ uint32_t kv_size,
1731
+ uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
1732
  const int32_t n_layer = hparams.n_layer;
1733
 
1734
+ LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
1735
+ __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
1736
 
1737
  head = 0;
1738
  size = kv_size;
1739
  used = 0;
1740
 
 
 
 
1741
  cells.clear();
1742
  cells.resize(kv_size);
1743
 
 
1975
  }
1976
  }
1977
 
1978
+ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
1979
+ if (shift == 0) {
1980
  return;
1981
  }
1982
 
 
1999
  if (tail_id >= 0) {
2000
  kv_cell & cell = cells[tail_id];
2001
  if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2002
+ cell.pos += shift;
2003
  }
2004
  }
2005
  }
 
2035
  }
2036
  }
2037
 
2038
+ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
2039
+ llama_pos result = std::numeric_limits<llama_pos>::max();
2040
+
2041
+ for (uint32_t i = 0; i < size; ++i) {
2042
+ if (cells[i].has_seq_id(seq_id)) {
2043
+ result = std::min(result, cells[i].pos);
2044
+ }
2045
+ }
2046
+
2047
+ if (result == std::numeric_limits<llama_pos>::max()) {
2048
+ result = -1;
2049
+ }
2050
+
2051
+ return result;
2052
+ }
2053
+
2054
  llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
2055
+ llama_pos result = -1;
2056
 
2057
  for (uint32_t i = 0; i < size; ++i) {
2058
  if (cells[i].has_seq_id(seq_id)) {
 
2075
  pending.ranges.clear();
2076
  }
2077
 
2078
+ bool llama_kv_cache_recurrent::update(llama_context & ctx) {
2079
+ GGML_UNUSED(ctx);
2080
  return false;
2081
  }
2082
 
 
2137
  if (seq_id < 0 || (uint32_t) seq_id >= size) {
2138
  // too big seq_id
2139
  // TODO: would it be possible to resize the cache instead?
2140
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
2141
  return false;
2142
  }
2143
  if (j > 0) {
 
2280
  return n >= n_seqs;
2281
  }
2282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2283
  bool llama_kv_cache_recurrent::get_can_shift() const {
2284
  return false;
2285
  }
 
2408
  io.read_to(&cell_count, sizeof(cell_count));
2409
 
2410
  bool res = true;
2411
+
2412
  res = res && state_read_meta(io, cell_count, seq_id);
2413
  res = res && state_read_data(io, cell_count);
2414
 
 
2737
 
2738
  return true;
2739
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/talk-llama/llama-kv-cache.h CHANGED
@@ -4,10 +4,12 @@
4
  #include "llama-io.h"
5
  #include "llama-graph.h"
6
  #include "llama-memory.h"
 
7
 
8
  #include "ggml-cpp.h"
9
 
10
  #include <set>
 
11
  #include <vector>
12
 
13
  struct llama_cparams;
@@ -34,12 +36,16 @@ struct llama_kv_cache : public llama_memory_i {
34
  virtual void defrag_sched(float thold) = 0;
35
 
36
  // simulate full cache, used for allocating worst-case compute buffers
 
37
  virtual void set_full() = 0;
38
 
39
  //
40
  // batch processing
41
  //
42
 
 
 
 
43
  virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
44
 
45
  // different KV caches require different batch splitting strategies
@@ -48,11 +54,10 @@ struct llama_kv_cache : public llama_memory_i {
48
  // find an empty slot of size "n_tokens" in the cache
49
  virtual bool find_slot(const llama_ubatch & batch) = 0;
50
 
 
 
51
  // getters
52
- virtual int32_t get_n_tokens() const = 0;
53
- virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
54
- virtual llama_pos get_pos_max() const = 0;
55
- virtual bool get_can_shift() const = 0;
56
 
57
  bool get_can_edit() const override { return get_can_shift(); }
58
 
@@ -87,38 +92,25 @@ private:
87
  // llama_kv_cache_unified
88
  //
89
 
90
- // TODO: add notion of max sequences
91
  class llama_kv_cache_unified : public llama_kv_cache {
92
  public:
93
- struct kv_cell {
94
- llama_pos pos = -1;
95
- llama_pos delta = 0;
96
-
97
- std::set<llama_seq_id> seq_id;
98
-
99
- bool has_seq_id(const llama_seq_id & id) const {
100
- return seq_id.find(id) != seq_id.end();
101
- }
102
-
103
- bool is_empty() const {
104
- return seq_id.empty();
105
- }
106
-
107
- bool is_same_seq(const kv_cell & other) const {
108
- return seq_id == other.seq_id;
109
- }
110
- };
111
-
112
  static uint32_t get_padding(const llama_cparams & cparams);
113
 
 
 
 
114
  llama_kv_cache_unified(
115
- const llama_model & model,
116
- ggml_type type_k,
117
- ggml_type type_v,
118
- bool v_trans,
119
- bool offload,
120
- uint32_t kv_size,
121
- uint32_t padding);
 
 
 
 
122
 
123
  ~llama_kv_cache_unified() = default;
124
 
@@ -130,10 +122,11 @@ public:
130
 
131
  bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
132
  void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
133
- void seq_keep(llama_seq_id seq_id) override;
134
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
135
  void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
136
 
 
137
  llama_pos seq_pos_max(llama_seq_id seq_id) const override;
138
 
139
  //
@@ -150,7 +143,6 @@ public:
150
  void set_full() override;
151
 
152
  llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
153
-
154
  llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
155
 
156
  // updates the cache head
@@ -158,50 +150,94 @@ public:
158
  // to the first cell of the slot.
159
  bool find_slot(const llama_ubatch & batch) override;
160
 
161
- int32_t get_n_tokens() const override;
162
- int32_t get_used_cells() const override;
163
-
164
- // TODO: better data structures to reduce the cost of this operation
165
- llama_pos get_pos_max() const override;
166
-
167
  bool get_can_shift() const override;
168
 
169
  // state write/load
170
 
171
  void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
172
- void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
173
 
174
- uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
175
- uint32_t size = 0; // total number of cells, shared across all sequences
176
- uint32_t used = 0; // used cells (i.e. at least one seq_id)
177
 
178
- // computed before each graph build
179
- uint32_t n = 0;
180
 
181
- std::vector<kv_cell> cells;
 
 
182
 
183
- std::vector<ggml_tensor *> k_l; // per layer
184
- std::vector<ggml_tensor *> v_l;
 
 
 
 
 
 
 
185
 
186
  private:
187
  const llama_model & model;
188
  const llama_hparams & hparams;
189
 
190
- bool has_shift = false;
191
- bool do_defrag = false;
 
 
 
 
 
 
192
 
 
193
  bool v_trans = true; // the value tensor is transposed
194
- bool can_shift = false;
 
 
 
 
 
 
 
195
 
196
  // required padding
197
- uint32_t padding = 1;
198
 
199
- ggml_type type_k = GGML_TYPE_F16;
200
- ggml_type type_v = GGML_TYPE_F16;
 
 
201
 
202
  std::vector<ggml_context_ptr> ctxs;
203
  std::vector<ggml_backend_buffer_ptr> bufs;
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  // defrag
206
  struct {
207
  std::vector<uint32_t> ids;
@@ -210,25 +246,13 @@ private:
210
  // return true if cells have been moved
211
  bool defrag_prepare(int32_t n_max_nodes);
212
 
213
- // commit/restore cache
214
- struct slot_range {
215
- uint32_t c0 = 0; // note: these are cell indices, not sequence positions
216
- uint32_t c1 = 0;
217
- };
218
-
219
- // pending cell updates that are not yet committed
220
- struct {
221
- std::vector<slot_range> ranges;
222
- } pending;
223
-
224
- // find how many cells are currently in use
225
- uint32_t cell_max() const;
226
-
227
  size_t total_size() const;
228
 
229
  size_t size_k_bytes() const;
230
  size_t size_v_bytes() const;
231
 
 
 
232
  ggml_tensor * build_rope_shift(
233
  const llama_cparams & cparams,
234
  ggml_context * ctx,
@@ -255,6 +279,100 @@ private:
255
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
256
  };
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  //
259
  // llama_kv_cache_recurrent
260
  //
@@ -286,7 +404,8 @@ public:
286
  ggml_type type_k,
287
  ggml_type type_v,
288
  bool offload,
289
- uint32_t kv_size);
 
290
 
291
  ~llama_kv_cache_recurrent() = default;
292
 
@@ -298,10 +417,11 @@ public:
298
 
299
  bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
300
  void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
301
- void seq_keep(llama_seq_id seq_id) override;
302
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
303
  void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
304
 
 
305
  llama_pos seq_pos_max(llama_seq_id seq_id) const override;
306
 
307
  //
@@ -311,24 +431,17 @@ public:
311
  void restore() override;
312
  void commit() override;
313
 
314
- bool update(llama_context & lctx) override;
315
 
316
  void defrag_sched(float thold) override;
317
 
318
  void set_full() override;
319
 
320
  llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
321
-
322
  llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
323
 
324
  bool find_slot(const llama_ubatch & batch) override;
325
 
326
- int32_t get_n_tokens() const override;
327
- int32_t get_used_cells() const override;
328
-
329
- // TODO: better data structures to reduce the cost of this operation
330
- llama_pos get_pos_max() const override;
331
-
332
  bool get_can_shift() const override;
333
 
334
  // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
@@ -368,8 +481,7 @@ private:
368
  std::vector<slot_range> ranges;
369
  } pending;
370
 
371
- ggml_type type_k = GGML_TYPE_F16;
372
- ggml_type type_v = GGML_TYPE_F16;
373
 
374
  std::vector<ggml_context_ptr> ctxs;
375
  std::vector<ggml_backend_buffer_ptr> bufs;
@@ -388,12 +500,3 @@ private:
388
  bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
389
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
390
  };
391
-
392
-
393
- //
394
- // kv cache view
395
- //
396
-
397
- llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
398
-
399
- void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);
 
4
  #include "llama-io.h"
5
  #include "llama-graph.h"
6
  #include "llama-memory.h"
7
+ #include "llama-kv-cells.h"
8
 
9
  #include "ggml-cpp.h"
10
 
11
  #include <set>
12
+ #include <unordered_map>
13
  #include <vector>
14
 
15
  struct llama_cparams;
 
36
  virtual void defrag_sched(float thold) = 0;
37
 
38
  // simulate full cache, used for allocating worst-case compute buffers
39
+ // TODO: remove
40
  virtual void set_full() = 0;
41
 
42
  //
43
  // batch processing
44
  //
45
 
46
+ // =============================================================================================================
47
+ // TODO: refactor and simplify this [TAG: KV_API]
48
+
49
  virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
50
 
51
  // different KV caches require different batch splitting strategies
 
54
  // find an empty slot of size "n_tokens" in the cache
55
  virtual bool find_slot(const llama_ubatch & batch) = 0;
56
 
57
+ // =============================================================================================================
58
+
59
  // getters
60
+ virtual bool get_can_shift() const = 0;
 
 
 
61
 
62
  bool get_can_edit() const override { return get_can_shift(); }
63
 
 
92
  // llama_kv_cache_unified
93
  //
94
 
 
95
  class llama_kv_cache_unified : public llama_kv_cache {
96
  public:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  static uint32_t get_padding(const llama_cparams & cparams);
98
 
99
+ // this callback is used to filter out layers that should not be included in the cache
100
+ using layer_filter_cb = std::function<bool(int32_t il)>;
101
+
102
  llama_kv_cache_unified(
103
+ const llama_model & model,
104
+ layer_filter_cb && filter,
105
+ ggml_type type_k,
106
+ ggml_type type_v,
107
+ bool v_trans,
108
+ bool offload,
109
+ uint32_t kv_size,
110
+ uint32_t n_seq_max,
111
+ uint32_t n_pad,
112
+ uint32_t n_swa,
113
+ llama_swa_type swa_type);
114
 
115
  ~llama_kv_cache_unified() = default;
116
 
 
122
 
123
  bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
124
  void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
125
+ void seq_keep(llama_seq_id seq_id) override;
126
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
127
  void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
128
 
129
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
130
  llama_pos seq_pos_max(llama_seq_id seq_id) const override;
131
 
132
  //
 
143
  void set_full() override;
144
 
145
  llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
 
146
  llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
147
 
148
  // updates the cache head
 
150
  // to the first cell of the slot.
151
  bool find_slot(const llama_ubatch & batch) override;
152
 
 
 
 
 
 
 
153
  bool get_can_shift() const override;
154
 
155
  // state write/load
156
 
157
  void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
158
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
159
 
160
+ //
161
+ // llama_kv_cache_unified specific API
162
+ //
163
 
164
+ uint32_t get_n() const;
165
+ uint32_t get_size() const;
166
 
167
+ // get views of the current state of the cache
168
+ ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
169
+ ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
170
 
171
+ // store k_cur and v_cur in the cache based on the current head location
172
+ ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
173
+ ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
174
+
175
+ void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax);
176
+
177
+ void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
178
+ void set_input_k_shift (ggml_tensor * dst) const;
179
+ void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
180
 
181
  private:
182
  const llama_model & model;
183
  const llama_hparams & hparams;
184
 
185
+ struct kv_layer {
186
+ // layer index in the model
187
+ // note: can be different from the layer index in the KV cache
188
+ uint32_t il;
189
+
190
+ ggml_tensor * k;
191
+ ggml_tensor * v;
192
+ };
193
 
194
+ bool do_defrag = false;
195
  bool v_trans = true; // the value tensor is transposed
196
+
197
+ uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
198
+
199
+ // computed before each graph build
200
+ // TODO: cells should start to maintain this value dynamically based on the edits
201
+ uint32_t n = 0;
202
+
203
+ const uint32_t n_seq_max = 1;
204
 
205
  // required padding
206
+ const uint32_t n_pad = 1;
207
 
208
+ // SWA
209
+ const uint32_t n_swa = 0;
210
+
211
+ const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
212
 
213
  std::vector<ggml_context_ptr> ctxs;
214
  std::vector<ggml_backend_buffer_ptr> bufs;
215
 
216
+ llama_kv_cells_unified cells;
217
+
218
+ std::vector<kv_layer> layers;
219
+
220
+ // model layer id -> KV cache layer id
221
+ std::unordered_map<int32_t, int32_t> map_layer_ids;
222
+
223
+ // recovery information used to restore the KV cells to their original state in case of a failure
224
+ // TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
225
+ // to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
226
+ struct {
227
+ void clear() {
228
+ states.clear();
229
+ }
230
+
231
+ struct state {
232
+ uint32_t i;
233
+
234
+ llama_kv_cells_unified cells;
235
+ };
236
+
237
+ // stack with the partial states before each ubatch
238
+ std::vector<state> states;
239
+ } recovery;
240
+
241
  // defrag
242
  struct {
243
  std::vector<uint32_t> ids;
 
246
  // return true if cells have been moved
247
  bool defrag_prepare(int32_t n_max_nodes);
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  size_t total_size() const;
250
 
251
  size_t size_k_bytes() const;
252
  size_t size_v_bytes() const;
253
 
254
+ bool is_masked_swa(llama_pos p0, llama_pos p1) const;
255
+
256
  ggml_tensor * build_rope_shift(
257
  const llama_cparams & cparams,
258
  ggml_context * ctx,
 
279
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
280
  };
281
 
282
+ //
283
+ // llama_kv_cache_unified_iswa
284
+ //
285
+
286
+ // utilizes two instances of llama_kv_cache_unified
287
+ // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
288
+ // upon successful commit, the SWA cache removes old tokens outside the n_swa window
289
+
290
+ class llama_kv_cache_unified_iswa : public llama_kv_cache {
291
+ public:
292
+ llama_kv_cache_unified_iswa(
293
+ const llama_model & model,
294
+ ggml_type type_k,
295
+ ggml_type type_v,
296
+ bool v_trans,
297
+ bool offload,
298
+ bool swa_full,
299
+ uint32_t kv_size,
300
+ uint32_t n_seq_max,
301
+ uint32_t n_batch,
302
+ uint32_t n_pad);
303
+
304
+ ~llama_kv_cache_unified_iswa() = default;
305
+
306
+ //
307
+ // llama_memory_i
308
+ //
309
+
310
+ void clear() override;
311
+
312
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
313
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
314
+ void seq_keep(llama_seq_id seq_id) override;
315
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
316
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
317
+
318
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
319
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
320
+
321
+ //
322
+ // llama_kv_cache
323
+ //
324
+
325
+ void restore() override;
326
+ void commit() override;
327
+
328
+ bool update(llama_context & ctx) override;
329
+
330
+ void defrag_sched(float thold) override;
331
+
332
+ void set_full() override;
333
+
334
+ llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
335
+ llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
336
+
337
+ bool find_slot(const llama_ubatch & batch) override;
338
+
339
+ bool get_can_shift() const override;
340
+
341
+ // state write/load
342
+
343
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
344
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
345
+
346
+ //
347
+ // llama_kv_cache_unified_iswa specific API
348
+ //
349
+
350
+ llama_kv_cache_unified * get_kv_base() const;
351
+ llama_kv_cache_unified * get_kv_swa () const;
352
+
353
+ private:
354
+ const llama_hparams & hparams;
355
+
356
+ bool do_prune = true;
357
+
358
+ struct {
359
+ struct entry {
360
+ llama_pos pmin;
361
+ llama_pos pmax;
362
+ };
363
+
364
+ void clear() {
365
+ pos.clear();
366
+ }
367
+
368
+ // used to perform SWA pruning of old tokens
369
+ std::unordered_map<llama_seq_id, entry> pos;
370
+ } pending;
371
+
372
+ std::unique_ptr<llama_kv_cache_unified> kv_base;
373
+ std::unique_ptr<llama_kv_cache_unified> kv_swa;
374
+ };
375
+
376
  //
377
  // llama_kv_cache_recurrent
378
  //
 
404
  ggml_type type_k,
405
  ggml_type type_v,
406
  bool offload,
407
+ uint32_t kv_size,
408
+ uint32_t n_seq_max);
409
 
410
  ~llama_kv_cache_recurrent() = default;
411
 
 
417
 
418
  bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
419
  void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
420
+ void seq_keep(llama_seq_id seq_id) override;
421
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
422
  void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
423
 
424
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
425
  llama_pos seq_pos_max(llama_seq_id seq_id) const override;
426
 
427
  //
 
431
  void restore() override;
432
  void commit() override;
433
 
434
+ bool update(llama_context & ctx) override;
435
 
436
  void defrag_sched(float thold) override;
437
 
438
  void set_full() override;
439
 
440
  llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
 
441
  llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
442
 
443
  bool find_slot(const llama_ubatch & batch) override;
444
 
 
 
 
 
 
 
445
  bool get_can_shift() const override;
446
 
447
  // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
 
481
  std::vector<slot_range> ranges;
482
  } pending;
483
 
484
+ const uint32_t n_seq_max = 1;
 
485
 
486
  std::vector<ggml_context_ptr> ctxs;
487
  std::vector<ggml_backend_buffer_ptr> bufs;
 
500
  bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
501
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
502
  };
 
 
 
 
 
 
 
 
 
examples/talk-llama/llama-kv-cells.h ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+ #include "llama-cparams.h"
5
+
6
+ #include <bitset>
7
+ #include <cassert>
8
+ #include <vector>
9
+ #include <set>
10
+
11
+ // meta information about KV cells that can be part of multiple sequences at the same time
12
+ // TODO: add unit tests
13
+ class llama_kv_cells_unified {
14
+ public:
15
+ void reset() {
16
+ for (uint32_t i = 0; i < pos.size(); ++i) {
17
+ pos[i] = -1;
18
+ shift[i] = 0;
19
+ seq[i].reset();
20
+ }
21
+
22
+ has_shift = false;
23
+
24
+ used.clear();
25
+
26
+ for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
27
+ seq_pos[s].clear();
28
+ }
29
+ }
30
+
31
+ void reset_shift() {
32
+ has_shift = false;
33
+
34
+ for (uint32_t i = 0; i < shift.size(); ++i) {
35
+ shift[i] = 0;
36
+ }
37
+ }
38
+
39
+ uint32_t size() const {
40
+ return pos.size();
41
+ }
42
+
43
+ void resize(uint32_t n) {
44
+ pos.resize(n);
45
+ shift.resize(n);
46
+ seq.resize(n);
47
+
48
+ reset();
49
+ }
50
+
51
+ bool is_empty(uint32_t i) const {
52
+ assert(i < pos.size());
53
+ assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
54
+
55
+ return pos[i] == -1;
56
+ }
57
+
58
+ uint32_t get_used() const {
59
+ return used.size();
60
+ }
61
+
62
+ // the index of the first cell that is used
63
+ // return 0 if no cells are used
64
+ uint32_t used_min() const {
65
+ return used.empty() ? 0 : *used.begin();
66
+ }
67
+
68
+ // the index of the last cell that is used + 1
69
+ // return 0 if no cells are used
70
+ uint32_t used_max_p1() const {
71
+ #if 0
72
+ if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
73
+ if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
74
+ if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
75
+ #endif
76
+
77
+ return used.empty() ? 0 : *used.rbegin() + 1;
78
+ }
79
+
80
+ bool get_has_shift() const {
81
+ return has_shift;
82
+ }
83
+
84
+ // move cell isrc to idst (used during defrag)
85
+ void mv(uint32_t isrc, uint32_t idst) {
86
+ assert(isrc < pos.size());
87
+ assert(idst < pos.size());
88
+
89
+ pos [idst] = pos [isrc];
90
+ shift[idst] = shift[isrc];
91
+ seq [idst] = seq [isrc];
92
+
93
+ pos [isrc] = -1;
94
+ shift[isrc] = 0;
95
+ seq [isrc].reset();
96
+
97
+ used.erase (isrc);
98
+ used.insert(idst);
99
+ }
100
+
101
+ // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
102
+ llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
103
+ assert(i + n <= pos.size());
104
+
105
+ llama_kv_cells_unified res;
106
+
107
+ res.resize(n);
108
+
109
+ for (uint32_t j = 0; j < n; ++j) {
110
+ res.pos[j] = pos[i + j];
111
+ res.seq[j] = seq[i + j];
112
+
113
+ assert(shift[i + j] == 0);
114
+ }
115
+
116
+ return res;
117
+ }
118
+
119
+ // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
120
+ void set(uint32_t i, const llama_kv_cells_unified & other) {
121
+ assert(i + other.pos.size() <= pos.size());
122
+
123
+ for (uint32_t j = 0; j < other.pos.size(); ++j) {
124
+ if (pos[i + j] == -1 && other.pos[j] != -1) {
125
+ used.insert(i + j);
126
+ }
127
+
128
+ if (pos[i + j] != -1 && other.pos[j] == -1) {
129
+ used.erase(i + j);
130
+ }
131
+
132
+ if (pos[i + j] != -1) {
133
+ seq_pos_rm(i + j);
134
+ }
135
+
136
+ pos[i + j] = other.pos[j];
137
+ seq[i + j] = other.seq[j];
138
+
139
+ if (pos[i + j] != -1) {
140
+ seq_pos_add(i + j);
141
+ }
142
+
143
+ assert(shift[i + j] == 0);
144
+ }
145
+ }
146
+
147
+ // note: call only if the cell has seq_id
148
+ // return true if the cell becomes empty
149
+ bool seq_rm(uint32_t i, llama_seq_id seq_id) {
150
+ assert(i < pos.size());
151
+ assert(seq[i].test(seq_id));
152
+ assert(pos[i] != -1);
153
+ assert(seq_id >= 0);
154
+
155
+ seq[i].reset(seq_id);
156
+ seq_pos[seq_id].erase(pos[i]);
157
+
158
+ if (seq[i].none()) {
159
+ pos[i] = -1;
160
+
161
+ used.erase(i);
162
+
163
+ return true;
164
+ }
165
+
166
+ return false;
167
+ }
168
+
169
+ // return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
170
+ bool seq_keep(uint32_t i, llama_seq_id seq_id) {
171
+ assert(i < pos.size());
172
+
173
+ if (seq[i].test(seq_id)) {
174
+ seq_pos_rm(i);
175
+ seq[i].reset();
176
+
177
+ seq[i].set(seq_id);
178
+ seq_pos[seq_id].insert(pos[i]);
179
+
180
+ return false;
181
+ }
182
+
183
+ if (seq[i].any()) {
184
+ seq_pos_rm(i);
185
+ seq[i].reset();
186
+
187
+ pos[i] = -1;
188
+
189
+ used.erase(i);
190
+
191
+ return true;
192
+ }
193
+
194
+ assert(pos[i] == -1);
195
+
196
+ return false;
197
+ }
198
+
199
+ bool seq_has(uint32_t i, llama_seq_id seq_id) const {
200
+ assert(i < pos.size());
201
+ assert(seq_id >= 0);
202
+
203
+ return seq[i].test(seq_id);
204
+ }
205
+
206
+ // note: call only if the cell is not empty and the seq_id is not in the cell
207
+ void seq_add(uint32_t i, llama_seq_id seq_id) {
208
+ assert(i < pos.size());
209
+ assert(pos[i] != -1);
210
+ assert(!seq[i].test(seq_id));
211
+
212
+ seq[i].set(seq_id);
213
+ seq_pos[seq_id].insert(pos[i]);
214
+ }
215
+
216
+ // the minimum position of sequence seq_id currently present in any of the cells
217
+ // return -1 if the sequence is not present
218
+ llama_pos seq_pos_min(llama_seq_id seq_id) const {
219
+ assert(seq_id >= 0);
220
+ assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
221
+
222
+ if (seq_pos[seq_id].empty()) {
223
+ return -1;
224
+ }
225
+
226
+ return *seq_pos[seq_id].begin();
227
+ }
228
+
229
+ // the maximum position of sequence seq_id currently present in any of the cells
230
+ // return -1 if the sequence is not present
231
+ llama_pos seq_pos_max(llama_seq_id seq_id) const {
232
+ assert(seq_id >= 0);
233
+ assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
234
+
235
+ if (seq_pos[seq_id].empty()) {
236
+ return -1;
237
+ }
238
+
239
+ return *seq_pos[seq_id].rbegin();
240
+ }
241
+
242
+ // note: call only if the cell is not empty
243
+ llama_pos pos_get(uint32_t i) const {
244
+ assert(i < pos.size());
245
+ assert(pos[i] != -1);
246
+
247
+ return pos[i];
248
+ }
249
+
250
+ // note: call only if the cell is not empty
251
+ llama_pos get_shift(uint32_t i) const {
252
+ assert(i < pos.size());
253
+ assert(pos[i] != -1);
254
+
255
+ return shift[i];
256
+ }
257
+
258
+ // check if a cell is not empty and its position is within [p0, p1)
259
+ bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
260
+ assert(i < pos.size());
261
+
262
+ return pos[i] >= p0 && pos[i] < p1;
263
+ }
264
+
265
+ // set the position of an empty cell
266
+ // does not modify "has_shift"
267
+ // note: call only if the cell is empty
268
+ void pos_set(uint32_t i, llama_pos p) {
269
+ assert(i < pos.size());
270
+ assert(pos[i] == -1);
271
+
272
+ pos[i] = p;
273
+
274
+ used.insert(i);
275
+ }
276
+
277
+ // pos[i] = pos[i] + d
278
+ // sets "has_shift" to true
279
+ // note: call only if the cell is not empty
280
+ bool pos_add(uint32_t i, llama_pos d) {
281
+ assert(i < pos.size());
282
+ assert(pos[i] != -1);
283
+
284
+ seq_pos_rm(i);
285
+
286
+ pos[i] += d;
287
+ shift[i] += d;
288
+
289
+ seq_pos_add(i);
290
+
291
+ has_shift = true;
292
+
293
+ if (pos[i] < 0) {
294
+ seq_pos_rm(i);
295
+
296
+ seq[i].reset();
297
+ pos[i] = -1;
298
+
299
+ used.erase(i);
300
+
301
+ return true;
302
+ }
303
+
304
+ return false;
305
+ }
306
+
307
+ // pos[i] = pos[i] / d
308
+ // sets "has_shift" to true
309
+ // note: call only if the cell is not empty
310
+ void pos_div(uint32_t i, int d) {
311
+ assert(i < pos.size());
312
+ assert(pos[i] != -1);
313
+
314
+ const llama_pos p_old = pos[i];
315
+
316
+ seq_pos_rm(i);
317
+
318
+ pos[i] /= d;
319
+ shift[i] += p_old - pos[i];
320
+
321
+ seq_pos_add(i);
322
+
323
+ has_shift = true;
324
+ }
325
+
326
+ private:
327
+ bool has_shift = false;
328
+
329
+ // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
330
+ std::set<uint32_t> used;
331
+
332
+ std::vector<llama_pos> pos;
333
+
334
+ // this array accumulates any applied shifts to the pos array since the last reset_shift() call
335
+ // this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
336
+ //
337
+ // cells.pos_add(x, shift_x);
338
+ // cells.pos_div(y, shift_y);
339
+ // ...
340
+ //
341
+ // if (cells.has_shift()) {
342
+ // for (int i = 0; i < n; ++i) {
343
+ // auto shift_i = cells.get_shift(i);
344
+ // ...
345
+ // }
346
+ // cells.reset_shift();
347
+ // }
348
+ //
349
+ std::vector<llama_pos> shift;
350
+
351
+ using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
352
+
353
+ // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
354
+ std::vector<bits_t> seq;
355
+
356
+ // the set seq_pos[s] tells us which positions are currently present for sequence s
357
+ // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
358
+ std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
359
+
360
+ // helper functions for updating `seq_pos`, once cell at a time:
361
+
362
+ // remove cell i
363
+ void seq_pos_rm(uint32_t i) {
364
+ for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
365
+ if (seq[i].test(s)) {
366
+ seq_pos[s].erase(pos[i]);
367
+ }
368
+ }
369
+ }
370
+
371
+ // add cell i
372
+ void seq_pos_add(uint32_t i) {
373
+ for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
374
+ if (seq[i].test(s)) {
375
+ seq_pos[s].insert(pos[i]);
376
+ }
377
+ }
378
+ }
379
+ };
examples/talk-llama/llama-memory.h CHANGED
@@ -7,8 +7,8 @@ struct llama_memory_params {
7
  ggml_type type_k;
8
  ggml_type type_v;
9
 
10
- // parameters for other types of memory
11
- // ...
12
  };
13
 
14
  // general concept of LLM memory
@@ -22,9 +22,10 @@ public:
22
  virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
23
  virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
24
  virtual void seq_keep(llama_seq_id seq_id) = 0;
25
- virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
26
  virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
27
 
 
28
  virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
29
 
30
  virtual bool get_can_edit() const = 0;
 
7
  ggml_type type_k;
8
  ggml_type type_v;
9
 
10
+ // use full-size SWA cache
11
+ bool swa_full;
12
  };
13
 
14
  // general concept of LLM memory
 
22
  virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
23
  virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
24
  virtual void seq_keep(llama_seq_id seq_id) = 0;
25
+ virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
26
  virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
27
 
28
+ virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
29
  virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
30
 
31
  virtual bool get_can_edit() const = 0;
examples/talk-llama/llama-model.cpp CHANGED
@@ -463,11 +463,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
463
  GGML_ASSERT(hparams.n_expert_used == 0);
464
  }
465
 
466
- // zero-out the array hparams
467
  std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
468
  std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
469
  std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
470
 
 
 
 
 
471
  ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false);
472
  ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
473
 
@@ -571,9 +574,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
571
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
572
  ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
573
  ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
574
- hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
575
- hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
576
- hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later
 
577
 
578
  switch (hparams.n_expert) {
579
  case 16: type = LLM_TYPE_17B_16E; break;
@@ -852,22 +856,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
852
  default: type = LLM_TYPE_UNKNOWN;
853
  }
854
 
855
- // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931
856
- if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) {
857
- // default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct
858
- hparams.n_swa = 2047;
859
- } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
860
- // default value for Phi-3-mini-128k-instruct
861
- // note: this seems incorrect because the window is bigger than the train context?
862
- hparams.n_swa = 262144;
863
- } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
864
- // default value for Phi-3-medium-128k-instruct
865
- // note: this seems incorrect because the window is equal to the train context?
866
- hparams.n_swa = 131072;
867
- }
868
- bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
869
- if (!found_swa && hparams.n_swa == 0) {
870
- throw std::runtime_error("invalid value for sliding_window");
871
  }
872
  } break;
873
  case LLM_ARCH_PHIMOE:
@@ -937,8 +936,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
937
  } break;
938
  case LLM_ARCH_GEMMA2:
939
  {
 
940
  hparams.n_swa = 4096; // default value of gemma 2
941
- hparams.n_swa_pattern = 2;
942
  hparams.attn_soft_cap = true;
943
 
944
  ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
@@ -955,7 +955,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
955
  } break;
956
  case LLM_ARCH_GEMMA3:
957
  {
958
- hparams.n_swa_pattern = 6;
 
959
 
960
  hparams.rope_freq_base_train_swa = 10000.0f;
961
  hparams.rope_freq_scale_train_swa = 1.0f;
@@ -1039,7 +1040,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1039
  } break;
1040
  case LLM_ARCH_COHERE2:
1041
  {
1042
- hparams.n_swa_pattern = 4;
 
1043
 
1044
  ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
1045
  ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
@@ -2487,7 +2489,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2487
 
2488
  // output
2489
  output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
2490
- output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
 
 
 
 
2491
 
2492
  for (int i = 0; i < n_layer; ++i) {
2493
  auto & layer = layers[i];
@@ -4321,7 +4327,7 @@ void llama_model::print_info() const {
4321
  LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
4322
  LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
4323
  LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
4324
- LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern);
4325
  LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
4326
  LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
4327
  LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
@@ -4489,7 +4495,17 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
4489
  return it->second;
4490
  }
4491
 
4492
- ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const {
 
 
 
 
 
 
 
 
 
 
4493
  // choose long/short freq factors based on the context size
4494
  if (layers[il].rope_freqs != nullptr) {
4495
  return layers[il].rope_freqs;
@@ -4517,21 +4533,174 @@ struct llm_build_llama : public llm_graph_context {
4517
  // inp_pos - contains the positions
4518
  ggml_tensor * inp_pos = build_inp_pos();
4519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4520
  // temperature tuning
4521
  ggml_tensor * inp_attn_scale = nullptr;
4522
- if (arch == LLM_ARCH_LLAMA4) {
4523
- inp_attn_scale = build_inp_attn_scale();
4524
- }
4525
 
4526
- auto * inp_attn = build_attn_inp_kv_unified();
4527
 
4528
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
4529
  for (int il = 0; il < n_layer; ++il) {
4530
  ggml_tensor * inpSA = inpL;
4531
 
4532
- bool use_rope = arch == LLM_ARCH_LLAMA4
4533
- ? (il + 1) % hparams.n_no_rope_layer_step != 0
4534
- : true;
4535
 
4536
  // norm
4537
  cur = build_norm(inpL,
@@ -4542,7 +4711,7 @@ struct llm_build_llama : public llm_graph_context {
4542
  // self-attention
4543
  {
4544
  // rope freq factors for llama3; may return nullptr for llama2 and other models
4545
- ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
4546
 
4547
  // compute Q and K and RoPE them
4548
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -4590,7 +4759,7 @@ struct llm_build_llama : public llm_graph_context {
4590
  cb(Kcur, "Kcur", il);
4591
  cb(Vcur, "Vcur", il);
4592
 
4593
- if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) {
4594
  // Llama4TextL2Norm
4595
  Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
4596
  Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
@@ -4616,7 +4785,6 @@ struct llm_build_llama : public llm_graph_context {
4616
 
4617
  // feed-forward network (non-MoE)
4618
  if (model.layers[il].ffn_gate_inp == nullptr) {
4619
-
4620
  cur = build_norm(ffn_inp,
4621
  model.layers[il].ffn_norm, NULL,
4622
  LLM_NORM_RMS, il);
@@ -4629,9 +4797,7 @@ struct llm_build_llama : public llm_graph_context {
4629
  NULL,
4630
  LLM_FFN_SILU, LLM_FFN_PAR, il);
4631
  cb(cur, "ffn_out", il);
4632
-
4633
- } else if (arch == LLM_ARCH_LLAMA4) {
4634
- // llama4 MoE
4635
  ggml_tensor * ffn_inp_normed = build_norm(ffn_inp,
4636
  model.layers[il].ffn_norm, NULL,
4637
  LLM_NORM_RMS, il);
@@ -4660,26 +4826,6 @@ struct llm_build_llama : public llm_graph_context {
4660
 
4661
  cur = ggml_add(ctx0, moe_out, shexp_out);
4662
  cb(cur, "ffn_moe_out_merged", il);
4663
-
4664
- } else {
4665
- // MoE branch
4666
- cur = build_norm(ffn_inp,
4667
- model.layers[il].ffn_norm, NULL,
4668
- LLM_NORM_RMS, il);
4669
- cb(cur, "ffn_norm", il);
4670
-
4671
- cur = build_moe_ffn(cur,
4672
- model.layers[il].ffn_gate_inp,
4673
- model.layers[il].ffn_up_exps,
4674
- model.layers[il].ffn_gate_exps,
4675
- model.layers[il].ffn_down_exps,
4676
- nullptr,
4677
- n_expert, n_expert_used,
4678
- LLM_FFN_SILU, true,
4679
- false, 0.0,
4680
- LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
4681
- il);
4682
- cb(cur, "ffn_moe_out", il);
4683
  }
4684
 
4685
  cur = ggml_add(ctx0, cur, ffn_inp);
@@ -4753,7 +4899,7 @@ struct llm_build_deci : public llm_graph_context {
4753
  } else if (n_head > 0) {
4754
  // self-attention
4755
  // rope freq factors for llama3; may return nullptr for llama2 and other models
4756
- ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
4757
 
4758
  // compute Q and K and RoPE them
4759
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -7202,6 +7348,7 @@ struct llm_build_phi2 : public llm_graph_context {
7202
  }
7203
  };
7204
 
 
7205
  struct llm_build_phi3 : public llm_graph_context {
7206
  llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
7207
  const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -7217,7 +7364,14 @@ struct llm_build_phi3 : public llm_graph_context {
7217
  // inp_pos - contains the positions
7218
  ggml_tensor * inp_pos = build_inp_pos();
7219
 
7220
- auto * inp_attn = build_attn_inp_kv_unified();
 
 
 
 
 
 
 
7221
 
7222
  for (int il = 0; il < n_layer; ++il) {
7223
  auto * residual = inpL;
@@ -7225,7 +7379,7 @@ struct llm_build_phi3 : public llm_graph_context {
7225
  // self-attention
7226
  {
7227
  // rope freq factors for 128k context
7228
- ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
7229
 
7230
  ggml_tensor* attn_norm_output = build_norm(inpL,
7231
  model.layers[il].attn_norm,
@@ -7977,7 +8131,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
7977
  for (int il = 0; il < n_layer; ++il) {
7978
  ggml_tensor * inpSA = inpL;
7979
 
7980
- ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
7981
 
7982
  // norm
7983
  cur = build_norm(inpL,
@@ -8277,8 +8431,8 @@ struct llm_build_gemma : public llm_graph_context {
8277
  }
8278
  };
8279
 
8280
- struct llm_build_gemma2 : public llm_graph_context {
8281
- llm_build_gemma2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8282
  const int64_t n_embd_head = hparams.n_embd_head_k;
8283
 
8284
  ggml_tensor * cur;
@@ -8292,7 +8446,7 @@ struct llm_build_gemma2 : public llm_graph_context {
8292
  // inp_pos - contains the positions
8293
  ggml_tensor * inp_pos = build_inp_pos();
8294
 
8295
- auto * inp_attn = build_attn_inp_kv_unified();
8296
 
8297
  for (int il = 0; il < n_layer; ++il) {
8298
  // norm
@@ -8414,8 +8568,8 @@ struct llm_build_gemma2 : public llm_graph_context {
8414
  }
8415
  };
8416
 
8417
- struct llm_build_gemma3 : public llm_graph_context {
8418
- llm_build_gemma3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8419
  const int64_t n_embd_head = hparams.n_embd_head_k;
8420
 
8421
  ggml_tensor * cur;
@@ -8433,13 +8587,11 @@ struct llm_build_gemma3 : public llm_graph_context {
8433
  ggml_tensor * inp_pos = build_inp_pos();
8434
 
8435
  // TODO: is causal == true correct? might need some changes
8436
- auto * inp_attn = build_attn_inp_kv_unified();
8437
 
8438
  for (int il = 0; il < n_layer; ++il) {
8439
- const bool is_swa = hparams.is_swa(il);
8440
-
8441
- const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
8442
- const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
8443
 
8444
  // norm
8445
  cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
@@ -9016,8 +9168,8 @@ struct llm_build_command_r : public llm_graph_context {
9016
  }
9017
  };
9018
 
9019
- struct llm_build_cohere2 : public llm_graph_context {
9020
- llm_build_cohere2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
9021
  const int64_t n_embd_head = hparams.n_embd_head_v;
9022
 
9023
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9032,7 +9184,7 @@ struct llm_build_cohere2 : public llm_graph_context {
9032
  // inp_pos - contains the positions
9033
  ggml_tensor * inp_pos = build_inp_pos();
9034
 
9035
- auto * inp_attn = build_attn_inp_kv_unified();
9036
 
9037
  for (int il = 0; il < n_layer; ++il) {
9038
  const bool is_swa = hparams.is_swa(il);
@@ -9045,7 +9197,7 @@ struct llm_build_cohere2 : public llm_graph_context {
9045
  // self-attention
9046
  {
9047
  // rope freq factors for 128k context
9048
- ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
9049
 
9050
  // compute Q and K and RoPE them
9051
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -9983,7 +10135,7 @@ struct llm_build_deepseek : public llm_graph_context {
9983
  // self-attention
9984
  {
9985
  // rope freq factors for llama3; may return nullptr for llama2 and other models
9986
- ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
9987
 
9988
  // compute Q and K and RoPE them
9989
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -11347,7 +11499,7 @@ struct llm_build_exaone : public llm_graph_context {
11347
  // self-attention
11348
  {
11349
  // rope freq factors for llama3; may return nullptr for llama2 and other models
11350
- ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
11351
 
11352
  // compute Q and K and RoPE them
11353
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -12263,7 +12415,7 @@ struct llm_build_granite : public llm_graph_context {
12263
  Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
12264
 
12265
  if (use_rope) {
12266
- ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
12267
  Qcur = ggml_rope_ext(
12268
  ctx0, Qcur, inp_pos, rope_factors,
12269
  n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -12916,7 +13068,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
12916
  // self-attention
12917
  {
12918
  // rope freq factors for llama3; may return nullptr for llama2 and other models
12919
- ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
12920
 
12921
  // compute Q and K and RoPE them
12922
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -13044,6 +13196,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13044
  case LLM_ARCH_JINA_BERT_V2:
13045
  case LLM_ARCH_NOMIC_BERT:
13046
  case LLM_ARCH_NOMIC_BERT_MOE:
 
13047
  {
13048
  res = nullptr;
13049
  } break;
@@ -13058,7 +13211,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13058
  GGML_TYPE_F32,
13059
  GGML_TYPE_F32,
13060
  cparams.offload_kqv,
13061
- std::max((uint32_t) 1, cparams.n_seq_max));
 
13062
  } break;
13063
  default:
13064
  {
@@ -13068,14 +13222,36 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13068
 
13069
  LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13070
 
13071
- res = new llama_kv_cache_unified(
13072
- *this,
13073
- params.type_k,
13074
- params.type_v,
13075
- !cparams.flash_attn,
13076
- cparams.offload_kqv,
13077
- cparams.n_ctx,
13078
- padding);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13079
  }
13080
  }
13081
 
@@ -13090,11 +13266,14 @@ llm_graph_result_ptr llama_model::build_graph(
13090
 
13091
  switch (arch) {
13092
  case LLM_ARCH_LLAMA:
13093
- case LLM_ARCH_LLAMA4:
13094
  case LLM_ARCH_MINICPM:
13095
  {
13096
  llm = std::make_unique<llm_build_llama>(*this, params, gf);
13097
  } break;
 
 
 
 
13098
  case LLM_ARCH_DECI:
13099
  {
13100
  llm = std::make_unique<llm_build_deci>(*this, params, gf);
@@ -13169,7 +13348,11 @@ llm_graph_result_ptr llama_model::build_graph(
13169
  case LLM_ARCH_PHI3:
13170
  case LLM_ARCH_PHIMOE:
13171
  {
13172
- llm = std::make_unique<llm_build_phi3>(*this, params, gf);
 
 
 
 
13173
  } break;
13174
  case LLM_ARCH_PLAMO:
13175
  {
@@ -13201,11 +13384,11 @@ llm_graph_result_ptr llama_model::build_graph(
13201
  } break;
13202
  case LLM_ARCH_GEMMA2:
13203
  {
13204
- llm = std::make_unique<llm_build_gemma2>(*this, params, gf);
13205
  } break;
13206
  case LLM_ARCH_GEMMA3:
13207
  {
13208
- llm = std::make_unique<llm_build_gemma3>(*this, params, gf);
13209
  } break;
13210
  case LLM_ARCH_STARCODER2:
13211
  {
@@ -13225,7 +13408,7 @@ llm_graph_result_ptr llama_model::build_graph(
13225
  } break;
13226
  case LLM_ARCH_COHERE2:
13227
  {
13228
- llm = std::make_unique<llm_build_cohere2>(*this, params, gf);
13229
  } break;
13230
  case LLM_ARCH_DBRX:
13231
  {
 
463
  GGML_ASSERT(hparams.n_expert_used == 0);
464
  }
465
 
 
466
  std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
467
  std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
468
  std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
469
 
470
+ std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
471
+
472
+ std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0);
473
+
474
  ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false);
475
  ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
476
 
 
574
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
575
  ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
576
  ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
577
+
578
+ hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
579
+ hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
580
+ hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
581
 
582
  switch (hparams.n_expert) {
583
  case 16: type = LLM_TYPE_17B_16E; break;
 
856
  default: type = LLM_TYPE_UNKNOWN;
857
  }
858
 
859
+ const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
860
+
861
+ if (found_swa && hparams.n_swa > 0) {
862
+ LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n",
863
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/13676");
864
+
865
+ // TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern`
866
+ hparams.swa_type = LLAMA_SWA_TYPE_NONE;
867
+
868
+ hparams.n_swa = 0;
869
+ hparams.set_swa_pattern(1);
 
 
 
 
 
870
  }
871
  } break;
872
  case LLM_ARCH_PHIMOE:
 
936
  } break;
937
  case LLM_ARCH_GEMMA2:
938
  {
939
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
940
  hparams.n_swa = 4096; // default value of gemma 2
941
+ hparams.set_swa_pattern(2);
942
  hparams.attn_soft_cap = true;
943
 
944
  ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
 
955
  } break;
956
  case LLM_ARCH_GEMMA3:
957
  {
958
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
959
+ hparams.set_swa_pattern(6);
960
 
961
  hparams.rope_freq_base_train_swa = 10000.0f;
962
  hparams.rope_freq_scale_train_swa = 1.0f;
 
1040
  } break;
1041
  case LLM_ARCH_COHERE2:
1042
  {
1043
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1044
+ hparams.set_swa_pattern(4);
1045
 
1046
  ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
1047
  ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
 
2489
 
2490
  // output
2491
  output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
2492
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
2493
+ // if output is NULL, init from the input tok embed
2494
+ if (output == NULL) {
2495
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
2496
+ }
2497
 
2498
  for (int i = 0; i < n_layer; ++i) {
2499
  auto & layer = layers[i];
 
4327
  LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
4328
  LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
4329
  LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
4330
+ LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
4331
  LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
4332
  LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
4333
  LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
 
4495
  return it->second;
4496
  }
4497
 
4498
+ float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const {
4499
+ return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
4500
+ }
4501
+
4502
+ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const {
4503
+ return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
4504
+ }
4505
+
4506
+ ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
4507
+ const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
4508
+
4509
  // choose long/short freq factors based on the context size
4510
  if (layers[il].rope_freqs != nullptr) {
4511
  return layers[il].rope_freqs;
 
4533
  // inp_pos - contains the positions
4534
  ggml_tensor * inp_pos = build_inp_pos();
4535
 
4536
+ auto * inp_attn = build_attn_inp_kv_unified();
4537
+
4538
+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
4539
+
4540
+ for (int il = 0; il < n_layer; ++il) {
4541
+ ggml_tensor * inpSA = inpL;
4542
+
4543
+ // norm
4544
+ cur = build_norm(inpL,
4545
+ model.layers[il].attn_norm, NULL,
4546
+ LLM_NORM_RMS, il);
4547
+ cb(cur, "attn_norm", il);
4548
+
4549
+ // self-attention
4550
+ {
4551
+ // rope freq factors for llama3; may return nullptr for llama2 and other models
4552
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
4553
+
4554
+ // compute Q and K and RoPE them
4555
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
4556
+ cb(Qcur, "Qcur", il);
4557
+ if (model.layers[il].bq) {
4558
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
4559
+ cb(Qcur, "Qcur", il);
4560
+ }
4561
+
4562
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
4563
+ cb(Kcur, "Kcur", il);
4564
+ if (model.layers[il].bk) {
4565
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
4566
+ cb(Kcur, "Kcur", il);
4567
+ }
4568
+
4569
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
4570
+ cb(Vcur, "Vcur", il);
4571
+ if (model.layers[il].bv) {
4572
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
4573
+ cb(Vcur, "Vcur", il);
4574
+ }
4575
+
4576
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
4577
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
4578
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
4579
+
4580
+ Qcur = ggml_rope_ext(
4581
+ ctx0, Qcur, inp_pos, rope_factors,
4582
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
4583
+ ext_factor, attn_factor, beta_fast, beta_slow
4584
+ );
4585
+
4586
+ Kcur = ggml_rope_ext(
4587
+ ctx0, Kcur, inp_pos, rope_factors,
4588
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
4589
+ ext_factor, attn_factor, beta_fast, beta_slow
4590
+ );
4591
+
4592
+ cb(Qcur, "Qcur", il);
4593
+ cb(Kcur, "Kcur", il);
4594
+ cb(Vcur, "Vcur", il);
4595
+
4596
+ cur = build_attn(inp_attn, gf,
4597
+ model.layers[il].wo, model.layers[il].bo,
4598
+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
4599
+ cb(cur, "attn_out", il);
4600
+ }
4601
+
4602
+ if (il == n_layer - 1) {
4603
+ // skip computing output for unused tokens
4604
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
4605
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
4606
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
4607
+ }
4608
+
4609
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
4610
+ cb(ffn_inp, "ffn_inp", il);
4611
+
4612
+ // feed-forward network (non-MoE)
4613
+ if (model.layers[il].ffn_gate_inp == nullptr) {
4614
+
4615
+ cur = build_norm(ffn_inp,
4616
+ model.layers[il].ffn_norm, NULL,
4617
+ LLM_NORM_RMS, il);
4618
+ cb(cur, "ffn_norm", il);
4619
+
4620
+ cur = build_ffn(cur,
4621
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
4622
+ model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
4623
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
4624
+ NULL,
4625
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
4626
+ cb(cur, "ffn_out", il);
4627
+ } else {
4628
+ // MoE branch
4629
+ cur = build_norm(ffn_inp,
4630
+ model.layers[il].ffn_norm, NULL,
4631
+ LLM_NORM_RMS, il);
4632
+ cb(cur, "ffn_norm", il);
4633
+
4634
+ cur = build_moe_ffn(cur,
4635
+ model.layers[il].ffn_gate_inp,
4636
+ model.layers[il].ffn_up_exps,
4637
+ model.layers[il].ffn_gate_exps,
4638
+ model.layers[il].ffn_down_exps,
4639
+ nullptr,
4640
+ n_expert, n_expert_used,
4641
+ LLM_FFN_SILU, true,
4642
+ false, 0.0,
4643
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
4644
+ il);
4645
+ cb(cur, "ffn_moe_out", il);
4646
+ }
4647
+
4648
+ cur = ggml_add(ctx0, cur, ffn_inp);
4649
+ cb(cur, "ffn_out", il);
4650
+
4651
+ cur = build_cvec(cur, il);
4652
+ cb(cur, "l_out", il);
4653
+
4654
+ // input for next layer
4655
+ inpL = cur;
4656
+ }
4657
+
4658
+ cur = inpL;
4659
+
4660
+ cur = build_norm(cur,
4661
+ model.output_norm, NULL,
4662
+ LLM_NORM_RMS, -1);
4663
+
4664
+ cb(cur, "result_norm", -1);
4665
+ res->t_embd = cur;
4666
+
4667
+ // lm_head
4668
+ cur = build_lora_mm(model.output, cur);
4669
+
4670
+ cb(cur, "result_output", -1);
4671
+ res->t_logits = cur;
4672
+
4673
+ ggml_build_forward_expand(gf, cur);
4674
+ }
4675
+ };
4676
+
4677
+ struct llm_build_llama_iswa : public llm_graph_context {
4678
+ llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
4679
+ const int64_t n_embd_head = hparams.n_embd_head_v;
4680
+
4681
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
4682
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
4683
+
4684
+ ggml_tensor * cur;
4685
+ ggml_tensor * inpL;
4686
+
4687
+ inpL = build_inp_embd(model.tok_embd);
4688
+
4689
+ // inp_pos - contains the positions
4690
+ ggml_tensor * inp_pos = build_inp_pos();
4691
+
4692
  // temperature tuning
4693
  ggml_tensor * inp_attn_scale = nullptr;
4694
+ inp_attn_scale = build_inp_attn_scale();
 
 
4695
 
4696
+ auto * inp_attn = build_attn_inp_kv_unified_iswa();
4697
 
4698
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
4699
+
4700
  for (int il = 0; il < n_layer; ++il) {
4701
  ggml_tensor * inpSA = inpL;
4702
 
4703
+ const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
 
 
4704
 
4705
  // norm
4706
  cur = build_norm(inpL,
 
4711
  // self-attention
4712
  {
4713
  // rope freq factors for llama3; may return nullptr for llama2 and other models
4714
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
4715
 
4716
  // compute Q and K and RoPE them
4717
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 
4759
  cb(Kcur, "Kcur", il);
4760
  cb(Vcur, "Vcur", il);
4761
 
4762
+ if (use_rope && hparams.use_kq_norm) {
4763
  // Llama4TextL2Norm
4764
  Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
4765
  Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
 
4785
 
4786
  // feed-forward network (non-MoE)
4787
  if (model.layers[il].ffn_gate_inp == nullptr) {
 
4788
  cur = build_norm(ffn_inp,
4789
  model.layers[il].ffn_norm, NULL,
4790
  LLM_NORM_RMS, il);
 
4797
  NULL,
4798
  LLM_FFN_SILU, LLM_FFN_PAR, il);
4799
  cb(cur, "ffn_out", il);
4800
+ } else {
 
 
4801
  ggml_tensor * ffn_inp_normed = build_norm(ffn_inp,
4802
  model.layers[il].ffn_norm, NULL,
4803
  LLM_NORM_RMS, il);
 
4826
 
4827
  cur = ggml_add(ctx0, moe_out, shexp_out);
4828
  cb(cur, "ffn_moe_out_merged", il);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4829
  }
4830
 
4831
  cur = ggml_add(ctx0, cur, ffn_inp);
 
4899
  } else if (n_head > 0) {
4900
  // self-attention
4901
  // rope freq factors for llama3; may return nullptr for llama2 and other models
4902
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
4903
 
4904
  // compute Q and K and RoPE them
4905
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 
7348
  }
7349
  };
7350
 
7351
+ template<bool iswa>
7352
  struct llm_build_phi3 : public llm_graph_context {
7353
  llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
7354
  const int64_t n_embd_head = hparams.n_embd_head_v;
 
7364
  // inp_pos - contains the positions
7365
  ggml_tensor * inp_pos = build_inp_pos();
7366
 
7367
+ using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
7368
+ inp_attn_type * inp_attn = nullptr;
7369
+
7370
+ if constexpr (iswa) {
7371
+ inp_attn = build_attn_inp_kv_unified_iswa();
7372
+ } else {
7373
+ inp_attn = build_attn_inp_kv_unified();
7374
+ }
7375
 
7376
  for (int il = 0; il < n_layer; ++il) {
7377
  auto * residual = inpL;
 
7379
  // self-attention
7380
  {
7381
  // rope freq factors for 128k context
7382
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
7383
 
7384
  ggml_tensor* attn_norm_output = build_norm(inpL,
7385
  model.layers[il].attn_norm,
 
8131
  for (int il = 0; il < n_layer; ++il) {
8132
  ggml_tensor * inpSA = inpL;
8133
 
8134
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
8135
 
8136
  // norm
8137
  cur = build_norm(inpL,
 
8431
  }
8432
  };
8433
 
8434
+ struct llm_build_gemma2_iswa : public llm_graph_context {
8435
+ llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8436
  const int64_t n_embd_head = hparams.n_embd_head_k;
8437
 
8438
  ggml_tensor * cur;
 
8446
  // inp_pos - contains the positions
8447
  ggml_tensor * inp_pos = build_inp_pos();
8448
 
8449
+ auto * inp_attn = build_attn_inp_kv_unified_iswa();
8450
 
8451
  for (int il = 0; il < n_layer; ++il) {
8452
  // norm
 
8568
  }
8569
  };
8570
 
8571
+ struct llm_build_gemma3_iswa : public llm_graph_context {
8572
+ llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8573
  const int64_t n_embd_head = hparams.n_embd_head_k;
8574
 
8575
  ggml_tensor * cur;
 
8587
  ggml_tensor * inp_pos = build_inp_pos();
8588
 
8589
  // TODO: is causal == true correct? might need some changes
8590
+ auto * inp_attn = build_attn_inp_kv_unified_iswa();
8591
 
8592
  for (int il = 0; il < n_layer; ++il) {
8593
+ const float freq_base_l = model.get_rope_freq_base (cparams, il);
8594
+ const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
 
 
8595
 
8596
  // norm
8597
  cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
 
9168
  }
9169
  };
9170
 
9171
+ struct llm_build_cohere2_iswa : public llm_graph_context {
9172
+ llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
9173
  const int64_t n_embd_head = hparams.n_embd_head_v;
9174
 
9175
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 
9184
  // inp_pos - contains the positions
9185
  ggml_tensor * inp_pos = build_inp_pos();
9186
 
9187
+ auto * inp_attn = build_attn_inp_kv_unified_iswa();
9188
 
9189
  for (int il = 0; il < n_layer; ++il) {
9190
  const bool is_swa = hparams.is_swa(il);
 
9197
  // self-attention
9198
  {
9199
  // rope freq factors for 128k context
9200
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
9201
 
9202
  // compute Q and K and RoPE them
9203
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 
10135
  // self-attention
10136
  {
10137
  // rope freq factors for llama3; may return nullptr for llama2 and other models
10138
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
10139
 
10140
  // compute Q and K and RoPE them
10141
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 
11499
  // self-attention
11500
  {
11501
  // rope freq factors for llama3; may return nullptr for llama2 and other models
11502
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
11503
 
11504
  // compute Q and K and RoPE them
11505
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 
12415
  Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
12416
 
12417
  if (use_rope) {
12418
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
12419
  Qcur = ggml_rope_ext(
12420
  ctx0, Qcur, inp_pos, rope_factors,
12421
  n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
 
13068
  // self-attention
13069
  {
13070
  // rope freq factors for llama3; may return nullptr for llama2 and other models
13071
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
13072
 
13073
  // compute Q and K and RoPE them
13074
  ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
 
13196
  case LLM_ARCH_JINA_BERT_V2:
13197
  case LLM_ARCH_NOMIC_BERT:
13198
  case LLM_ARCH_NOMIC_BERT_MOE:
13199
+ case LLM_ARCH_WAVTOKENIZER_DEC:
13200
  {
13201
  res = nullptr;
13202
  } break;
 
13211
  GGML_TYPE_F32,
13212
  GGML_TYPE_F32,
13213
  cparams.offload_kqv,
13214
+ std::max((uint32_t) 1, cparams.n_seq_max),
13215
+ cparams.n_seq_max);
13216
  } break;
13217
  default:
13218
  {
 
13222
 
13223
  LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13224
 
13225
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13226
+ GGML_ASSERT(hparams.is_swa_any());
13227
+
13228
+ res = new llama_kv_cache_unified_iswa(
13229
+ *this,
13230
+ params.type_k,
13231
+ params.type_v,
13232
+ !cparams.flash_attn,
13233
+ cparams.offload_kqv,
13234
+ params.swa_full,
13235
+ cparams.n_ctx,
13236
+ cparams.n_seq_max,
13237
+ cparams.n_batch,
13238
+ padding);
13239
+ } else {
13240
+ GGML_ASSERT(!hparams.is_swa_any());
13241
+
13242
+ res = new llama_kv_cache_unified(
13243
+ *this,
13244
+ nullptr,
13245
+ params.type_k,
13246
+ params.type_v,
13247
+ !cparams.flash_attn,
13248
+ cparams.offload_kqv,
13249
+ cparams.n_ctx,
13250
+ cparams.n_seq_max,
13251
+ padding,
13252
+ hparams.n_swa,
13253
+ hparams.swa_type);
13254
+ }
13255
  }
13256
  }
13257
 
 
13266
 
13267
  switch (arch) {
13268
  case LLM_ARCH_LLAMA:
 
13269
  case LLM_ARCH_MINICPM:
13270
  {
13271
  llm = std::make_unique<llm_build_llama>(*this, params, gf);
13272
  } break;
13273
+ case LLM_ARCH_LLAMA4:
13274
+ {
13275
+ llm = std::make_unique<llm_build_llama_iswa>(*this, params, gf);
13276
+ } break;
13277
  case LLM_ARCH_DECI:
13278
  {
13279
  llm = std::make_unique<llm_build_deci>(*this, params, gf);
 
13348
  case LLM_ARCH_PHI3:
13349
  case LLM_ARCH_PHIMOE:
13350
  {
13351
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13352
+ llm = std::make_unique<llm_build_phi3<true>> (*this, params, gf);
13353
+ } else {
13354
+ llm = std::make_unique<llm_build_phi3<false>>(*this, params, gf);
13355
+ }
13356
  } break;
13357
  case LLM_ARCH_PLAMO:
13358
  {
 
13384
  } break;
13385
  case LLM_ARCH_GEMMA2:
13386
  {
13387
+ llm = std::make_unique<llm_build_gemma2_iswa>(*this, params, gf);
13388
  } break;
13389
  case LLM_ARCH_GEMMA3:
13390
  {
13391
+ llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
13392
  } break;
13393
  case LLM_ARCH_STARCODER2:
13394
  {
 
13408
  } break;
13409
  case LLM_ARCH_COHERE2:
13410
  {
13411
+ llm = std::make_unique<llm_build_cohere2_iswa>(*this, params, gf);
13412
  } break;
13413
  case LLM_ARCH_DBRX:
13414
  {
examples/talk-llama/llama-model.h CHANGED
@@ -398,7 +398,10 @@ struct llama_model {
398
 
399
  const struct ggml_tensor * get_tensor(const char * name) const;
400
 
401
- ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const;
 
 
 
402
 
403
  // note: can mutate `cparams`
404
  // TODO: move this to new llm_arch_model_i interface
 
398
 
399
  const struct ggml_tensor * get_tensor(const char * name) const;
400
 
401
+ float get_rope_freq_base (const llama_cparams & cparams, int il) const;
402
+ float get_rope_freq_scale(const llama_cparams & cparams, int il) const;
403
+
404
+ ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
405
 
406
  // note: can mutate `cparams`
407
  // TODO: move this to new llm_arch_model_i interface
examples/talk-llama/llama-sampling.cpp CHANGED
@@ -798,7 +798,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
798
  }
799
 
800
  // if we have enough values the operation was a success
801
- if (filtered_tokens.size() >= ctx->min_keep) {
802
  memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
803
  cur_p->size = filtered_tokens.size();
804
  min_p_applied = true;
@@ -909,7 +909,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
909
  cum_sum += cur_p->data[idx].p;
910
 
911
  // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
912
- if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
913
  last_idx = i + 1;
914
  break;
915
  }
 
798
  }
799
 
800
  // if we have enough values the operation was a success
801
+ if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
802
  memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
803
  cur_p->size = filtered_tokens.size();
804
  min_p_applied = true;
 
909
  cum_sum += cur_p->data[idx].p;
910
 
911
  // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
912
+ if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) {
913
  last_idx = i + 1;
914
  break;
915
  }
examples/talk-llama/llama-vocab.cpp CHANGED
@@ -835,7 +835,7 @@ struct llm_tokenizer_ugm_session {
835
  }
836
 
837
  // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
838
- std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX});
839
  // at the beginning tokenization score is zero
840
  tokenization_results[0] = { vocab.token_unk(), 0, 0 };
841
 
@@ -867,7 +867,7 @@ struct llm_tokenizer_ugm_session {
867
  const double challenger_score = current_best.score_sum + token_score;
868
  struct best_tokenization & current_champ = tokenization_results[prefix_offset];
869
  if (challenger_score > current_champ.score_sum) {
870
- struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score };
871
  current_champ = challenger;
872
  }
873
  }
@@ -881,7 +881,7 @@ struct llm_tokenizer_ugm_session {
881
  prefix_offset = input_offset + n_utf8_code_units;
882
  struct best_tokenization & current_champ = tokenization_results[prefix_offset];
883
  if (challenger_score > current_champ.score_sum) {
884
- struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score };
885
  current_champ = challenger;
886
  }
887
  }
@@ -1007,7 +1007,7 @@ private:
1007
  struct best_tokenization {
1008
  llama_token token_id;
1009
  size_t input_offset;
1010
- float score_sum;
1011
  };
1012
 
1013
  struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
 
835
  }
836
 
837
  // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
838
+ std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX});
839
  // at the beginning tokenization score is zero
840
  tokenization_results[0] = { vocab.token_unk(), 0, 0 };
841
 
 
867
  const double challenger_score = current_best.score_sum + token_score;
868
  struct best_tokenization & current_champ = tokenization_results[prefix_offset];
869
  if (challenger_score > current_champ.score_sum) {
870
+ struct best_tokenization challenger = { token_id, input_offset, challenger_score };
871
  current_champ = challenger;
872
  }
873
  }
 
881
  prefix_offset = input_offset + n_utf8_code_units;
882
  struct best_tokenization & current_champ = tokenization_results[prefix_offset];
883
  if (challenger_score > current_champ.score_sum) {
884
+ struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score };
885
  current_champ = challenger;
886
  }
887
  }
 
1007
  struct best_tokenization {
1008
  llama_token token_id;
1009
  size_t input_offset;
1010
+ double score_sum;
1011
  };
1012
 
1013
  struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
examples/talk-llama/llama.h CHANGED
@@ -361,10 +361,11 @@ extern "C" {
361
 
362
  // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
363
  bool embeddings; // if true, extract embeddings (together with logits)
364
- bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
365
- bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
366
- bool no_perf; // whether to measure performance timings
367
- bool op_offload; // whether to offload host tensor operations to device
 
368
  };
369
 
370
  // model quantization parameters
@@ -470,6 +471,7 @@ extern "C" {
470
  LLAMA_API int64_t llama_time_us(void);
471
 
472
  LLAMA_API size_t llama_max_devices(void);
 
473
 
474
  LLAMA_API bool llama_supports_mmap (void);
475
  LLAMA_API bool llama_supports_mlock (void);
@@ -607,71 +609,14 @@ extern "C" {
607
  // KV cache
608
  //
609
 
610
- // TODO: start using struct llama_kv_cache
611
-
612
- // Information associated with an individual cell in the KV cache view.
613
- struct llama_kv_cache_view_cell {
614
- // The position for this cell. Takes KV cache shifts into account.
615
- // May be negative if the cell is not populated.
616
- llama_pos pos;
617
- };
618
-
619
- // An updateable view of the KV cache.
620
- struct llama_kv_cache_view {
621
- // Number of KV cache cells. This will be the same as the context size.
622
- int32_t n_cells;
623
-
624
- // Maximum number of sequences that can exist in a cell. It's not an error
625
- // if there are more sequences in a cell than this value, however they will
626
- // not be visible in the view cells_sequences.
627
- int32_t n_seq_max;
628
-
629
- // Number of tokens in the cache. For example, if there are two populated
630
- // cells, the first with 1 sequence id in it and the second with 2 sequence
631
- // ids then you'll have 3 tokens.
632
- int32_t token_count;
633
-
634
- // Number of populated cache cells.
635
- int32_t used_cells;
636
-
637
- // Maximum contiguous empty slots in the cache.
638
- int32_t max_contiguous;
639
-
640
- // Index to the start of the max_contiguous slot range. Can be negative
641
- // when cache is full.
642
- int32_t max_contiguous_idx;
643
-
644
- // Information for an individual cell.
645
- struct llama_kv_cache_view_cell * cells;
646
-
647
- // The sequences for each cell. There will be n_seq_max items per cell.
648
- llama_seq_id * cells_sequences;
649
- };
650
-
651
- // Create an empty KV cache view. (use only for debugging purposes)
652
- LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
653
-
654
- // Free a KV cache view. (use only for debugging purposes)
655
- LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
656
-
657
- // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
658
- // TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
659
- LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
660
-
661
- ///
662
-
663
  // Returns the number of tokens in the KV cache (slow, use only for debug)
664
  // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
665
- LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
666
-
667
- DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
668
- "use llama_kv_self_n_tokens instead");
669
 
670
  // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
671
- LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
672
-
673
- DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
674
- "use llama_kv_self_used_cells instead");
675
 
676
  // Clear the KV cache - both cell info is erased and KV data is zeroed
677
  LLAMA_API void llama_kv_self_clear(
@@ -730,10 +675,18 @@ extern "C" {
730
  llama_pos p1,
731
  int d);
732
 
 
 
 
 
 
 
 
733
  // Returns the largest position present in the KV cache for the specified sequence
 
734
  LLAMA_API llama_pos llama_kv_self_seq_pos_max(
735
  struct llama_context * ctx,
736
- llama_seq_id seq_id);
737
 
738
  // Defragment the KV cache
739
  // This will be applied:
@@ -747,61 +700,6 @@ extern "C" {
747
  // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
748
  LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
749
 
750
- DEPRECATED(LLAMA_API void llama_kv_cache_clear(
751
- struct llama_context * ctx),
752
- "use llama_kv_self_clear instead");
753
-
754
- DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
755
- struct llama_context * ctx,
756
- llama_seq_id seq_id,
757
- llama_pos p0,
758
- llama_pos p1),
759
- "use llama_kv_self_seq_rm instead");
760
-
761
- DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
762
- struct llama_context * ctx,
763
- llama_seq_id seq_id_src,
764
- llama_seq_id seq_id_dst,
765
- llama_pos p0,
766
- llama_pos p1),
767
- "use llama_kv_self_seq_cp instead");
768
-
769
- DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
770
- struct llama_context * ctx,
771
- llama_seq_id seq_id),
772
- "use llama_kv_self_seq_keep instead");
773
-
774
- DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
775
- struct llama_context * ctx,
776
- llama_seq_id seq_id,
777
- llama_pos p0,
778
- llama_pos p1,
779
- llama_pos delta),
780
- "use llama_kv_self_seq_add instead");
781
-
782
- DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
783
- struct llama_context * ctx,
784
- llama_seq_id seq_id,
785
- llama_pos p0,
786
- llama_pos p1,
787
- int d),
788
- "use llama_kv_self_seq_div instead");
789
-
790
- DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
791
- struct llama_context * ctx,
792
- llama_seq_id seq_id),
793
- "use llama_kv_self_seq_pos_max instead");
794
-
795
- DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
796
- "use llama_kv_self_defrag instead");
797
-
798
- DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
799
- "use llama_kv_self_can_shift instead");
800
-
801
- DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
802
- "use llama_kv_self_update instead");
803
-
804
-
805
  //
806
  // State / sessions
807
  //
@@ -943,9 +841,12 @@ extern "C" {
943
  // Requires KV cache.
944
  // For encode-decoder contexts, processes the batch using the decoder.
945
  // Positive return values does not mean a fatal error, but rather a warning.
946
- // 0 - success
947
- // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
948
- // < 0 - error. the KV cache state is restored to the state before this call
 
 
 
949
  LLAMA_API int32_t llama_decode(
950
  struct llama_context * ctx,
951
  struct llama_batch batch);
 
361
 
362
  // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
363
  bool embeddings; // if true, extract embeddings (together with logits)
364
+ bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
365
+ bool flash_attn; // use flash attention [EXPERIMENTAL]
366
+ bool no_perf; // measure performance timings
367
+ bool op_offload; // offload host tensor operations to device
368
+ bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
369
  };
370
 
371
  // model quantization parameters
 
471
  LLAMA_API int64_t llama_time_us(void);
472
 
473
  LLAMA_API size_t llama_max_devices(void);
474
+ LLAMA_API size_t llama_max_parallel_sequences(void);
475
 
476
  LLAMA_API bool llama_supports_mmap (void);
477
  LLAMA_API bool llama_supports_mlock (void);
 
609
  // KV cache
610
  //
611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  // Returns the number of tokens in the KV cache (slow, use only for debug)
613
  // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
614
+ DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
615
+ "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
 
 
616
 
617
  // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
618
+ DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
619
+ "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
 
 
620
 
621
  // Clear the KV cache - both cell info is erased and KV data is zeroed
622
  LLAMA_API void llama_kv_self_clear(
 
675
  llama_pos p1,
676
  int d);
677
 
678
+ // Returns the smallest position present in the KV cache for the specified sequence
679
+ // This is typically non-zero only for SWA caches
680
+ // Return -1 if the sequence is empty
681
+ LLAMA_API llama_pos llama_kv_self_seq_pos_min(
682
+ struct llama_context * ctx,
683
+ llama_seq_id seq_id);
684
+
685
  // Returns the largest position present in the KV cache for the specified sequence
686
+ // Return -1 if the sequence is empty
687
  LLAMA_API llama_pos llama_kv_self_seq_pos_max(
688
  struct llama_context * ctx,
689
+ llama_seq_id seq_id);
690
 
691
  // Defragment the KV cache
692
  // This will be applied:
 
700
  // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
701
  LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
703
  //
704
  // State / sessions
705
  //
 
841
  // Requires KV cache.
842
  // For encode-decoder contexts, processes the batch using the decoder.
843
  // Positive return values does not mean a fatal error, but rather a warning.
844
+ // Upon non-zero return values, the KV cache state is restored to the state before this call
845
+ // 0 - success
846
+ // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
847
+ // 2 - aborted
848
+ // -1 - invalid input batch
849
+ // < -1 - error
850
  LLAMA_API int32_t llama_decode(
851
  struct llama_context * ctx,
852
  struct llama_batch batch);