ggerganov commited on
Commit
5ef1601
·
1 Parent(s): 6ac9e73

talk-llama : sync llama.cpp

Browse files
examples/talk-llama/CMakeLists.txt CHANGED
@@ -16,7 +16,6 @@ if (WHISPER_SDL2)
16
  llama-hparams.cpp
17
  llama-impl.cpp
18
  llama-io.cpp
19
- llama-kv-cache.cpp
20
  llama-kv-cache-unified.cpp
21
  llama-kv-cache-unified-iswa.cpp
22
  llama-kv-cache-recurrent.cpp
 
16
  llama-hparams.cpp
17
  llama-impl.cpp
18
  llama-io.cpp
 
19
  llama-kv-cache-unified.cpp
20
  llama-kv-cache-unified-iswa.cpp
21
  llama-kv-cache-recurrent.cpp
examples/talk-llama/llama-arch.cpp CHANGED
@@ -200,7 +200,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
200
  { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
201
  { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
202
  { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
203
- { LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
204
  { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
205
  { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
206
  { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
@@ -1707,8 +1706,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
1707
  LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
1708
 
1709
  std::string LLM_KV::operator()(llm_kv kv) const {
1710
- return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
1711
- : ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
 
 
 
 
 
 
1712
  }
1713
 
1714
  std::string LLM_TN_IMPL::str() const {
 
200
  { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
201
  { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
202
  { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
 
203
  { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
204
  { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
205
  { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
 
1706
  LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
1707
 
1708
  std::string LLM_KV::operator()(llm_kv kv) const {
1709
+ std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
1710
+
1711
+ if (suffix != nullptr) {
1712
+ name += ".";
1713
+ name += suffix;
1714
+ }
1715
+
1716
+ return name;
1717
  }
1718
 
1719
  std::string LLM_TN_IMPL::str() const {
examples/talk-llama/llama-arch.h CHANGED
@@ -196,7 +196,6 @@ enum llm_kv {
196
  LLM_KV_TOKENIZER_HF_JSON,
197
  LLM_KV_TOKENIZER_RWKV,
198
  LLM_KV_TOKENIZER_CHAT_TEMPLATE,
199
- LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
200
  LLM_KV_TOKENIZER_FIM_PRE_ID,
201
  LLM_KV_TOKENIZER_FIM_SUF_ID,
202
  LLM_KV_TOKENIZER_FIM_MID_ID,
 
196
  LLM_KV_TOKENIZER_HF_JSON,
197
  LLM_KV_TOKENIZER_RWKV,
198
  LLM_KV_TOKENIZER_CHAT_TEMPLATE,
 
199
  LLM_KV_TOKENIZER_FIM_PRE_ID,
200
  LLM_KV_TOKENIZER_FIM_SUF_ID,
201
  LLM_KV_TOKENIZER_FIM_MID_ID,
examples/talk-llama/llama-context.cpp CHANGED
@@ -2,9 +2,9 @@
2
 
3
  #include "llama-impl.h"
4
  #include "llama-io.h"
 
5
  #include "llama-mmap.h"
6
  #include "llama-model.h"
7
- #include "llama-kv-cache.h"
8
 
9
  #include <cinttypes>
10
  #include <cstring>
@@ -123,7 +123,7 @@ llama_context::llama_context(
123
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
124
  }
125
 
126
- if (!params.swa_full && cparams.n_seq_max > 1) {
127
  LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
128
  __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
129
  }
@@ -277,10 +277,9 @@ llama_context::llama_context(
277
  int n_nodes_tg = -1;
278
 
279
  // simulate full KV cache
280
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
281
 
282
- const auto kv_state = kv_self->init_full();
283
- if (!kv_state) {
284
  throw std::runtime_error("failed to initialize KV cache");
285
  }
286
 
@@ -288,7 +287,7 @@ llama_context::llama_context(
288
 
289
  // reserve pp graph first so that buffers are only allocated once
290
  {
291
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
292
  if (!gf) {
293
  throw std::runtime_error("failed to allocate compute pp buffers");
294
  }
@@ -299,7 +298,7 @@ llama_context::llama_context(
299
 
300
  // reserve with tg graph to get the number of splits and nodes
301
  {
302
- auto * gf = graph_reserve(1, 1, 1, kv_state.get());
303
  if (!gf) {
304
  throw std::runtime_error("failed to allocate compute tg buffers");
305
  }
@@ -310,7 +309,7 @@ llama_context::llama_context(
310
 
311
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
312
  {
313
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
314
  if (!gf) {
315
  throw std::runtime_error("failed to allocate compute pp buffers");
316
  }
@@ -419,40 +418,68 @@ uint32_t llama_context::n_threads_batch() const {
419
  return cparams.n_threads_batch;
420
  }
421
 
422
- llama_kv_cache * llama_context::get_kv_self() {
423
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
424
- return kv_self;
425
  }
426
 
427
- const llama_kv_cache * llama_context::get_kv_self() const {
428
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
429
- return kv_self;
 
 
 
 
430
  }
431
 
432
- bool llama_context::kv_self_update() {
 
433
  if (!memory) {
434
  return false;
435
  }
436
 
437
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
 
 
 
438
 
439
- if (!kv_self->update(*this)) {
440
- // no updates have been performed
441
- return false;
442
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
- // if the KV cache did any computation, we have to reserve a new worst-case graph
445
- const auto kv_state = kv_self->init_full();
446
- if (!kv_state) {
447
- throw std::runtime_error("failed to initialize KV cache");
448
  }
449
 
450
- const uint32_t n_seqs = cparams.n_seq_max;
451
- const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
 
 
 
452
 
453
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
454
- if (!gf) {
455
- LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
 
 
 
 
456
  }
457
 
458
  return true;
@@ -814,16 +841,17 @@ int llama_context::encode(llama_batch & inp_batch) {
814
  } break;
815
  case LLAMA_POOLING_TYPE_RANK:
816
  {
817
- // extract the rerank score - a single float per sequence
818
  auto & embd_seq_out = embd_seq;
 
819
 
820
  for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
821
  const llama_seq_id seq_id = ubatch.seq_id[s][0];
822
  if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
823
  continue;
824
  }
825
- embd_seq_out[seq_id].resize(1);
826
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
827
  }
828
  } break;
829
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -880,10 +908,8 @@ int llama_context::decode(llama_batch & inp_batch) {
880
  }
881
  }
882
 
883
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
884
-
885
  // temporary allocate memory for the input batch if needed
886
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
887
 
888
  const llama_batch & batch = batch_allocr.batch;
889
 
@@ -940,42 +966,49 @@ int llama_context::decode(llama_batch & inp_batch) {
940
  n_outputs_all = 1;
941
  }
942
 
943
- // handle any pending defrags/shifts
944
- kv_self_update();
945
 
946
- llama_memory_state_ptr kv_state;
 
947
 
948
- bool did_defrag = false;
949
 
950
  while (true) {
951
- kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
952
- if (!kv_state) {
953
  return -2;
954
  }
955
 
956
- switch (kv_state->get_status()) {
957
  case LLAMA_MEMORY_STATUS_SUCCESS:
958
  {
959
  } break;
 
 
 
 
 
 
960
  case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
961
  {
962
- if (!did_defrag) {
963
- did_defrag = true;
964
 
965
- kv_self->defrag_sched(-1.0f);
966
- if (kv_self_update()) {
967
- LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
968
 
969
  continue;
970
  }
971
  }
972
 
973
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
974
 
975
  return 1;
976
  }
977
  case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
978
  {
 
 
979
  return -2;
980
  }
981
  }
@@ -992,7 +1025,7 @@ int llama_context::decode(llama_batch & inp_batch) {
992
  int64_t n_outputs_prev = 0;
993
 
994
  do {
995
- const auto & ubatch = kv_state->get_ubatch();
996
 
997
  // count the outputs in this u_batch
998
  {
@@ -1015,11 +1048,14 @@ int llama_context::decode(llama_batch & inp_batch) {
1015
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1016
 
1017
  ggml_status status;
1018
- const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
1019
 
1020
  if (!res) {
1021
  // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1022
- llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
 
 
 
1023
 
1024
  for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1025
  const auto & seq_id = ubatch.seq_id[i][0];
@@ -1034,7 +1070,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1034
 
1035
  LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1036
 
1037
- llama_kv_self_seq_rm(this, s, pos_min[s], -1);
1038
  }
1039
 
1040
  switch (status) {
@@ -1128,7 +1164,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1128
  }
1129
 
1130
  n_outputs_prev += n_outputs;
1131
- } while (kv_state->next());
1132
 
1133
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1134
  n_outputs = n_outputs_all;
@@ -1137,7 +1173,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1137
  {
1138
  bool sorted_output = true;
1139
 
1140
- auto & out_ids = kv_state->out_ids();
1141
 
1142
  GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1143
 
@@ -1189,11 +1225,6 @@ int llama_context::decode(llama_batch & inp_batch) {
1189
  // wait for the computation to finish (automatically done when obtaining the model output)
1190
  //synchronize();
1191
 
1192
- // decide if we need to defrag the kv cache
1193
- if (cparams.defrag_thold > 0.0f) {
1194
- kv_self->defrag_sched(cparams.defrag_thold);
1195
- }
1196
-
1197
  // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1198
  // overlap with device computation.
1199
  ggml_backend_sched_reset(sched.get());
@@ -1810,11 +1841,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1810
  }
1811
  }
1812
 
1813
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1814
-
1815
- if (kv_self != nullptr) {
1816
  LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1817
- kv_self->state_write(io);
1818
  }
1819
 
1820
  return io.n_bytes();
@@ -1901,9 +1930,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1901
  if (memory) {
1902
  LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1903
 
1904
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1905
-
1906
- kv_self->state_read(io);
1907
  }
1908
 
1909
  return io.n_bytes();
@@ -1913,9 +1940,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
1913
  GGML_UNUSED(seq_id);
1914
 
1915
  if (memory) {
1916
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1917
-
1918
- kv_self->state_write(io, seq_id);
1919
  }
1920
 
1921
  return io.n_bytes();
@@ -1925,9 +1950,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
1925
  GGML_UNUSED(seq_id);
1926
 
1927
  if (memory) {
1928
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1929
-
1930
- kv_self->state_read(io, seq_id);
1931
  }
1932
 
1933
  return io.n_bytes();
@@ -2032,9 +2055,7 @@ void llama_context::opt_epoch_iter(
2032
  const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
2033
  const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
2034
 
2035
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
2036
-
2037
- kv_self->clear();
2038
 
2039
  for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
2040
  batch.n_tokens = n_batch;
@@ -2057,8 +2078,8 @@ void llama_context::opt_epoch_iter(
2057
 
2058
  int64_t n_outputs_all = n_tokens_all;
2059
 
2060
- auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
2061
- if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2062
  LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2063
  break;
2064
  }
@@ -2071,17 +2092,17 @@ void llama_context::opt_epoch_iter(
2071
 
2072
  uint32_t pos_batch = 0;
2073
  do {
2074
- const auto & ubatch = kv_state->get_ubatch();
2075
 
2076
  n_outputs = ubatch.n_tokens;
2077
 
2078
- if (!kv_state->apply()) {
2079
  LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
2080
  break;
2081
  }
2082
 
2083
  auto * gf = graph_init();
2084
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
2085
 
2086
  struct ggml_context * ctx_compute_opt;
2087
  {
@@ -2116,7 +2137,7 @@ void llama_context::opt_epoch_iter(
2116
  ggml_free(ctx_compute_opt);
2117
 
2118
  pos_batch += ubatch.n_tokens;
2119
- } while (kv_state->next());
2120
  }
2121
  }
2122
 
@@ -2277,13 +2298,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
2277
  return &ctx->get_model();
2278
  }
2279
 
 
2280
  llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2281
- return ctx->get_kv_self();
2282
  }
2283
 
2284
  // deprecated
2285
  void llama_kv_self_update(llama_context * ctx) {
2286
- ctx->kv_self_update();
2287
  }
2288
 
2289
  enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
@@ -2398,13 +2420,118 @@ int32_t llama_apply_adapter_cvec(
2398
  return res ? 0 : -1;
2399
  }
2400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2401
  //
2402
  // kv cache
2403
  //
2404
 
2405
  // deprecated
2406
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2407
- const auto * kv = ctx->get_kv_self();
2408
  if (!kv) {
2409
  return 0;
2410
  }
@@ -2426,7 +2553,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2426
  // deprecated
2427
  // note: this is the same as above - will be removed anyway, so it's ok
2428
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2429
- const auto * kv = ctx->get_kv_self();
2430
  if (!kv) {
2431
  return 0;
2432
  }
@@ -2445,115 +2572,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2445
  return res;
2446
  }
2447
 
 
2448
  void llama_kv_self_clear(llama_context * ctx) {
2449
- auto * kv = ctx->get_kv_self();
2450
  if (!kv) {
2451
  return;
2452
  }
2453
 
2454
- kv->clear();
2455
  }
2456
 
 
2457
  bool llama_kv_self_seq_rm(
2458
  llama_context * ctx,
2459
  llama_seq_id seq_id,
2460
  llama_pos p0,
2461
  llama_pos p1) {
2462
- auto * kv = ctx->get_kv_self();
2463
  if (!kv) {
2464
  return true;
2465
  }
2466
 
2467
- return kv->seq_rm(seq_id, p0, p1);
2468
  }
2469
 
 
2470
  void llama_kv_self_seq_cp(
2471
  llama_context * ctx,
2472
  llama_seq_id seq_id_src,
2473
  llama_seq_id seq_id_dst,
2474
  llama_pos p0,
2475
  llama_pos p1) {
2476
- auto * kv = ctx->get_kv_self();
2477
  if (!kv) {
2478
  return;
2479
  }
2480
 
2481
- kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2482
  }
2483
 
 
2484
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2485
- auto * kv = ctx->get_kv_self();
2486
  if (!kv) {
2487
  return;
2488
  }
2489
 
2490
- kv->seq_keep(seq_id);
2491
  }
2492
 
 
2493
  void llama_kv_self_seq_add(
2494
  llama_context * ctx,
2495
  llama_seq_id seq_id,
2496
  llama_pos p0,
2497
  llama_pos p1,
2498
  llama_pos delta) {
2499
- auto * kv = ctx->get_kv_self();
2500
  if (!kv) {
2501
  return;
2502
  }
2503
 
2504
- kv->seq_add(seq_id, p0, p1, delta);
2505
  }
2506
 
 
2507
  void llama_kv_self_seq_div(
2508
  llama_context * ctx,
2509
  llama_seq_id seq_id,
2510
  llama_pos p0,
2511
  llama_pos p1,
2512
  int d) {
2513
- auto * kv = ctx->get_kv_self();
2514
  if (!kv) {
2515
  return;
2516
  }
2517
 
2518
- kv->seq_div(seq_id, p0, p1, d);
2519
  }
2520
 
 
2521
  llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2522
- const auto * kv = ctx->get_kv_self();
2523
  if (!kv) {
2524
  return -1;
2525
  }
2526
 
2527
- return kv->seq_pos_min(seq_id);
2528
  }
2529
 
 
2530
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2531
- const auto * kv = ctx->get_kv_self();
2532
  if (!kv) {
2533
  return -1;
2534
  }
2535
 
2536
- return kv->seq_pos_max(seq_id);
2537
  }
2538
 
2539
  // deprecated
2540
  void llama_kv_self_defrag(llama_context * ctx) {
2541
- auto * kv = ctx->get_kv_self();
2542
- if (!kv) {
2543
- return;
2544
- }
2545
-
2546
  // force defrag
2547
- kv->defrag_sched(-1.0f);
2548
  }
2549
 
 
2550
  bool llama_kv_self_can_shift(const llama_context * ctx) {
2551
- const auto * kv = ctx->get_kv_self();
2552
  if (!kv) {
2553
  return false;
2554
  }
2555
 
2556
- return kv->get_can_shift();
2557
  }
2558
 
2559
  // llama state API
 
2
 
3
  #include "llama-impl.h"
4
  #include "llama-io.h"
5
+ #include "llama-memory.h"
6
  #include "llama-mmap.h"
7
  #include "llama-model.h"
 
8
 
9
  #include <cinttypes>
10
  #include <cstring>
 
123
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
124
  }
125
 
126
+ if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
127
  LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
128
  __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
129
  }
 
277
  int n_nodes_tg = -1;
278
 
279
  // simulate full KV cache
 
280
 
281
+ const auto mstate = memory->init_full();
282
+ if (!mstate) {
283
  throw std::runtime_error("failed to initialize KV cache");
284
  }
285
 
 
287
 
288
  // reserve pp graph first so that buffers are only allocated once
289
  {
290
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
291
  if (!gf) {
292
  throw std::runtime_error("failed to allocate compute pp buffers");
293
  }
 
298
 
299
  // reserve with tg graph to get the number of splits and nodes
300
  {
301
+ auto * gf = graph_reserve(1, 1, 1, mstate.get());
302
  if (!gf) {
303
  throw std::runtime_error("failed to allocate compute tg buffers");
304
  }
 
309
 
310
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
311
  {
312
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
313
  if (!gf) {
314
  throw std::runtime_error("failed to allocate compute pp buffers");
315
  }
 
418
  return cparams.n_threads_batch;
419
  }
420
 
421
+ llama_memory_t llama_context::get_memory() const {
422
+ return memory.get();
 
423
  }
424
 
425
+ // deprecated
426
+ void llama_context::kv_self_defrag_sched() {
427
+ if (!memory) {
428
+ return;
429
+ }
430
+
431
+ memory_force_optimize = true;
432
  }
433
 
434
+ // deprecated
435
+ bool llama_context::kv_self_update(bool optimize) {
436
  if (!memory) {
437
  return false;
438
  }
439
 
440
+ {
441
+ // TODO: remove in the future
442
+ optimize |= memory_force_optimize;
443
+ memory_force_optimize = false;
444
 
445
+ const auto mstate = memory->init_update(this, optimize);
446
+ switch (mstate->get_status()) {
447
+ case LLAMA_MEMORY_STATUS_SUCCESS:
448
+ {
449
+ // noop
450
+ } break;
451
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
452
+ {
453
+ // no updates need to be performed
454
+ return false;
455
+ }
456
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
457
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
458
+ {
459
+ LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
460
+ return false;
461
+ }
462
+ }
463
 
464
+ if (!mstate->apply()) {
465
+ LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
466
+ }
 
467
  }
468
 
469
+ // if the memory module did any computation, we have to reserve a new worst-case graph
470
+ {
471
+ const auto mstate = memory->init_full();
472
+ if (!mstate) {
473
+ throw std::runtime_error("failed to initialize memory state");
474
+ }
475
 
476
+ const uint32_t n_seqs = cparams.n_seq_max;
477
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
478
+
479
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
480
+ if (!gf) {
481
+ LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
482
+ }
483
  }
484
 
485
  return true;
 
841
  } break;
842
  case LLAMA_POOLING_TYPE_RANK:
843
  {
844
+ // extract the rerank score - n_cls_out floats per sequence
845
  auto & embd_seq_out = embd_seq;
846
+ const uint32_t n_cls_out = hparams.n_cls_out;
847
 
848
  for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
849
  const llama_seq_id seq_id = ubatch.seq_id[s][0];
850
  if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
851
  continue;
852
  }
853
+ embd_seq_out[seq_id].resize(n_cls_out);
854
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
855
  }
856
  } break;
857
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
 
908
  }
909
  }
910
 
 
 
911
  // temporary allocate memory for the input batch if needed
912
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
913
 
914
  const llama_batch & batch = batch_allocr.batch;
915
 
 
966
  n_outputs_all = 1;
967
  }
968
 
969
+ bool did_optimize = false;
 
970
 
971
+ // handle any pending defrags/shifts
972
+ kv_self_update(false);
973
 
974
+ llama_memory_state_ptr mstate;
975
 
976
  while (true) {
977
+ mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
978
+ if (!mstate) {
979
  return -2;
980
  }
981
 
982
+ switch (mstate->get_status()) {
983
  case LLAMA_MEMORY_STATUS_SUCCESS:
984
  {
985
  } break;
986
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
987
+ {
988
+ LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
989
+
990
+ return -2;
991
+ }
992
  case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
993
  {
994
+ if (!did_optimize) {
995
+ did_optimize = true;
996
 
997
+ if (kv_self_update(true)) {
998
+ LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
 
999
 
1000
  continue;
1001
  }
1002
  }
1003
 
1004
+ LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
1005
 
1006
  return 1;
1007
  }
1008
  case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
1009
  {
1010
+ LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
1011
+
1012
  return -2;
1013
  }
1014
  }
 
1025
  int64_t n_outputs_prev = 0;
1026
 
1027
  do {
1028
+ const auto & ubatch = mstate->get_ubatch();
1029
 
1030
  // count the outputs in this u_batch
1031
  {
 
1048
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1049
 
1050
  ggml_status status;
1051
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
1052
 
1053
  if (!res) {
1054
  // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1055
+ llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
1056
+ for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
1057
+ pos_min[s] = std::numeric_limits<llama_pos>::max();
1058
+ }
1059
 
1060
  for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1061
  const auto & seq_id = ubatch.seq_id[i][0];
 
1070
 
1071
  LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1072
 
1073
+ memory->seq_rm(s, pos_min[s], -1);
1074
  }
1075
 
1076
  switch (status) {
 
1164
  }
1165
 
1166
  n_outputs_prev += n_outputs;
1167
+ } while (mstate->next());
1168
 
1169
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1170
  n_outputs = n_outputs_all;
 
1173
  {
1174
  bool sorted_output = true;
1175
 
1176
+ auto & out_ids = mstate->out_ids();
1177
 
1178
  GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1179
 
 
1225
  // wait for the computation to finish (automatically done when obtaining the model output)
1226
  //synchronize();
1227
 
 
 
 
 
 
1228
  // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1229
  // overlap with device computation.
1230
  ggml_backend_sched_reset(sched.get());
 
1841
  }
1842
  }
1843
 
1844
+ if (memory != nullptr) {
 
 
1845
  LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1846
+ memory->state_write(io);
1847
  }
1848
 
1849
  return io.n_bytes();
 
1930
  if (memory) {
1931
  LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1932
 
1933
+ memory->state_read(io);
 
 
1934
  }
1935
 
1936
  return io.n_bytes();
 
1940
  GGML_UNUSED(seq_id);
1941
 
1942
  if (memory) {
1943
+ memory->state_write(io, seq_id);
 
 
1944
  }
1945
 
1946
  return io.n_bytes();
 
1950
  GGML_UNUSED(seq_id);
1951
 
1952
  if (memory) {
1953
+ memory->state_read(io, seq_id);
 
 
1954
  }
1955
 
1956
  return io.n_bytes();
 
2055
  const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
2056
  const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
2057
 
2058
+ memory->clear(true);
 
 
2059
 
2060
  for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
2061
  batch.n_tokens = n_batch;
 
2078
 
2079
  int64_t n_outputs_all = n_tokens_all;
2080
 
2081
+ auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
2082
+ if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2083
  LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2084
  break;
2085
  }
 
2092
 
2093
  uint32_t pos_batch = 0;
2094
  do {
2095
+ const auto & ubatch = mstate->get_ubatch();
2096
 
2097
  n_outputs = ubatch.n_tokens;
2098
 
2099
+ if (!mstate->apply()) {
2100
  LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
2101
  break;
2102
  }
2103
 
2104
  auto * gf = graph_init();
2105
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
2106
 
2107
  struct ggml_context * ctx_compute_opt;
2108
  {
 
2137
  ggml_free(ctx_compute_opt);
2138
 
2139
  pos_batch += ubatch.n_tokens;
2140
+ } while (mstate->next());
2141
  }
2142
  }
2143
 
 
2298
  return &ctx->get_model();
2299
  }
2300
 
2301
+ // deprecated
2302
  llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2303
+ return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
2304
  }
2305
 
2306
  // deprecated
2307
  void llama_kv_self_update(llama_context * ctx) {
2308
+ ctx->kv_self_update(false);
2309
  }
2310
 
2311
  enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
 
2420
  return res ? 0 : -1;
2421
  }
2422
 
2423
+ //
2424
+ // memory
2425
+ //
2426
+
2427
+ llama_memory_t llama_get_memory(const struct llama_context * ctx) {
2428
+ return ctx->get_memory();
2429
+ }
2430
+
2431
+ void llama_memory_clear(llama_memory_t mem, bool data) {
2432
+ if (!mem) {
2433
+ return;
2434
+ }
2435
+
2436
+ mem->clear(data);
2437
+ }
2438
+
2439
+ bool llama_memory_seq_rm(
2440
+ llama_memory_t mem,
2441
+ llama_seq_id seq_id,
2442
+ llama_pos p0,
2443
+ llama_pos p1) {
2444
+ if (!mem) {
2445
+ return true;
2446
+ }
2447
+
2448
+ return mem->seq_rm(seq_id, p0, p1);
2449
+ }
2450
+
2451
+ void llama_memory_seq_cp(
2452
+ llama_memory_t mem,
2453
+ llama_seq_id seq_id_src,
2454
+ llama_seq_id seq_id_dst,
2455
+ llama_pos p0,
2456
+ llama_pos p1) {
2457
+ if (!mem) {
2458
+ return;
2459
+ }
2460
+
2461
+ mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2462
+ }
2463
+
2464
+ void llama_memory_seq_keep(
2465
+ llama_memory_t mem,
2466
+ llama_seq_id seq_id) {
2467
+ if (!mem) {
2468
+ return;
2469
+ }
2470
+
2471
+ mem->seq_keep(seq_id);
2472
+ }
2473
+
2474
+ void llama_memory_seq_add(
2475
+ llama_memory_t mem,
2476
+ llama_seq_id seq_id,
2477
+ llama_pos p0,
2478
+ llama_pos p1,
2479
+ llama_pos delta) {
2480
+ if (!mem) {
2481
+ return;
2482
+ }
2483
+
2484
+ mem->seq_add(seq_id, p0, p1, delta);
2485
+ }
2486
+
2487
+ void llama_memory_seq_div(
2488
+ llama_memory_t mem,
2489
+ llama_seq_id seq_id,
2490
+ llama_pos p0,
2491
+ llama_pos p1,
2492
+ int d) {
2493
+ if (!mem) {
2494
+ return;
2495
+ }
2496
+
2497
+ mem->seq_div(seq_id, p0, p1, d);
2498
+ }
2499
+
2500
+ llama_pos llama_memory_seq_pos_min(
2501
+ llama_memory_t mem,
2502
+ llama_seq_id seq_id) {
2503
+ if (!mem) {
2504
+ return -1;
2505
+ }
2506
+
2507
+ return mem->seq_pos_min(seq_id);
2508
+ }
2509
+
2510
+ llama_pos llama_memory_seq_pos_max(
2511
+ llama_memory_t mem,
2512
+ llama_seq_id seq_id) {
2513
+ if (!mem) {
2514
+ return -1;
2515
+ }
2516
+
2517
+ return mem->seq_pos_max(seq_id);
2518
+ }
2519
+
2520
+ bool llama_memory_can_shift(llama_memory_t mem) {
2521
+ if (!mem) {
2522
+ return false;
2523
+ }
2524
+
2525
+ return mem->get_can_shift();
2526
+ }
2527
+
2528
  //
2529
  // kv cache
2530
  //
2531
 
2532
  // deprecated
2533
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2534
+ const auto * kv = llama_get_memory(ctx);
2535
  if (!kv) {
2536
  return 0;
2537
  }
 
2553
  // deprecated
2554
  // note: this is the same as above - will be removed anyway, so it's ok
2555
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2556
+ const auto * kv = llama_get_memory(ctx);
2557
  if (!kv) {
2558
  return 0;
2559
  }
 
2572
  return res;
2573
  }
2574
 
2575
+ // deprecated
2576
  void llama_kv_self_clear(llama_context * ctx) {
2577
+ auto * kv = llama_get_memory(ctx);
2578
  if (!kv) {
2579
  return;
2580
  }
2581
 
2582
+ llama_memory_clear(kv, true);
2583
  }
2584
 
2585
+ // deprecated
2586
  bool llama_kv_self_seq_rm(
2587
  llama_context * ctx,
2588
  llama_seq_id seq_id,
2589
  llama_pos p0,
2590
  llama_pos p1) {
2591
+ auto * kv = llama_get_memory(ctx);
2592
  if (!kv) {
2593
  return true;
2594
  }
2595
 
2596
+ return llama_memory_seq_rm(kv, seq_id, p0, p1);
2597
  }
2598
 
2599
+ // deprecated
2600
  void llama_kv_self_seq_cp(
2601
  llama_context * ctx,
2602
  llama_seq_id seq_id_src,
2603
  llama_seq_id seq_id_dst,
2604
  llama_pos p0,
2605
  llama_pos p1) {
2606
+ auto * kv = llama_get_memory(ctx);
2607
  if (!kv) {
2608
  return;
2609
  }
2610
 
2611
+ llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
2612
  }
2613
 
2614
+ // deprecated
2615
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2616
+ auto * kv = llama_get_memory(ctx);
2617
  if (!kv) {
2618
  return;
2619
  }
2620
 
2621
+ llama_memory_seq_keep(kv, seq_id);
2622
  }
2623
 
2624
+ // deprecated
2625
  void llama_kv_self_seq_add(
2626
  llama_context * ctx,
2627
  llama_seq_id seq_id,
2628
  llama_pos p0,
2629
  llama_pos p1,
2630
  llama_pos delta) {
2631
+ auto * kv = llama_get_memory(ctx);
2632
  if (!kv) {
2633
  return;
2634
  }
2635
 
2636
+ llama_memory_seq_add(kv, seq_id, p0, p1, delta);
2637
  }
2638
 
2639
+ // deprecated
2640
  void llama_kv_self_seq_div(
2641
  llama_context * ctx,
2642
  llama_seq_id seq_id,
2643
  llama_pos p0,
2644
  llama_pos p1,
2645
  int d) {
2646
+ auto * kv = llama_get_memory(ctx);
2647
  if (!kv) {
2648
  return;
2649
  }
2650
 
2651
+ llama_memory_seq_div(kv, seq_id, p0, p1, d);
2652
  }
2653
 
2654
+ // deprecated
2655
  llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2656
+ auto * kv = llama_get_memory(ctx);
2657
  if (!kv) {
2658
  return -1;
2659
  }
2660
 
2661
+ return llama_memory_seq_pos_min(kv, seq_id);
2662
  }
2663
 
2664
+ // deprecated
2665
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2666
+ auto * kv = llama_get_memory(ctx);
2667
  if (!kv) {
2668
  return -1;
2669
  }
2670
 
2671
+ return llama_memory_seq_pos_max(kv, seq_id);
2672
  }
2673
 
2674
  // deprecated
2675
  void llama_kv_self_defrag(llama_context * ctx) {
 
 
 
 
 
2676
  // force defrag
2677
+ ctx->kv_self_defrag_sched();
2678
  }
2679
 
2680
+ // deprecated
2681
  bool llama_kv_self_can_shift(const llama_context * ctx) {
2682
+ auto * kv = llama_get_memory(ctx);
2683
  if (!kv) {
2684
  return false;
2685
  }
2686
 
2687
+ return llama_memory_can_shift(kv);
2688
  }
2689
 
2690
  // llama state API
examples/talk-llama/llama-context.h CHANGED
@@ -13,13 +13,12 @@
13
  #include <vector>
14
 
15
  struct llama_model;
16
- struct llama_kv_cache;
17
 
18
  class llama_io_read_i;
19
  class llama_io_write_i;
20
 
21
- class llama_memory_i;
22
- class llama_memory_state_i;
23
 
24
  struct llama_context {
25
  // init scheduler and compute buffers, reserve worst-case graphs
@@ -47,12 +46,12 @@ struct llama_context {
47
  uint32_t n_threads() const;
48
  uint32_t n_threads_batch() const;
49
 
50
- llama_kv_cache * get_kv_self();
51
- const llama_kv_cache * get_kv_self() const;
52
 
53
  // return true of the KV cache was updated
54
  // TODO: remove
55
- bool kv_self_update();
 
56
 
57
  enum llama_pooling_type pooling_type() const;
58
 
@@ -231,6 +230,9 @@ private:
231
 
232
  std::unique_ptr<llama_memory_i> memory;
233
 
 
 
 
234
  // decode output (2-dimensional array: [n_outputs][n_vocab])
235
  size_t logits_size = 0; // capacity (of floats) for logits
236
  float * logits = nullptr;
 
13
  #include <vector>
14
 
15
  struct llama_model;
 
16
 
17
  class llama_io_read_i;
18
  class llama_io_write_i;
19
 
20
+ struct llama_memory_i;
21
+ struct llama_memory_state_i;
22
 
23
  struct llama_context {
24
  // init scheduler and compute buffers, reserve worst-case graphs
 
46
  uint32_t n_threads() const;
47
  uint32_t n_threads_batch() const;
48
 
49
+ llama_memory_t get_memory() const;
 
50
 
51
  // return true of the KV cache was updated
52
  // TODO: remove
53
+ bool kv_self_update(bool optimize);
54
+ void kv_self_defrag_sched();
55
 
56
  enum llama_pooling_type pooling_type() const;
57
 
 
230
 
231
  std::unique_ptr<llama_memory_i> memory;
232
 
233
+ // TODO: temporary, until the llama_kv_self_defrag() API is removed
234
+ bool memory_force_optimize = false;
235
+
236
  // decode output (2-dimensional array: [n_outputs][n_vocab])
237
  size_t logits_size = 0; // capacity (of floats) for logits
238
  float * logits = nullptr;
examples/talk-llama/llama-graph.cpp CHANGED
@@ -659,6 +659,20 @@ ggml_tensor * llm_graph_context::build_ffn(
659
  cur = ggml_mul(ctx0, x0, x1);
660
  cb(cur, "ffn_mul", il);
661
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
  }
663
 
664
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -769,9 +783,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
769
  cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
770
 
771
  if (weight_before_ffn) {
772
- // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
773
- ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
774
- repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
775
  cur = ggml_mul(ctx0, repeated, weights);
776
  cb(cur, "ffn_moe_weighted", il);
777
  }
 
659
  cur = ggml_mul(ctx0, x0, x1);
660
  cb(cur, "ffn_mul", il);
661
  } break;
662
+ case LLM_FFN_GEGLU:
663
+ {
664
+ // Split into two equal parts
665
+ int64_t split_point = cur->ne[0] / 2;
666
+ // TODO: these conts should not be needed
667
+ ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
668
+ ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
669
+
670
+ x0 = ggml_gelu(ctx0, x0);
671
+ cb(x0, "ffn_gelu", il);
672
+
673
+ cur = ggml_mul(ctx0, x0, x1);
674
+ cb(cur, "ffn_geglu", il);
675
+ } break;
676
  }
677
 
678
  if (gate && type_gate == LLM_FFN_PAR) {
 
783
  cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
784
 
785
  if (weight_before_ffn) {
786
+ // repeat cur to [n_embd, n_expert_used, n_tokens]
787
+ ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
 
788
  cur = ggml_mul(ctx0, repeated, weights);
789
  cb(cur, "ffn_moe_weighted", il);
790
  }
examples/talk-llama/llama-graph.h CHANGED
@@ -17,7 +17,7 @@ struct ggml_tensor;
17
  struct llama_ubatch;
18
  struct llama_cparams;
19
 
20
- class llama_memory_state_i;
21
 
22
  class llama_kv_cache_unified_state;
23
  class llama_kv_cache_unified_iswa_state;
@@ -36,6 +36,7 @@ enum llm_ffn_op_type {
36
  LLM_FFN_RELU,
37
  LLM_FFN_RELU_SQR,
38
  LLM_FFN_SWIGLU,
 
39
  };
40
 
41
  enum llm_ffn_gate_type {
 
17
  struct llama_ubatch;
18
  struct llama_cparams;
19
 
20
+ struct llama_memory_state_i;
21
 
22
  class llama_kv_cache_unified_state;
23
  class llama_kv_cache_unified_iswa_state;
 
36
  LLM_FFN_RELU,
37
  LLM_FFN_RELU_SQR,
38
  LLM_FFN_SWIGLU,
39
+ LLM_FFN_GEGLU,
40
  };
41
 
42
  enum llm_ffn_gate_type {
examples/talk-llama/llama-kv-cache-recurrent.cpp CHANGED
@@ -1,6 +1,7 @@
1
  #include "llama-kv-cache-recurrent.h"
2
 
3
  #include "llama-impl.h"
 
4
  #include "llama-batch.h"
5
  #include "llama-model.h"
6
 
@@ -116,18 +117,21 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
116
  }
117
  }
118
 
119
- void llama_kv_cache_recurrent::clear() {
120
  for (int32_t i = 0; i < (int32_t) size; ++i) {
121
  cells[i].pos = -1;
122
  cells[i].seq_id.clear();
123
  cells[i].src = -1;
124
  cells[i].tail = -1;
125
  }
 
126
  head = 0;
127
  used = 0;
128
 
129
- for (auto & buf : bufs) {
130
- ggml_backend_buffer_clear(buf.get(), 0);
 
 
131
  }
132
  }
133
 
@@ -386,6 +390,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
386
  return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
387
  }
388
 
 
 
 
 
 
 
 
389
  bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
390
  // simply remember the full state because it is very small for this type of cache
391
  // TODO: optimize
@@ -419,17 +430,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
419
  return success;
420
  }
421
 
422
- bool llama_kv_cache_recurrent::update(llama_context & lctx) {
423
- GGML_UNUSED(lctx);
424
- // noop
425
- return false;
426
- }
427
-
428
- void llama_kv_cache_recurrent::defrag_sched(float thold) {
429
- GGML_UNUSED(thold);
430
- // noop
431
- }
432
-
433
  bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
434
  const uint32_t n_tokens = ubatch.n_tokens;
435
  const uint32_t n_seqs = ubatch.n_seqs;
@@ -726,7 +726,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
726
 
727
  if (!res) {
728
  if (seq_id == -1) {
729
- clear();
730
  } else {
731
  seq_rm(seq_id, -1, -1);
732
  }
@@ -883,7 +883,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
883
  return false;
884
  }
885
 
886
- clear();
887
 
888
  for (uint32_t i = 0; i < cell_count; ++i) {
889
  kv_cell & cell = cells[i];
 
1
  #include "llama-kv-cache-recurrent.h"
2
 
3
  #include "llama-impl.h"
4
+ #include "llama-io.h"
5
  #include "llama-batch.h"
6
  #include "llama-model.h"
7
 
 
117
  }
118
  }
119
 
120
+ void llama_kv_cache_recurrent::clear(bool data) {
121
  for (int32_t i = 0; i < (int32_t) size; ++i) {
122
  cells[i].pos = -1;
123
  cells[i].seq_id.clear();
124
  cells[i].src = -1;
125
  cells[i].tail = -1;
126
  }
127
+
128
  head = 0;
129
  used = 0;
130
 
131
+ if (data) {
132
+ for (auto & buf : bufs) {
133
+ ggml_backend_buffer_clear(buf.get(), 0);
134
+ }
135
  }
136
  }
137
 
 
390
  return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
391
  }
392
 
393
+ llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
394
+ GGML_UNUSED(lctx);
395
+ GGML_UNUSED(optimize);
396
+
397
+ return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
398
+ }
399
+
400
  bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
401
  // simply remember the full state because it is very small for this type of cache
402
  // TODO: optimize
 
430
  return success;
431
  }
432
 
 
 
 
 
 
 
 
 
 
 
 
433
  bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
434
  const uint32_t n_tokens = ubatch.n_tokens;
435
  const uint32_t n_seqs = ubatch.n_seqs;
 
726
 
727
  if (!res) {
728
  if (seq_id == -1) {
729
+ clear(true);
730
  } else {
731
  seq_rm(seq_id, -1, -1);
732
  }
 
883
  return false;
884
  }
885
 
886
+ clear(true);
887
 
888
  for (uint32_t i = 0; i < cell_count; ++i) {
889
  kv_cell & cell = cells[i];
examples/talk-llama/llama-kv-cache-recurrent.h CHANGED
@@ -2,7 +2,7 @@
2
 
3
  #include "llama-batch.h"
4
  #include "llama-graph.h"
5
- #include "llama-kv-cache.h"
6
 
7
  #include <set>
8
  #include <vector>
@@ -13,7 +13,7 @@
13
 
14
  // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
15
  // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
16
- class llama_kv_cache_recurrent : public llama_kv_cache {
17
  public:
18
  llama_kv_cache_recurrent(
19
  const llama_model & model,
@@ -29,21 +29,6 @@ public:
29
  // llama_memory_i
30
  //
31
 
32
- void clear() override;
33
-
34
- bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
35
- void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
36
- void seq_keep(llama_seq_id seq_id) override;
37
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
38
- void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
39
-
40
- llama_pos seq_pos_min(llama_seq_id seq_id) const override;
41
- llama_pos seq_pos_max(llama_seq_id seq_id) const override;
42
-
43
- //
44
- // llama_kv_cache
45
- //
46
-
47
  llama_memory_state_ptr init_batch(
48
  const llama_batch & batch,
49
  uint32_t n_ubatch,
@@ -52,9 +37,18 @@ public:
52
 
53
  llama_memory_state_ptr init_full() override;
54
 
55
- bool update(llama_context & lctx) override;
 
 
56
 
57
- void defrag_sched(float thold) override;
 
 
 
 
 
 
 
58
 
59
  bool prepare(const std::vector<llama_ubatch> & ubatches);
60
 
 
2
 
3
  #include "llama-batch.h"
4
  #include "llama-graph.h"
5
+ #include "llama-memory.h"
6
 
7
  #include <set>
8
  #include <vector>
 
13
 
14
  // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
15
  // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
16
+ class llama_kv_cache_recurrent : public llama_memory_i {
17
  public:
18
  llama_kv_cache_recurrent(
19
  const llama_model & model,
 
29
  // llama_memory_i
30
  //
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  llama_memory_state_ptr init_batch(
33
  const llama_batch & batch,
34
  uint32_t n_ubatch,
 
37
 
38
  llama_memory_state_ptr init_full() override;
39
 
40
+ llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
41
+
42
+ void clear(bool data) override;
43
 
44
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
45
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
46
+ void seq_keep(llama_seq_id seq_id) override;
47
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
48
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
49
+
50
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
51
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
52
 
53
  bool prepare(const std::vector<llama_ubatch> & ubatches);
54
 
examples/talk-llama/llama-kv-cache-unified-iswa.cpp CHANGED
@@ -52,9 +52,9 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
52
  hparams.n_swa, hparams.swa_type);
53
  }
54
 
55
- void llama_kv_cache_unified_iswa::clear() {
56
- kv_base->clear();
57
- kv_swa ->clear();
58
  }
59
 
60
  bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
123
 
124
  assert(heads_base.size() == heads_swa.size());
125
 
126
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
127
  this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
128
  }
129
 
130
  llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
131
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
132
  }
133
 
134
- bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
135
- bool res = false;
136
-
137
- res = res | kv_base->update(lctx);
138
- res = res | kv_swa ->update(lctx);
139
-
140
- return res;
141
- }
142
-
143
- void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
144
- kv_base->defrag_sched(thold);
145
- kv_swa ->defrag_sched(thold);
146
  }
147
 
148
  bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
174
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
175
 
176
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
177
- llama_memory_status status,
178
- llama_kv_cache_unified_iswa * kv) : status(status) {
179
- state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
180
- state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
 
 
 
 
 
 
 
 
 
 
 
181
  }
182
 
183
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
184
- llama_memory_status status,
185
  llama_kv_cache_unified_iswa * kv,
186
  llama_sbatch sbatch,
187
  std::vector<uint32_t> heads_base,
188
  std::vector<uint32_t> heads_swa,
189
  std::vector<llama_ubatch> ubatches)
190
- : status(status),
191
- sbatch(std::move(sbatch)),
192
- ubatches(std::move(ubatches)) {
193
- // note: here we copy the ubatches. not sure if this is ideal
194
- state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
195
- state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
196
- }
 
 
197
 
198
  llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
199
 
@@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
233
 
234
  const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
235
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
236
  return ubatches[i_next];
237
  }
238
 
239
  const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
240
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
241
 
242
- return state_base.get();
243
  }
244
 
245
  const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
246
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
247
 
248
- return state_swa.get();
249
  }
 
52
  hparams.n_swa, hparams.swa_type);
53
  }
54
 
55
+ void llama_kv_cache_unified_iswa::clear(bool data) {
56
+ kv_base->clear(data);
57
+ kv_swa ->clear(data);
58
  }
59
 
60
  bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
 
123
 
124
  assert(heads_base.size() == heads_swa.size());
125
 
126
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
127
  this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
128
  }
129
 
130
  llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
131
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
132
  }
133
 
134
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
135
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
 
 
 
 
 
 
 
 
 
 
136
  }
137
 
138
  bool llama_kv_cache_unified_iswa::get_can_shift() const {
 
164
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
165
 
166
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
167
+ llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
168
+ state_base = kv->get_base()->init_full();
169
+ state_swa = kv->get_swa ()->init_full();
170
+
171
+ status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
172
+ }
173
+
174
+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
175
+ llama_kv_cache_unified_iswa * kv,
176
+ llama_context * lctx,
177
+ bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
178
+ state_base = kv->get_base()->init_update(lctx, optimize);
179
+ state_swa = kv->get_swa ()->init_update(lctx, optimize);
180
+
181
+ status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
182
  }
183
 
184
  llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
 
185
  llama_kv_cache_unified_iswa * kv,
186
  llama_sbatch sbatch,
187
  std::vector<uint32_t> heads_base,
188
  std::vector<uint32_t> heads_swa,
189
  std::vector<llama_ubatch> ubatches)
190
+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
191
+ sbatch(std::move(sbatch)),
192
+ ubatches(std::move(ubatches)) {
193
+ // note: here we copy the ubatches. not sure if this is ideal
194
+ state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
195
+ state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
196
+
197
+ status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
198
+ }
199
 
200
  llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
201
 
 
235
 
236
  const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
237
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
238
+
239
  return ubatches[i_next];
240
  }
241
 
242
  const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
243
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
244
 
245
+ return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
246
  }
247
 
248
  const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
249
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
250
 
251
+ return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
252
  }
examples/talk-llama/llama-kv-cache-unified-iswa.h CHANGED
@@ -11,7 +11,7 @@
11
  // utilizes two instances of llama_kv_cache_unified
12
  // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
13
 
14
- class llama_kv_cache_unified_iswa : public llama_kv_cache {
15
  public:
16
  llama_kv_cache_unified_iswa(
17
  const llama_model & model,
@@ -31,21 +31,6 @@ public:
31
  // llama_memory_i
32
  //
33
 
34
- void clear() override;
35
-
36
- bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
37
- void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
38
- void seq_keep(llama_seq_id seq_id) override;
39
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
40
- void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
41
-
42
- llama_pos seq_pos_min(llama_seq_id seq_id) const override;
43
- llama_pos seq_pos_max(llama_seq_id seq_id) const override;
44
-
45
- //
46
- // llama_kv_cache
47
- //
48
-
49
  llama_memory_state_ptr init_batch(
50
  const llama_batch & batch,
51
  uint32_t n_ubatch,
@@ -54,12 +39,21 @@ public:
54
 
55
  llama_memory_state_ptr init_full() override;
56
 
57
- bool update(llama_context & lctx) override;
58
-
59
- void defrag_sched(float thold) override;
60
 
61
  bool get_can_shift() const override;
62
 
 
 
 
 
 
 
 
 
 
 
 
63
  // state write/load
64
 
65
  void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@@ -86,12 +80,16 @@ public:
86
 
87
  // used to create a full-cache state
88
  llama_kv_cache_unified_iswa_state(
89
- llama_memory_status status,
90
  llama_kv_cache_unified_iswa * kv);
91
 
 
 
 
 
 
 
92
  // used to create a state from a batch
93
  llama_kv_cache_unified_iswa_state(
94
- llama_memory_status status,
95
  llama_kv_cache_unified_iswa * kv,
96
  llama_sbatch sbatch,
97
  std::vector<uint32_t> heads_base,
@@ -120,7 +118,7 @@ public:
120
  const llama_kv_cache_unified_state * get_swa() const;
121
 
122
  private:
123
- const llama_memory_status status;
124
 
125
  //llama_kv_cache_unified_iswa * kv;
126
 
@@ -131,6 +129,6 @@ private:
131
 
132
  std::vector<llama_ubatch> ubatches;
133
 
134
- std::unique_ptr<llama_kv_cache_unified_state> state_base;
135
- std::unique_ptr<llama_kv_cache_unified_state> state_swa;
136
  };
 
11
  // utilizes two instances of llama_kv_cache_unified
12
  // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
13
 
14
+ class llama_kv_cache_unified_iswa : public llama_memory_i {
15
  public:
16
  llama_kv_cache_unified_iswa(
17
  const llama_model & model,
 
31
  // llama_memory_i
32
  //
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  llama_memory_state_ptr init_batch(
35
  const llama_batch & batch,
36
  uint32_t n_ubatch,
 
39
 
40
  llama_memory_state_ptr init_full() override;
41
 
42
+ llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
 
 
43
 
44
  bool get_can_shift() const override;
45
 
46
+ void clear(bool data) override;
47
+
48
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
49
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
50
+ void seq_keep(llama_seq_id seq_id) override;
51
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
52
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
53
+
54
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
55
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
56
+
57
  // state write/load
58
 
59
  void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
 
80
 
81
  // used to create a full-cache state
82
  llama_kv_cache_unified_iswa_state(
 
83
  llama_kv_cache_unified_iswa * kv);
84
 
85
+ // used to create an update state
86
+ llama_kv_cache_unified_iswa_state(
87
+ llama_kv_cache_unified_iswa * kv,
88
+ llama_context * lctx,
89
+ bool optimize);
90
+
91
  // used to create a state from a batch
92
  llama_kv_cache_unified_iswa_state(
 
93
  llama_kv_cache_unified_iswa * kv,
94
  llama_sbatch sbatch,
95
  std::vector<uint32_t> heads_base,
 
118
  const llama_kv_cache_unified_state * get_swa() const;
119
 
120
  private:
121
+ llama_memory_status status;
122
 
123
  //llama_kv_cache_unified_iswa * kv;
124
 
 
129
 
130
  std::vector<llama_ubatch> ubatches;
131
 
132
+ llama_memory_state_ptr state_base;
133
+ llama_memory_state_ptr state_swa;
134
  };
examples/talk-llama/llama-kv-cache-unified.cpp CHANGED
@@ -1,6 +1,7 @@
1
  #include "llama-kv-cache-unified.h"
2
 
3
  #include "llama-impl.h"
 
4
  #include "llama-model.h"
5
  #include "llama-context.h"
6
 
@@ -128,13 +129,15 @@ llama_kv_cache_unified::llama_kv_cache_unified(
128
  }
129
  }
130
 
131
- void llama_kv_cache_unified::clear() {
132
  cells.reset();
133
 
134
  head = 0;
135
 
136
- for (auto & buf : bufs) {
137
- ggml_backend_buffer_clear(buf.get(), 0);
 
 
138
  }
139
  }
140
 
@@ -149,12 +152,27 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
149
  p1 = std::numeric_limits<llama_pos>::max();
150
  }
151
 
152
- for (uint32_t i = 0; i < cells.size(); ++i) {
153
- if (!cells.pos_in(i, p0, p1)) {
154
- continue;
 
 
 
 
 
 
 
 
155
  }
 
 
 
 
 
 
 
 
156
 
157
- if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
158
  if (new_head == cells.size()) {
159
  new_head = i;
160
  }
@@ -305,16 +323,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
305
  return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
306
  }
307
 
308
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
309
  this, std::move(sbatch), std::move(heads), std::move(ubatches));
310
  }
311
 
312
  llama_memory_state_ptr llama_kv_cache_unified::init_full() {
313
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  }
315
 
316
- std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
317
- std::vector<uint32_t> res;
318
 
319
  struct state {
320
  uint32_t head_old; // old position of the head, before placing the ubatch
@@ -359,12 +410,12 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
359
  return res;
360
  }
361
 
362
- bool llama_kv_cache_unified::update(llama_context & lctx) {
363
  bool updated = false;
364
 
365
- auto * sched = lctx.get_sched();
366
 
367
- if (cells.get_has_shift()) {
368
  if (!get_can_shift()) {
369
  GGML_ABORT("The current KV cache / model configuration does not support K-shift");
370
  }
@@ -375,9 +426,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
375
  if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
376
  ggml_backend_sched_reset(sched);
377
 
378
- auto * gf = lctx.graph_init();
379
 
380
- auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
381
  if (!res) {
382
  LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
383
  return updated;
@@ -390,7 +441,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
390
 
391
  res->set_inputs(nullptr);
392
 
393
- if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
394
  LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
395
  return updated;
396
  }
@@ -401,54 +452,53 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
401
  cells.reset_shift();
402
  }
403
 
404
- if (do_defrag) {
405
  LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
406
 
407
- if (defrag_prepare(lctx.graph_max_nodes())) {
408
- ggml_backend_sched_reset(sched);
409
-
410
- auto * gf = lctx.graph_init();
411
-
412
- auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
413
- if (!res) {
414
- LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
415
- return updated;
416
- }
417
 
418
- if (!ggml_backend_sched_alloc_graph(sched, gf)) {
419
- LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
420
- return updated;
421
- }
422
 
423
- res->set_inputs(nullptr);
 
 
424
 
425
- if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
426
- LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
427
- return updated;
428
  }
429
 
430
- updated = true;
 
431
  }
432
 
433
- do_defrag = false;
434
- }
435
 
436
- return updated;
437
- }
438
 
439
- void llama_kv_cache_unified::defrag_sched(float thold) {
440
- const auto n_kv = cells.used_max_p1();
 
 
 
 
 
 
 
 
441
 
442
- // - do not defrag small contexts (i.e. < 2048 tokens)
443
- // - count the padding towards the number of used tokens
444
- const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
445
 
446
- // queue defragmentation for next llama_kv_cache_update
447
- if (fragmentation > thold) {
448
- LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
 
449
 
450
- do_defrag = true;
451
  }
 
 
452
  }
453
 
454
  int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
@@ -597,6 +647,10 @@ uint32_t llama_kv_cache_unified::get_size() const {
597
  return cells.size();
598
  }
599
 
 
 
 
 
600
  uint32_t llama_kv_cache_unified::get_n_kv() const {
601
  return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
602
  }
@@ -890,11 +944,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
890
  const auto & n_embd_head_k = hparams.n_embd_head_k;
891
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
892
 
893
- //GGML_ASSERT(kv_self->size == n_ctx);
894
-
895
  auto inp = std::make_unique<llm_graph_input_k_shift>(this);
896
 
897
- inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
898
  ggml_set_input(inp->k_shift);
899
 
900
  for (const auto & layer : layers) {
@@ -926,12 +978,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
926
  }
927
 
928
  llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
929
- const llama_cparams & cparams,
930
- ggml_context * ctx,
931
- ggml_cgraph * gf) const {
 
932
  auto res = std::make_unique<llm_graph_result>();
933
 
934
- const auto & ids = defrag_info.ids;
935
 
936
  #if 0
937
  // CPU defrag
@@ -1072,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
1072
  return res;
1073
  }
1074
 
1075
- bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1076
  const uint32_t n_layer = layers.size();
1077
 
1078
  const uint32_t n_kv = cells.used_max_p1();
@@ -1093,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1093
  const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
1094
 
1095
  // determine which KV cells to move where
1096
- //
1097
- // cell i moves to ids[i]
1098
- //
1099
- // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
1100
- //
1101
- auto & ids = defrag_info.ids;
1102
 
1103
- ids.clear();
1104
  ids.resize(n_kv, n_kv);
1105
 
1106
  for (uint32_t i0 = 0; i0 < n_used; ++i0) {
@@ -1164,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1164
  // this cell goes to (i0 + nf)
1165
  ids[i1] = i0 + nf;
1166
 
1167
- // move the cell meta data
1168
- cells.mv(i1, i0 + nf);
1169
-
1170
- head = n_used;
1171
-
1172
  if (!cont) {
1173
  n_moves++;
1174
  cont = true;
@@ -1191,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1191
  }
1192
 
1193
  if (n_moves == 0) {
1194
- return false;
1195
  }
1196
 
1197
  LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
1198
 
1199
  LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
1200
 
1201
- return true;
1202
  }
1203
 
1204
  bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
@@ -1276,7 +1319,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
1276
 
1277
  if (!res) {
1278
  if (seq_id == -1) {
1279
- clear();
1280
  } else {
1281
  seq_rm(seq_id, -1, -1);
1282
  }
@@ -1457,7 +1500,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1457
  return false;
1458
  }
1459
 
1460
- clear();
1461
 
1462
  for (uint32_t i = 0; i < cell_count; ++i) {
1463
  llama_pos pos;
@@ -1621,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1621
  llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
1622
 
1623
  llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1624
- llama_memory_status status,
1625
- llama_kv_cache_unified * kv) : status(status), kv(kv) {
1626
- n_kv = kv->get_size();
1627
- head = 0;
1628
- }
1629
 
1630
  llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1631
- llama_memory_status status,
1632
- llama_kv_cache_unified * kv,
1633
- llama_sbatch sbatch,
1634
- std::vector<uint32_t> heads,
1635
- std::vector<llama_ubatch> ubatches)
1636
- : status(status),
1637
- kv(kv),
1638
- sbatch(std::move(sbatch)),
1639
- heads(std::move(heads)),
1640
- ubatches(std::move(ubatches)) {
1641
  }
 
 
 
 
 
 
 
 
1642
 
1643
  llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
1644
 
@@ -1655,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() {
1655
  bool llama_kv_cache_unified_state::apply() {
1656
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1657
 
 
 
 
 
 
 
 
1658
  kv->apply_ubatch(heads[i_next], ubatches[i_next]);
1659
 
1660
  n_kv = kv->get_n_kv();
 
1
  #include "llama-kv-cache-unified.h"
2
 
3
  #include "llama-impl.h"
4
+ #include "llama-io.h"
5
  #include "llama-model.h"
6
  #include "llama-context.h"
7
 
 
129
  }
130
  }
131
 
132
+ void llama_kv_cache_unified::clear(bool data) {
133
  cells.reset();
134
 
135
  head = 0;
136
 
137
+ if (data) {
138
+ for (auto & buf : bufs) {
139
+ ggml_backend_buffer_clear(buf.get(), 0);
140
+ }
141
  }
142
  }
143
 
 
152
  p1 = std::numeric_limits<llama_pos>::max();
153
  }
154
 
155
+ if (seq_id >= 0) {
156
+ for (uint32_t i = 0; i < cells.size(); ++i) {
157
+ if (!cells.pos_in(i, p0, p1)) {
158
+ continue;
159
+ }
160
+
161
+ if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
162
+ if (new_head == cells.size()) {
163
+ new_head = i;
164
+ }
165
+ }
166
  }
167
+ } else {
168
+ // match any sequence
169
+ for (uint32_t i = 0; i < cells.size(); ++i) {
170
+ if (!cells.pos_in(i, p0, p1)) {
171
+ continue;
172
+ }
173
+
174
+ cells.rm(i);
175
 
 
176
  if (new_head == cells.size()) {
177
  new_head = i;
178
  }
 
323
  return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
324
  }
325
 
326
+ return std::make_unique<llama_kv_cache_unified_state>(
327
  this, std::move(sbatch), std::move(heads), std::move(ubatches));
328
  }
329
 
330
  llama_memory_state_ptr llama_kv_cache_unified::init_full() {
331
+ return std::make_unique<llama_kv_cache_unified_state>(this);
332
+ }
333
+
334
+ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
335
+ bool do_shift = get_has_shift();
336
+
337
+ defrag_info dinfo;
338
+
339
+ // see if we need to defrag
340
+ {
341
+ bool do_defrag = optimize;
342
+
343
+ const auto thold = lctx->get_cparams().defrag_thold;
344
+
345
+ if (!do_defrag && thold > 0.0f) {
346
+ const auto n_kv = cells.used_max_p1();
347
+
348
+ // - do not defrag small contexts (i.e. < 2048 tokens)
349
+ // - count the padding towards the number of used tokens
350
+ const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
351
+
352
+ if (fragmentation > thold) {
353
+ LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
354
+
355
+ do_defrag = true;
356
+ }
357
+ }
358
+
359
+ if (do_defrag) {
360
+ dinfo = defrag_prepare(lctx->graph_max_nodes());
361
+ }
362
+ }
363
+
364
+ return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
365
  }
366
 
367
+ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
368
+ llama_kv_cache_unified::ubatch_heads res;
369
 
370
  struct state {
371
  uint32_t head_old; // old position of the head, before placing the ubatch
 
410
  return res;
411
  }
412
 
413
+ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
414
  bool updated = false;
415
 
416
+ auto * sched = lctx->get_sched();
417
 
418
+ if (do_shift) {
419
  if (!get_can_shift()) {
420
  GGML_ABORT("The current KV cache / model configuration does not support K-shift");
421
  }
 
426
  if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
427
  ggml_backend_sched_reset(sched);
428
 
429
+ auto * gf = lctx->graph_init();
430
 
431
+ auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
432
  if (!res) {
433
  LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
434
  return updated;
 
441
 
442
  res->set_inputs(nullptr);
443
 
444
+ if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
445
  LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
446
  return updated;
447
  }
 
452
  cells.reset_shift();
453
  }
454
 
455
+ if (!dinfo.empty()) {
456
  LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
457
 
458
+ // apply moves:
459
+ {
460
+ const auto n_kv = dinfo.ids.size();
 
 
 
 
 
 
 
461
 
462
+ for (uint32_t i = 0; i < n_kv; ++i) {
463
+ assert(dinfo.ids[i] <= n_kv);
 
 
464
 
465
+ if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
466
+ continue;
467
+ }
468
 
469
+ cells.mv(i, dinfo.ids[i]);
 
 
470
  }
471
 
472
+ // reset the head so we can find the first free slot during the next ubatch
473
+ head = 0;
474
  }
475
 
476
+ ggml_backend_sched_reset(sched);
 
477
 
478
+ auto * gf = lctx->graph_init();
 
479
 
480
+ auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
481
+ if (!res) {
482
+ LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
483
+ return updated;
484
+ }
485
+
486
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
487
+ LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
488
+ return updated;
489
+ }
490
 
491
+ res->set_inputs(nullptr);
 
 
492
 
493
+ if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
494
+ LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
495
+ return updated;
496
+ }
497
 
498
+ updated = true;
499
  }
500
+
501
+ return updated;
502
  }
503
 
504
  int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
 
647
  return cells.size();
648
  }
649
 
650
+ bool llama_kv_cache_unified::get_has_shift() const {
651
+ return cells.get_has_shift();
652
+ }
653
+
654
  uint32_t llama_kv_cache_unified::get_n_kv() const {
655
  return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
656
  }
 
944
  const auto & n_embd_head_k = hparams.n_embd_head_k;
945
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
946
 
 
 
947
  auto inp = std::make_unique<llm_graph_input_k_shift>(this);
948
 
949
+ inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
950
  ggml_set_input(inp->k_shift);
951
 
952
  for (const auto & layer : layers) {
 
978
  }
979
 
980
  llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
981
+ const llama_cparams & cparams,
982
+ ggml_context * ctx,
983
+ ggml_cgraph * gf,
984
+ const defrag_info & dinfo) const {
985
  auto res = std::make_unique<llm_graph_result>();
986
 
987
+ const auto & ids = dinfo.ids;
988
 
989
  #if 0
990
  // CPU defrag
 
1125
  return res;
1126
  }
1127
 
1128
+ llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
1129
  const uint32_t n_layer = layers.size();
1130
 
1131
  const uint32_t n_kv = cells.used_max_p1();
 
1146
  const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
1147
 
1148
  // determine which KV cells to move where
1149
+ defrag_info res;
1150
+ auto & ids = res.ids;
 
 
 
 
1151
 
 
1152
  ids.resize(n_kv, n_kv);
1153
 
1154
  for (uint32_t i0 = 0; i0 < n_used; ++i0) {
 
1212
  // this cell goes to (i0 + nf)
1213
  ids[i1] = i0 + nf;
1214
 
 
 
 
 
 
1215
  if (!cont) {
1216
  n_moves++;
1217
  cont = true;
 
1234
  }
1235
 
1236
  if (n_moves == 0) {
1237
+ return {};
1238
  }
1239
 
1240
  LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
1241
 
1242
  LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
1243
 
1244
+ return res;
1245
  }
1246
 
1247
  bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
 
1319
 
1320
  if (!res) {
1321
  if (seq_id == -1) {
1322
+ clear(true);
1323
  } else {
1324
  seq_rm(seq_id, -1, -1);
1325
  }
 
1500
  return false;
1501
  }
1502
 
1503
+ clear(true);
1504
 
1505
  for (uint32_t i = 0; i < cell_count; ++i) {
1506
  llama_pos pos;
 
1664
  llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
1665
 
1666
  llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1667
+ llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1668
+ n_kv = kv->get_size();
1669
+ head = 0;
1670
+ }
 
1671
 
1672
  llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1673
+ llama_kv_cache_unified * kv,
1674
+ llama_context * lctx,
1675
+ bool do_shift,
1676
+ defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
1677
+ if (!do_shift && dinfo.empty()) {
1678
+ status = LLAMA_MEMORY_STATUS_NO_UPDATE;
 
 
 
 
1679
  }
1680
+ }
1681
+
1682
+ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1683
+ llama_kv_cache_unified * kv,
1684
+ llama_sbatch sbatch,
1685
+ llama_kv_cache_unified::ubatch_heads heads,
1686
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
1687
+ }
1688
 
1689
  llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
1690
 
 
1701
  bool llama_kv_cache_unified_state::apply() {
1702
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1703
 
1704
+ // no ubatches -> this is a KV cache update
1705
+ if (ubatches.empty()) {
1706
+ kv->update(lctx, do_shift, dinfo);
1707
+
1708
+ return true;
1709
+ }
1710
+
1711
  kv->apply_ubatch(heads[i_next], ubatches[i_next]);
1712
 
1713
  n_kv = kv->get_n_kv();
examples/talk-llama/llama-kv-cache-unified.h CHANGED
@@ -2,8 +2,8 @@
2
 
3
  #include "llama-batch.h"
4
  #include "llama-graph.h"
5
- #include "llama-kv-cache.h"
6
  #include "llama-kv-cells.h"
 
7
 
8
  #include <unordered_map>
9
  #include <vector>
@@ -17,13 +17,26 @@ struct llama_context;
17
  // llama_kv_cache_unified
18
  //
19
 
20
- class llama_kv_cache_unified : public llama_kv_cache {
21
  public:
22
  static uint32_t get_padding(const llama_cparams & cparams);
23
 
24
  // this callback is used to filter out layers that should not be included in the cache
25
  using layer_filter_cb = std::function<bool(int32_t il)>;
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  llama_kv_cache_unified(
28
  const llama_model & model,
29
  layer_filter_cb && filter,
@@ -43,21 +56,6 @@ public:
43
  // llama_memory_i
44
  //
45
 
46
- void clear() override;
47
-
48
- bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
49
- void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
50
- void seq_keep(llama_seq_id seq_id) override;
51
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
52
- void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
53
-
54
- llama_pos seq_pos_min(llama_seq_id seq_id) const override;
55
- llama_pos seq_pos_max(llama_seq_id seq_id) const override;
56
-
57
- //
58
- // llama_kv_cache
59
- //
60
-
61
  llama_memory_state_ptr init_batch(
62
  const llama_batch & batch,
63
  uint32_t n_ubatch,
@@ -66,12 +64,21 @@ public:
66
 
67
  llama_memory_state_ptr init_full() override;
68
 
69
- bool update(llama_context & lctx) override;
70
-
71
- void defrag_sched(float thold) override;
72
 
73
  bool get_can_shift() const override;
74
 
 
 
 
 
 
 
 
 
 
 
 
75
  // state write/load
76
 
77
  void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@@ -83,6 +90,8 @@ public:
83
 
84
  uint32_t get_size() const;
85
 
 
 
86
  //
87
  // graph_build API
88
  //
@@ -103,7 +112,9 @@ public:
103
 
104
  // find places for the provided ubatches in the cache, returns the head locations
105
  // return empty vector on failure
106
- std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
 
 
107
 
108
  // return the cell position where we can insert the ubatch
109
  // return -1 on failure to find a contiguous slot of kv cells
@@ -133,8 +144,7 @@ private:
133
  ggml_tensor * v;
134
  };
135
 
136
- bool do_defrag = false;
137
- bool v_trans = true; // the value tensor is transposed
138
 
139
  // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
140
  // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
@@ -160,13 +170,8 @@ private:
160
  // model layer id -> KV cache layer id
161
  std::unordered_map<int32_t, int32_t> map_layer_ids;
162
 
163
- // defrag
164
- struct {
165
- std::vector<uint32_t> ids;
166
- } defrag_info;
167
-
168
- // return true if cells have been moved
169
- bool defrag_prepare(int32_t n_max_nodes);
170
 
171
  size_t total_size() const;
172
 
@@ -192,7 +197,8 @@ private:
192
  llm_graph_result_ptr build_graph_defrag(
193
  const llama_cparams & cparams,
194
  ggml_context * ctx,
195
- ggml_cgraph * gf) const;
 
196
 
197
  void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
198
  void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
@@ -203,20 +209,29 @@ private:
203
 
204
  class llama_kv_cache_unified_state : public llama_memory_state_i {
205
  public:
 
 
 
 
206
  // used for errors
207
  llama_kv_cache_unified_state(llama_memory_status status);
208
 
209
  // used to create a full-cache state
210
  llama_kv_cache_unified_state(
211
- llama_memory_status status,
212
  llama_kv_cache_unified * kv);
213
 
214
- // used to create a state from a batch
 
 
 
 
 
 
 
215
  llama_kv_cache_unified_state(
216
- llama_memory_status status,
217
  llama_kv_cache_unified * kv,
218
  llama_sbatch sbatch,
219
- std::vector<uint32_t> heads,
220
  std::vector<llama_ubatch> ubatches);
221
 
222
  virtual ~llama_kv_cache_unified_state();
@@ -253,16 +268,30 @@ public:
253
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
254
 
255
  private:
256
- const llama_memory_status status;
257
 
258
  llama_kv_cache_unified * kv;
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  llama_sbatch sbatch;
261
 
262
  // the index of the next ubatch to process
263
  size_t i_next = 0;
264
 
265
- std::vector<uint32_t> heads;
 
266
  std::vector<llama_ubatch> ubatches;
267
 
268
  //
 
2
 
3
  #include "llama-batch.h"
4
  #include "llama-graph.h"
 
5
  #include "llama-kv-cells.h"
6
+ #include "llama-memory.h"
7
 
8
  #include <unordered_map>
9
  #include <vector>
 
17
  // llama_kv_cache_unified
18
  //
19
 
20
+ class llama_kv_cache_unified : public llama_memory_i {
21
  public:
22
  static uint32_t get_padding(const llama_cparams & cparams);
23
 
24
  // this callback is used to filter out layers that should not be included in the cache
25
  using layer_filter_cb = std::function<bool(int32_t il)>;
26
 
27
+ using ubatch_heads = std::vector<uint32_t>;
28
+
29
+ struct defrag_info {
30
+ bool empty() const {
31
+ return ids.empty();
32
+ }
33
+
34
+ // contains information about which cell moves where:
35
+ // - cell i moves to ids[i]
36
+ // - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
37
+ std::vector<uint32_t> ids;
38
+ };
39
+
40
  llama_kv_cache_unified(
41
  const llama_model & model,
42
  layer_filter_cb && filter,
 
56
  // llama_memory_i
57
  //
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  llama_memory_state_ptr init_batch(
60
  const llama_batch & batch,
61
  uint32_t n_ubatch,
 
64
 
65
  llama_memory_state_ptr init_full() override;
66
 
67
+ llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
 
 
68
 
69
  bool get_can_shift() const override;
70
 
71
+ void clear(bool data) override;
72
+
73
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
74
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
75
+ void seq_keep(llama_seq_id seq_id) override;
76
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
77
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
78
+
79
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
80
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
81
+
82
  // state write/load
83
 
84
  void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
 
90
 
91
  uint32_t get_size() const;
92
 
93
+ bool get_has_shift() const;
94
+
95
  //
96
  // graph_build API
97
  //
 
112
 
113
  // find places for the provided ubatches in the cache, returns the head locations
114
  // return empty vector on failure
115
+ ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
116
+
117
+ bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
118
 
119
  // return the cell position where we can insert the ubatch
120
  // return -1 on failure to find a contiguous slot of kv cells
 
144
  ggml_tensor * v;
145
  };
146
 
147
+ bool v_trans = true; // the value tensor is transposed
 
148
 
149
  // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
150
  // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
 
170
  // model layer id -> KV cache layer id
171
  std::unordered_map<int32_t, int32_t> map_layer_ids;
172
 
173
+ // return non-empty vector if cells have been moved
174
+ defrag_info defrag_prepare(int32_t n_max_nodes) const;
 
 
 
 
 
175
 
176
  size_t total_size() const;
177
 
 
197
  llm_graph_result_ptr build_graph_defrag(
198
  const llama_cparams & cparams,
199
  ggml_context * ctx,
200
+ ggml_cgraph * gf,
201
+ const defrag_info & dinfo) const;
202
 
203
  void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
204
  void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
 
209
 
210
  class llama_kv_cache_unified_state : public llama_memory_state_i {
211
  public:
212
+ // some shorthands
213
+ using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
214
+ using defrag_info = llama_kv_cache_unified::defrag_info;
215
+
216
  // used for errors
217
  llama_kv_cache_unified_state(llama_memory_status status);
218
 
219
  // used to create a full-cache state
220
  llama_kv_cache_unified_state(
 
221
  llama_kv_cache_unified * kv);
222
 
223
+ // used to create an update state
224
+ llama_kv_cache_unified_state(
225
+ llama_kv_cache_unified * kv,
226
+ llama_context * lctx,
227
+ bool do_shift,
228
+ defrag_info dinfo);
229
+
230
+ // used to create a decode state from a batch
231
  llama_kv_cache_unified_state(
 
232
  llama_kv_cache_unified * kv,
233
  llama_sbatch sbatch,
234
+ ubatch_heads heads,
235
  std::vector<llama_ubatch> ubatches);
236
 
237
  virtual ~llama_kv_cache_unified_state();
 
268
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
269
 
270
  private:
271
+ llama_memory_status status;
272
 
273
  llama_kv_cache_unified * kv;
274
+ llama_context * lctx;
275
+
276
+ //
277
+ // update state
278
+ //
279
+
280
+ bool do_shift = false;
281
+
282
+ defrag_info dinfo;
283
+
284
+ //
285
+ // batch processing state
286
+ //
287
 
288
  llama_sbatch sbatch;
289
 
290
  // the index of the next ubatch to process
291
  size_t i_next = 0;
292
 
293
+ ubatch_heads heads;
294
+
295
  std::vector<llama_ubatch> ubatches;
296
 
297
  //
examples/talk-llama/llama-kv-cache.cpp DELETED
@@ -1 +0,0 @@
1
- #include "llama-kv-cache.h"
 
 
examples/talk-llama/llama-kv-cells.h CHANGED
@@ -80,6 +80,9 @@ public:
80
  assert(isrc < pos.size());
81
  assert(idst < pos.size());
82
 
 
 
 
83
  pos [idst] = pos [isrc];
84
  shift[idst] = shift[isrc];
85
  seq [idst] = seq [isrc];
@@ -144,9 +147,10 @@ public:
144
  assert(pos[i] != -1);
145
 
146
  seq_pos_rm(i);
 
147
 
148
  pos[i] = -1;
149
- seq[i].reset();
150
 
151
  used.erase(i);
152
  }
@@ -164,6 +168,7 @@ public:
164
 
165
  if (seq[i].none()) {
166
  pos[i] = -1;
 
167
 
168
  used.erase(i);
169
 
@@ -192,6 +197,7 @@ public:
192
  seq[i].reset();
193
 
194
  pos[i] = -1;
 
195
 
196
  used.erase(i);
197
 
@@ -317,21 +323,20 @@ public:
317
  pos[i] += d;
318
  shift[i] += d;
319
 
320
- seq_pos_add(i);
321
-
322
  has_shift = true;
323
 
324
  if (pos[i] < 0) {
325
- seq_pos_rm(i);
326
-
327
  seq[i].reset();
328
  pos[i] = -1;
 
329
 
330
  used.erase(i);
331
 
332
  return true;
333
  }
334
 
 
 
335
  return false;
336
  }
337
 
 
80
  assert(isrc < pos.size());
81
  assert(idst < pos.size());
82
 
83
+ assert(pos[idst] == -1);
84
+ assert(pos[isrc] != -1);
85
+
86
  pos [idst] = pos [isrc];
87
  shift[idst] = shift[isrc];
88
  seq [idst] = seq [isrc];
 
147
  assert(pos[i] != -1);
148
 
149
  seq_pos_rm(i);
150
+ seq[i].reset();
151
 
152
  pos[i] = -1;
153
+ shift[i] = 0;
154
 
155
  used.erase(i);
156
  }
 
168
 
169
  if (seq[i].none()) {
170
  pos[i] = -1;
171
+ shift[i] = 0;
172
 
173
  used.erase(i);
174
 
 
197
  seq[i].reset();
198
 
199
  pos[i] = -1;
200
+ shift[i] = 0;
201
 
202
  used.erase(i);
203
 
 
323
  pos[i] += d;
324
  shift[i] += d;
325
 
 
 
326
  has_shift = true;
327
 
328
  if (pos[i] < 0) {
 
 
329
  seq[i].reset();
330
  pos[i] = -1;
331
+ shift[i] = 0;
332
 
333
  used.erase(i);
334
 
335
  return true;
336
  }
337
 
338
+ seq_pos_add(i);
339
+
340
  return false;
341
  }
342
 
examples/talk-llama/llama-memory.cpp CHANGED
@@ -1 +1,42 @@
1
  #include "llama-memory.h"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #include "llama-memory.h"
2
+
3
+ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) {
4
+ bool has_update = false;
5
+
6
+ switch (s0) {
7
+ case LLAMA_MEMORY_STATUS_SUCCESS:
8
+ {
9
+ has_update = true;
10
+ break;
11
+ }
12
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
13
+ {
14
+ break;
15
+ }
16
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
17
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
18
+ {
19
+ return s0;
20
+ }
21
+ }
22
+
23
+ switch (s1) {
24
+ case LLAMA_MEMORY_STATUS_SUCCESS:
25
+ {
26
+ has_update = true;
27
+ break;
28
+ }
29
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
30
+ {
31
+ break;
32
+ }
33
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
34
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
35
+ {
36
+ return s1;
37
+ }
38
+ }
39
+
40
+ // if either status has an update, then the combined status has an update
41
+ return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
42
+ }
examples/talk-llama/llama-memory.h CHANGED
@@ -7,6 +7,9 @@
7
 
8
  struct llama_ubatch;
9
 
 
 
 
10
  struct llama_memory_params {
11
  // kv cache
12
  ggml_type type_k;
@@ -16,32 +19,17 @@ struct llama_memory_params {
16
  bool swa_full;
17
  };
18
 
19
- // general concept of LLM memory
20
- // the KV cache is a type of LLM memory, but there can be other types
21
- class llama_memory_i {
22
- public:
23
- virtual ~llama_memory_i() = default;
24
-
25
- virtual void clear() = 0;
26
-
27
- virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
28
- virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
29
- virtual void seq_keep(llama_seq_id seq_id) = 0;
30
- virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
31
- virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
32
-
33
- virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
34
- virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
35
-
36
- virtual bool get_can_edit() const = 0;
37
- };
38
-
39
  enum llama_memory_status {
40
  LLAMA_MEMORY_STATUS_SUCCESS = 0,
 
41
  LLAMA_MEMORY_STATUS_FAILED_PREPARE,
42
  LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
43
  };
44
 
 
 
 
 
45
  // the interface for managing the memory state during batch processing
46
  // this interface is implemented per memory type. see:
47
  // - llama_kv_cache_unified_state
@@ -51,8 +39,7 @@ enum llama_memory_status {
51
  // the only method that can mutate the memory and the memory state is llama_memory_i::apply()
52
  //
53
  // TODO: rename to llama_memory_context_i ?
54
- class llama_memory_state_i {
55
- public:
56
  virtual ~llama_memory_state_i() = default;
57
 
58
  // consume the current ubatch from the state and proceed to the next one
@@ -69,8 +56,63 @@ public:
69
  // get the current ubatch
70
  virtual const llama_ubatch & get_ubatch() const = 0;
71
 
72
- // get the status of the memory state
73
  virtual llama_memory_status get_status() const = 0;
74
  };
75
 
76
  using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  struct llama_ubatch;
9
 
10
+ class llama_io_write_i;
11
+ class llama_io_read_i;
12
+
13
  struct llama_memory_params {
14
  // kv cache
15
  ggml_type type_k;
 
19
  bool swa_full;
20
  };
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  enum llama_memory_status {
23
  LLAMA_MEMORY_STATUS_SUCCESS = 0,
24
+ LLAMA_MEMORY_STATUS_NO_UPDATE,
25
  LLAMA_MEMORY_STATUS_FAILED_PREPARE,
26
  LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
27
  };
28
 
29
+ // helper function for combining the status of two memory states
30
+ // useful for implementing hybrid memory types (e.g. iSWA)
31
+ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
32
+
33
  // the interface for managing the memory state during batch processing
34
  // this interface is implemented per memory type. see:
35
  // - llama_kv_cache_unified_state
 
39
  // the only method that can mutate the memory and the memory state is llama_memory_i::apply()
40
  //
41
  // TODO: rename to llama_memory_context_i ?
42
+ struct llama_memory_state_i {
 
43
  virtual ~llama_memory_state_i() = default;
44
 
45
  // consume the current ubatch from the state and proceed to the next one
 
56
  // get the current ubatch
57
  virtual const llama_ubatch & get_ubatch() const = 0;
58
 
59
+ // get the status of the memory state - used for error handling and checking if any updates would be applied
60
  virtual llama_memory_status get_status() const = 0;
61
  };
62
 
63
  using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
64
+
65
+ // general concept of LLM memory
66
+ // the KV cache is a type of LLM memory, but there can be other types
67
+ struct llama_memory_i {
68
+ virtual ~llama_memory_i() = default;
69
+
70
+ // split the input batch into a set of ubatches and verify that they can fit into the cache
71
+ // return a state object containing the ubatches and KV cache state required to process them
72
+ // check the llama_memory_state_i::get_status() for the result
73
+ virtual llama_memory_state_ptr init_batch(
74
+ const llama_batch & batch,
75
+ uint32_t n_ubatch,
76
+ bool embd_pooled,
77
+ bool logits_all) = 0;
78
+
79
+ // simulate full cache, used for allocating worst-case compute buffers
80
+ virtual llama_memory_state_ptr init_full() = 0;
81
+
82
+ // prepare for any pending memory updates, such as shifts, defrags, etc.
83
+ // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
84
+ virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
85
+
86
+ // getters
87
+ virtual bool get_can_shift() const = 0;
88
+
89
+ //
90
+ // ops
91
+ //
92
+
93
+ // if data == true, the data buffers will also be cleared together with the metadata
94
+ virtual void clear(bool data) = 0;
95
+
96
+ virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
97
+ virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
98
+ virtual void seq_keep(llama_seq_id seq_id) = 0;
99
+ virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
100
+ virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
101
+
102
+ virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
103
+ virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
104
+
105
+ //
106
+ // state write/read
107
+ //
108
+
109
+ virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
110
+ virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
111
+ };
112
+
113
+ using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
114
+
115
+ // TODO: temporary until the llama_kv_cache is removed from the public API
116
+ struct llama_kv_cache : public llama_memory_i {
117
+ virtual ~llama_kv_cache() = default;
118
+ };
examples/talk-llama/llama-mmap.cpp CHANGED
@@ -401,7 +401,7 @@ struct llama_mmap::impl {
401
  }
402
  }
403
  #else
404
- throw std::runtime_error("PrefetchVirtualMemory unavailable");
405
  #endif
406
  }
407
  }
 
401
  }
402
  }
403
  #else
404
+ LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n");
405
  #endif
406
  }
407
  }
examples/talk-llama/llama-model-loader.cpp CHANGED
@@ -288,9 +288,10 @@ namespace GGUFMeta {
288
 
289
  template<typename T>
290
  bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
291
- const int kid = gguf_find_key(meta.get(), key.c_str());
 
292
 
293
- if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
294
  if (required) {
295
  throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
296
  }
@@ -298,28 +299,40 @@ namespace GGUFMeta {
298
  }
299
 
300
  struct GGUFMeta::ArrayInfo arr_info =
301
- GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
302
 
303
  switch (arr_info.gt) {
304
  case GGUF_TYPE_UINT32:
305
- case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
306
- (std::is_same<T, uint32_t>::value)); break;
307
- case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
 
308
  default:
309
- throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
310
  }
311
 
312
- result.resize(arr_info.length);
313
- result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
 
 
 
 
 
 
 
 
 
 
314
 
315
  return true;
316
  }
317
 
318
  template<typename T, size_t N_MAX>
319
  bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
320
- const int kid = gguf_find_key(meta.get(), key.c_str());
 
321
 
322
- if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
323
  if (required) {
324
  throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
325
  }
@@ -327,22 +340,32 @@ namespace GGUFMeta {
327
  }
328
 
329
  struct GGUFMeta::ArrayInfo arr_info =
330
- GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
331
 
332
  switch (arr_info.gt) {
333
  case GGUF_TYPE_UINT32:
334
- case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
335
- (std::is_same<T, uint32_t>::value)); break;
336
- case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
 
337
  default:
338
- throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
339
  }
340
 
341
  if (arr_info.length > N_MAX) {
342
  throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
343
  }
344
 
345
- std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
 
 
 
 
 
 
 
 
 
346
 
347
  return true;
348
  }
@@ -352,6 +375,8 @@ namespace GGUFMeta {
352
  return get_arr(llm_kv(kid), result, required);
353
  }
354
 
 
 
355
  template<typename T>
356
  bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
357
  auto it = kv_overrides.find(key);
 
288
 
289
  template<typename T>
290
  bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
291
+ const gguf_context * ctx = meta.get();
292
+ const int kid = gguf_find_key(ctx, key.c_str());
293
 
294
+ if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
295
  if (required) {
296
  throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
297
  }
 
299
  }
300
 
301
  struct GGUFMeta::ArrayInfo arr_info =
302
+ GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
303
 
304
  switch (arr_info.gt) {
305
  case GGUF_TYPE_UINT32:
306
+ case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
307
+ (std::is_same<T, uint32_t>::value)); break;
308
+ case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
309
+ case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
310
  default:
311
+ throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
312
  }
313
 
314
+ if constexpr (std::is_same<T, std::string>::value) {
315
+ const size_t n_items = gguf_get_arr_n(ctx, kid);
316
+ result.clear();
317
+
318
+ for (size_t i = 0; i < n_items; i++) {
319
+ const T value = gguf_get_arr_str(ctx, kid, i);
320
+ result.emplace_back(value);
321
+ }
322
+ } else {
323
+ result.resize(arr_info.length);
324
+ result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
325
+ }
326
 
327
  return true;
328
  }
329
 
330
  template<typename T, size_t N_MAX>
331
  bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
332
+ const gguf_context * ctx = meta.get();
333
+ const int kid = gguf_find_key(ctx, key.c_str());
334
 
335
+ if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
336
  if (required) {
337
  throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
338
  }
 
340
  }
341
 
342
  struct GGUFMeta::ArrayInfo arr_info =
343
+ GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
344
 
345
  switch (arr_info.gt) {
346
  case GGUF_TYPE_UINT32:
347
+ case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
348
+ (std::is_same<T, uint32_t>::value)); break;
349
+ case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
350
+ case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
351
  default:
352
+ throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
353
  }
354
 
355
  if (arr_info.length > N_MAX) {
356
  throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
357
  }
358
 
359
+ if constexpr (std::is_same<T, std::string>::value) {
360
+ const size_t n_items = gguf_get_arr_n(ctx, kid);
361
+
362
+ for (size_t i = 0; i < n_items; i++) {
363
+ const T value = gguf_get_arr_str(ctx, kid, i);
364
+ result[i] = value;
365
+ }
366
+ } else {
367
+ std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
368
+ }
369
 
370
  return true;
371
  }
 
375
  return get_arr(llm_kv(kid), result, required);
376
  }
377
 
378
+ template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
379
+
380
  template<typename T>
381
  bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
382
  auto it = kv_overrides.find(key);
examples/talk-llama/llama-model.cpp CHANGED
@@ -543,6 +543,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
543
  uint32_t n_vocab = 0;
544
  ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
545
 
 
 
 
 
 
 
546
  // arch-specific KVs
547
  switch (arch) {
548
  case LLM_ARCH_LLAMA:
@@ -686,7 +692,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
686
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
687
  ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
688
  ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
689
- ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
690
 
691
  switch (hparams.n_layer) {
692
  case 3:
@@ -956,6 +961,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
956
  case 46: type = LLM_TYPE_27B; break;
957
  default: type = LLM_TYPE_UNKNOWN;
958
  }
 
 
 
 
 
959
  } break;
960
  case LLM_ARCH_GEMMA3:
961
  {
@@ -976,6 +986,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
976
  default: type = LLM_TYPE_UNKNOWN;
977
  }
978
 
 
979
  hparams.f_attention_scale = type == LLM_TYPE_27B
980
  ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
981
  : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
@@ -4356,6 +4367,15 @@ void llama_model::print_info() const {
4356
  LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
4357
  LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
4358
  LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
 
 
 
 
 
 
 
 
 
4359
  }
4360
 
4361
  LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
@@ -8484,14 +8504,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
8484
  cb(Kcur, "Kcur", il);
8485
  cb(Vcur, "Vcur", il);
8486
 
8487
- // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
8488
- switch (model.type) {
8489
- case LLM_TYPE_2B:
8490
- case LLM_TYPE_9B:
8491
- case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
8492
- default: GGML_ABORT("fatal error");
8493
- };
8494
- cb(Qcur, "Qcur_scaled", il);
8495
 
8496
  cur = build_attn(inp_attn, gf,
8497
  model.layers[il].wo, NULL,
@@ -8632,9 +8645,12 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
8632
  cb(Kcur, "Kcur", il);
8633
  cb(Vcur, "Vcur", il);
8634
 
 
 
 
8635
  cur = build_attn(inp_attn, gf,
8636
  model.layers[il].wo, NULL,
8637
- Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
8638
  }
8639
 
8640
  cur = build_norm(cur,
@@ -13600,6 +13616,18 @@ int32_t llama_model_n_swa(const llama_model * model) {
13600
  return model->hparams.n_swa;
13601
  }
13602
 
 
 
 
 
 
 
 
 
 
 
 
 
13603
  // deprecated
13604
  int32_t llama_n_ctx_train(const llama_model * model) {
13605
  return llama_model_n_ctx_train(model);
@@ -13760,7 +13788,7 @@ uint64_t llama_model_size(const llama_model * model) {
13760
  }
13761
 
13762
  const char * llama_model_chat_template(const llama_model * model, const char * name) {
13763
- const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
13764
  : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
13765
  const auto & it = model->gguf_kv.find(key);
13766
  if (it == model->gguf_kv.end()) {
 
543
  uint32_t n_vocab = 0;
544
  ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
545
 
546
+ // for classifier models
547
+ ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
548
+ if (!classifier_labels.empty()) {
549
+ hparams.n_cls_out = classifier_labels.size();
550
+ }
551
+
552
  // arch-specific KVs
553
  switch (arch) {
554
  case LLM_ARCH_LLAMA:
 
692
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
693
  ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
694
  ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
 
695
 
696
  switch (hparams.n_layer) {
697
  case 3:
 
961
  case 46: type = LLM_TYPE_27B; break;
962
  default: type = LLM_TYPE_UNKNOWN;
963
  }
964
+
965
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
966
+ hparams.f_attention_scale = type == LLM_TYPE_27B
967
+ ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
968
+ : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
969
  } break;
970
  case LLM_ARCH_GEMMA3:
971
  {
 
986
  default: type = LLM_TYPE_UNKNOWN;
987
  }
988
 
989
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
990
  hparams.f_attention_scale = type == LLM_TYPE_27B
991
  ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
992
  : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
 
4367
  LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
4368
  LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
4369
  LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
4370
+
4371
+ if (!classifier_labels.empty()) {
4372
+ LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
4373
+
4374
+ size_t i = 0;
4375
+ for (auto label : classifier_labels) {
4376
+ LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str());
4377
+ }
4378
+ }
4379
  }
4380
 
4381
  LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
 
8504
  cb(Kcur, "Kcur", il);
8505
  cb(Vcur, "Vcur", il);
8506
 
8507
+ Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
 
 
 
 
 
 
 
8508
 
8509
  cur = build_attn(inp_attn, gf,
8510
  model.layers[il].wo, NULL,
 
8645
  cb(Kcur, "Kcur", il);
8646
  cb(Vcur, "Vcur", il);
8647
 
8648
+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
8649
+ Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
8650
+
8651
  cur = build_attn(inp_attn, gf,
8652
  model.layers[il].wo, NULL,
8653
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8654
  }
8655
 
8656
  cur = build_norm(cur,
 
13616
  return model->hparams.n_swa;
13617
  }
13618
 
13619
+ uint32_t llama_model_n_cls_out(const struct llama_model * model) {
13620
+ return model->hparams.n_cls_out;
13621
+ }
13622
+
13623
+ const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
13624
+ if (i < model->classifier_labels.size()) {
13625
+ return model->classifier_labels[i].c_str();
13626
+ }
13627
+
13628
+ return nullptr;
13629
+ }
13630
+
13631
  // deprecated
13632
  int32_t llama_n_ctx_train(const llama_model * model) {
13633
  return llama_model_n_ctx_train(model);
 
13788
  }
13789
 
13790
  const char * llama_model_chat_template(const llama_model * model, const char * name) {
13791
+ const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)
13792
  : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
13793
  const auto & it = model->gguf_kv.find(key);
13794
  if (it == model->gguf_kv.end()) {
examples/talk-llama/llama-model.h CHANGED
@@ -329,6 +329,9 @@ struct llama_model {
329
  llama_hparams hparams = {};
330
  llama_vocab vocab;
331
 
 
 
 
332
  struct ggml_tensor * tok_embd = nullptr;
333
  struct ggml_tensor * type_embd = nullptr;
334
  struct ggml_tensor * pos_embd = nullptr;
 
329
  llama_hparams hparams = {};
330
  llama_vocab vocab;
331
 
332
+ // for classifier models
333
+ std::vector<std::string> classifier_labels;
334
+
335
  struct ggml_tensor * tok_embd = nullptr;
336
  struct ggml_tensor * type_embd = nullptr;
337
  struct ggml_tensor * pos_embd = nullptr;
examples/talk-llama/llama-vocab.cpp CHANGED
@@ -2080,9 +2080,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
2080
 
2081
  std::string model_name;
2082
  std::string tokenizer_pre;
 
2083
 
2084
  ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
2085
  ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
 
2086
 
2087
  // model name to lowercase
2088
  std::transform(model_name.begin(), model_name.end(), model_name.begin(),
@@ -2091,9 +2093,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
2091
  }
2092
  );
2093
 
2094
- // set attributes by model/tokenizer name
2095
- if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
2096
- _set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
 
 
 
 
 
 
 
2097
  } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
2098
  for (auto id : cache_special_tokens) {
2099
  _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
 
2080
 
2081
  std::string model_name;
2082
  std::string tokenizer_pre;
2083
+ std::string general_arch;
2084
 
2085
  ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
2086
  ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
2087
+ ml.get_key(LLM_KV_GENERAL_ARCHITECTURE, general_arch, false);
2088
 
2089
  // model name to lowercase
2090
  std::transform(model_name.begin(), model_name.end(), model_name.begin(),
 
2093
  }
2094
  );
2095
 
2096
+ // set attributes by model/tokenizer/architecture name
2097
+ if (false
2098
+ || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
2099
+ || _contains_any(general_arch, {"nomic-bert-moe"})
2100
+ ) {
2101
+ if (token_to_id.count("<mask>") == 0) {
2102
+ LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
2103
+ } else {
2104
+ _set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
2105
+ }
2106
  } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
2107
  for (auto id : cache_special_tokens) {
2108
  _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
examples/talk-llama/llama.h CHANGED
@@ -61,7 +61,10 @@ extern "C" {
61
  struct llama_model;
62
  struct llama_context;
63
  struct llama_sampler;
64
- struct llama_kv_cache;
 
 
 
65
 
66
  typedef int32_t llama_pos;
67
  typedef int32_t llama_token;
@@ -493,9 +496,11 @@ extern "C" {
493
  DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
494
 
495
  LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
496
- LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
497
  LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
498
 
 
 
499
  LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
500
  LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
501
 
@@ -509,6 +514,13 @@ extern "C" {
509
  // Get the model's RoPE frequency scaling factor
510
  LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
511
 
 
 
 
 
 
 
 
512
  LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
513
 
514
  LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
@@ -609,7 +621,81 @@ extern "C" {
609
  int32_t il_end);
610
 
611
  //
612
- // KV cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  //
614
 
615
  // Returns the number of tokens in the KV cache (slow, use only for debug)
@@ -622,86 +708,95 @@ extern "C" {
622
  "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
623
 
624
  // Clear the KV cache - both cell info is erased and KV data is zeroed
625
- LLAMA_API void llama_kv_self_clear(
626
- struct llama_context * ctx);
 
627
 
628
  // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
629
  // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
630
  // seq_id < 0 : match any sequence
631
  // p0 < 0 : [0, p1]
632
  // p1 < 0 : [p0, inf)
633
- LLAMA_API bool llama_kv_self_seq_rm(
634
  struct llama_context * ctx,
635
  llama_seq_id seq_id,
636
  llama_pos p0,
637
- llama_pos p1);
 
638
 
639
  // Copy all tokens that belong to the specified sequence to another sequence
640
  // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
641
  // p0 < 0 : [0, p1]
642
  // p1 < 0 : [p0, inf)
643
- LLAMA_API void llama_kv_self_seq_cp(
644
  struct llama_context * ctx,
645
  llama_seq_id seq_id_src,
646
  llama_seq_id seq_id_dst,
647
  llama_pos p0,
648
- llama_pos p1);
 
649
 
650
  // Removes all tokens that do not belong to the specified sequence
651
- LLAMA_API void llama_kv_self_seq_keep(
652
  struct llama_context * ctx,
653
- llama_seq_id seq_id);
 
654
 
655
  // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
656
  // If the KV cache is RoPEd, the KV data is updated accordingly:
657
  // - lazily on next llama_decode()
658
  // p0 < 0 : [0, p1]
659
  // p1 < 0 : [p0, inf)
660
- LLAMA_API void llama_kv_self_seq_add(
661
  struct llama_context * ctx,
662
  llama_seq_id seq_id,
663
  llama_pos p0,
664
  llama_pos p1,
665
- llama_pos delta);
 
666
 
667
  // Integer division of the positions by factor of `d > 1`
668
  // If the KV cache is RoPEd, the KV data is updated accordingly:
669
  // - lazily on next llama_decode()
670
  // p0 < 0 : [0, p1]
671
  // p1 < 0 : [p0, inf)
672
- LLAMA_API void llama_kv_self_seq_div(
673
  struct llama_context * ctx,
674
  llama_seq_id seq_id,
675
  llama_pos p0,
676
  llama_pos p1,
677
- int d);
 
678
 
679
  // Returns the smallest position present in the KV cache for the specified sequence
680
  // This is typically non-zero only for SWA caches
681
  // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
682
  // Return -1 if the sequence is empty
683
- LLAMA_API llama_pos llama_kv_self_seq_pos_min(
684
  struct llama_context * ctx,
685
- llama_seq_id seq_id);
 
686
 
687
  // Returns the largest position present in the KV cache for the specified sequence
688
  // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
689
  // Return -1 if the sequence is empty
690
- LLAMA_API llama_pos llama_kv_self_seq_pos_max(
691
  struct llama_context * ctx,
692
- llama_seq_id seq_id);
 
693
 
694
  // Defragment the KV cache
695
  // This will be applied:
696
  // - lazily on next llama_decode()
697
- LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
698
  "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
699
 
700
  // Check if the context supports KV cache shifting
701
- LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
 
702
 
703
  // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
704
- LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
705
  "simply remove this call, updates are applied lazily on the next llama_decode()");
706
 
707
  //
@@ -709,7 +804,7 @@ extern "C" {
709
  //
710
 
711
  // Returns the *actual* size in bytes of the state
712
- // (logits, embedding and kv_cache)
713
  // Only use when saving the state, not when restoring it, otherwise the size may be too small.
714
  LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
715
  LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
@@ -765,12 +860,12 @@ extern "C" {
765
  size_t n_token_count),
766
  "use llama_state_save_file instead");
767
 
768
- // Get the exact size needed to copy the KV cache of a single sequence
769
  LLAMA_API size_t llama_state_seq_get_size(
770
  struct llama_context * ctx,
771
  llama_seq_id seq_id);
772
 
773
- // Copy the KV cache of a single sequence into the specified buffer
774
  LLAMA_API size_t llama_state_seq_get_data(
775
  struct llama_context * ctx,
776
  uint8_t * dst,
@@ -836,16 +931,16 @@ extern "C" {
836
  // For encode-decoder contexts, processes the batch using the encoder.
837
  // Can store the encoder output internally for later use by the decoder's cross-attention layers.
838
  // 0 - success
839
- // < 0 - error. the KV cache state is restored to the state before this call
840
  LLAMA_API int32_t llama_encode(
841
  struct llama_context * ctx,
842
  struct llama_batch batch);
843
 
844
  // Process a batch of tokens.
845
- // Requires KV cache.
846
  // For encode-decoder contexts, processes the batch using the decoder.
847
  // Positive return values does not mean a fatal error, but rather a warning.
848
- // Upon non-zero return values, the KV cache state is restored to the state before this call
849
  // 0 - success
850
  // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
851
  // 2 - aborted
@@ -916,7 +1011,7 @@ extern "C" {
916
 
917
  // Get the embeddings for a sequence id
918
  // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
919
- // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
920
  // otherwise: float[n_embd] (1-dimensional)
921
  LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
922
 
 
61
  struct llama_model;
62
  struct llama_context;
63
  struct llama_sampler;
64
+
65
+ typedef struct llama_memory_i * llama_memory_t;
66
+
67
+ struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
68
 
69
  typedef int32_t llama_pos;
70
  typedef int32_t llama_token;
 
496
  DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
497
 
498
  LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
499
+ LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
500
  LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
501
 
502
+ DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
503
+
504
  LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
505
  LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
506
 
 
514
  // Get the model's RoPE frequency scaling factor
515
  LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
516
 
517
+ // Returns the number of classifier outputs (only valid for classifier models)
518
+ // Undefined behavior for non-classifier models
519
+ LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);
520
+
521
+ // Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
522
+ LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i);
523
+
524
  LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
525
 
526
  LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
 
621
  int32_t il_end);
622
 
623
  //
624
+ // Memory
625
+ //
626
+
627
+ // Clear the memory contents
628
+ // If data == true, the data buffers will also be cleared together with the metadata
629
+ LLAMA_API void llama_memory_clear(
630
+ llama_memory_t mem,
631
+ bool data);
632
+
633
+ // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
634
+ // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
635
+ // seq_id < 0 : match any sequence
636
+ // p0 < 0 : [0, p1]
637
+ // p1 < 0 : [p0, inf)
638
+ LLAMA_API bool llama_memory_seq_rm(
639
+ llama_memory_t mem,
640
+ llama_seq_id seq_id,
641
+ llama_pos p0,
642
+ llama_pos p1);
643
+
644
+ // Copy all tokens that belong to the specified sequence to another sequence
645
+ // p0 < 0 : [0, p1]
646
+ // p1 < 0 : [p0, inf)
647
+ LLAMA_API void llama_memory_seq_cp(
648
+ llama_memory_t mem,
649
+ llama_seq_id seq_id_src,
650
+ llama_seq_id seq_id_dst,
651
+ llama_pos p0,
652
+ llama_pos p1);
653
+
654
+ // Removes all tokens that do not belong to the specified sequence
655
+ LLAMA_API void llama_memory_seq_keep(
656
+ llama_memory_t mem,
657
+ llama_seq_id seq_id);
658
+
659
+ // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
660
+ // p0 < 0 : [0, p1]
661
+ // p1 < 0 : [p0, inf)
662
+ LLAMA_API void llama_memory_seq_add(
663
+ llama_memory_t mem,
664
+ llama_seq_id seq_id,
665
+ llama_pos p0,
666
+ llama_pos p1,
667
+ llama_pos delta);
668
+
669
+ // Integer division of the positions by factor of `d > 1`
670
+ // p0 < 0 : [0, p1]
671
+ // p1 < 0 : [p0, inf)
672
+ LLAMA_API void llama_memory_seq_div(
673
+ llama_memory_t mem,
674
+ llama_seq_id seq_id,
675
+ llama_pos p0,
676
+ llama_pos p1,
677
+ int d);
678
+
679
+ // Returns the smallest position present in the memory for the specified sequence
680
+ // This is typically non-zero only for SWA caches
681
+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
682
+ // Return -1 if the sequence is empty
683
+ LLAMA_API llama_pos llama_memory_seq_pos_min(
684
+ llama_memory_t mem,
685
+ llama_seq_id seq_id);
686
+
687
+ // Returns the largest position present in the memory for the specified sequence
688
+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
689
+ // Return -1 if the sequence is empty
690
+ LLAMA_API llama_pos llama_memory_seq_pos_max(
691
+ llama_memory_t mem,
692
+ llama_seq_id seq_id);
693
+
694
+ // Check if the memory supports shifting
695
+ LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
696
+
697
+ //
698
+ // KV cache for self-attention (TODO: deprecate in favor of llama_memory)
699
  //
700
 
701
  // Returns the number of tokens in the KV cache (slow, use only for debug)
 
708
  "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
709
 
710
  // Clear the KV cache - both cell info is erased and KV data is zeroed
711
+ DEPRECATED(LLAMA_API void llama_kv_self_clear(
712
+ struct llama_context * ctx),
713
+ "Use llama_memory_clear() instead");
714
 
715
  // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
716
  // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
717
  // seq_id < 0 : match any sequence
718
  // p0 < 0 : [0, p1]
719
  // p1 < 0 : [p0, inf)
720
+ DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm(
721
  struct llama_context * ctx,
722
  llama_seq_id seq_id,
723
  llama_pos p0,
724
+ llama_pos p1),
725
+ "Use llama_memory_seq_rm() instead");
726
 
727
  // Copy all tokens that belong to the specified sequence to another sequence
728
  // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
729
  // p0 < 0 : [0, p1]
730
  // p1 < 0 : [p0, inf)
731
+ DEPRECATED(LLAMA_API void llama_kv_self_seq_cp(
732
  struct llama_context * ctx,
733
  llama_seq_id seq_id_src,
734
  llama_seq_id seq_id_dst,
735
  llama_pos p0,
736
+ llama_pos p1),
737
+ "Use llama_memory_seq_cp() instead");
738
 
739
  // Removes all tokens that do not belong to the specified sequence
740
+ DEPRECATED(LLAMA_API void llama_kv_self_seq_keep(
741
  struct llama_context * ctx,
742
+ llama_seq_id seq_id),
743
+ "Use llama_memory_seq_keep() instead");
744
 
745
  // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
746
  // If the KV cache is RoPEd, the KV data is updated accordingly:
747
  // - lazily on next llama_decode()
748
  // p0 < 0 : [0, p1]
749
  // p1 < 0 : [p0, inf)
750
+ DEPRECATED(LLAMA_API void llama_kv_self_seq_add(
751
  struct llama_context * ctx,
752
  llama_seq_id seq_id,
753
  llama_pos p0,
754
  llama_pos p1,
755
+ llama_pos delta),
756
+ "Use llama_memory_seq_add() instead");
757
 
758
  // Integer division of the positions by factor of `d > 1`
759
  // If the KV cache is RoPEd, the KV data is updated accordingly:
760
  // - lazily on next llama_decode()
761
  // p0 < 0 : [0, p1]
762
  // p1 < 0 : [p0, inf)
763
+ DEPRECATED(void llama_kv_self_seq_div(
764
  struct llama_context * ctx,
765
  llama_seq_id seq_id,
766
  llama_pos p0,
767
  llama_pos p1,
768
+ int d),
769
+ "Use llama_memory_seq_div() instead");
770
 
771
  // Returns the smallest position present in the KV cache for the specified sequence
772
  // This is typically non-zero only for SWA caches
773
  // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
774
  // Return -1 if the sequence is empty
775
+ DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min(
776
  struct llama_context * ctx,
777
+ llama_seq_id seq_id),
778
+ "Use llama_memory_seq_pos_min() instead");
779
 
780
  // Returns the largest position present in the KV cache for the specified sequence
781
  // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
782
  // Return -1 if the sequence is empty
783
+ DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max(
784
  struct llama_context * ctx,
785
+ llama_seq_id seq_id),
786
+ "Use llama_memory_seq_pos_max() instead");
787
 
788
  // Defragment the KV cache
789
  // This will be applied:
790
  // - lazily on next llama_decode()
791
+ DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
792
  "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
793
 
794
  // Check if the context supports KV cache shifting
795
+ DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx),
796
+ "use llama_memory_can_shift() instead");
797
 
798
  // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
799
+ DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
800
  "simply remove this call, updates are applied lazily on the next llama_decode()");
801
 
802
  //
 
804
  //
805
 
806
  // Returns the *actual* size in bytes of the state
807
+ // (logits, embedding and memory)
808
  // Only use when saving the state, not when restoring it, otherwise the size may be too small.
809
  LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
810
  LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
 
860
  size_t n_token_count),
861
  "use llama_state_save_file instead");
862
 
863
+ // Get the exact size needed to copy the state of a single sequence
864
  LLAMA_API size_t llama_state_seq_get_size(
865
  struct llama_context * ctx,
866
  llama_seq_id seq_id);
867
 
868
+ // Copy the state of a single sequence into the specified buffer
869
  LLAMA_API size_t llama_state_seq_get_data(
870
  struct llama_context * ctx,
871
  uint8_t * dst,
 
931
  // For encode-decoder contexts, processes the batch using the encoder.
932
  // Can store the encoder output internally for later use by the decoder's cross-attention layers.
933
  // 0 - success
934
+ // < 0 - error. the memory state is restored to the state before this call
935
  LLAMA_API int32_t llama_encode(
936
  struct llama_context * ctx,
937
  struct llama_batch batch);
938
 
939
  // Process a batch of tokens.
940
+ // Requires the context to have a memory.
941
  // For encode-decoder contexts, processes the batch using the decoder.
942
  // Positive return values does not mean a fatal error, but rather a warning.
943
+ // Upon non-zero return values, the memory state is restored to the state before this call
944
  // 0 - success
945
  // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
946
  // 2 - aborted
 
1011
 
1012
  // Get the embeddings for a sequence id
1013
  // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
1014
+ // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence
1015
  // otherwise: float[n_embd] (1-dimensional)
1016
  LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
1017