Spaces:
Sleeping
Sleeping
talk-llama : sync llama.cpp
Browse files- examples/talk-llama/CMakeLists.txt +0 -1
- examples/talk-llama/llama-arch.cpp +8 -3
- examples/talk-llama/llama-arch.h +0 -1
- examples/talk-llama/llama-context.cpp +240 -109
- examples/talk-llama/llama-context.h +8 -6
- examples/talk-llama/llama-graph.cpp +16 -3
- examples/talk-llama/llama-graph.h +2 -1
- examples/talk-llama/llama-kv-cache-recurrent.cpp +16 -16
- examples/talk-llama/llama-kv-cache-recurrent.h +13 -19
- examples/talk-llama/llama-kv-cache-unified-iswa.cpp +34 -31
- examples/talk-llama/llama-kv-cache-unified-iswa.h +22 -24
- examples/talk-llama/llama-kv-cache-unified.cpp +142 -89
- examples/talk-llama/llama-kv-cache-unified.h +66 -37
- examples/talk-llama/llama-kv-cache.cpp +0 -1
- examples/talk-llama/llama-kv-cells.h +10 -5
- examples/talk-llama/llama-memory.cpp +41 -0
- examples/talk-llama/llama-memory.h +65 -23
- examples/talk-llama/llama-mmap.cpp +1 -1
- examples/talk-llama/llama-model-loader.cpp +42 -17
- examples/talk-llama/llama-model.cpp +39 -11
- examples/talk-llama/llama-model.h +3 -0
- examples/talk-llama/llama-vocab.cpp +12 -3
- examples/talk-llama/llama.h +124 -29
examples/talk-llama/CMakeLists.txt
CHANGED
|
@@ -16,7 +16,6 @@ if (WHISPER_SDL2)
|
|
| 16 |
llama-hparams.cpp
|
| 17 |
llama-impl.cpp
|
| 18 |
llama-io.cpp
|
| 19 |
-
llama-kv-cache.cpp
|
| 20 |
llama-kv-cache-unified.cpp
|
| 21 |
llama-kv-cache-unified-iswa.cpp
|
| 22 |
llama-kv-cache-recurrent.cpp
|
|
|
|
| 16 |
llama-hparams.cpp
|
| 17 |
llama-impl.cpp
|
| 18 |
llama-io.cpp
|
|
|
|
| 19 |
llama-kv-cache-unified.cpp
|
| 20 |
llama-kv-cache-unified-iswa.cpp
|
| 21 |
llama-kv-cache-recurrent.cpp
|
examples/talk-llama/llama-arch.cpp
CHANGED
|
@@ -200,7 +200,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|
| 200 |
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
| 201 |
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
| 202 |
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
|
| 203 |
-
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
|
| 204 |
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
|
| 205 |
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
|
| 206 |
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
|
|
@@ -1707,8 +1706,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|
| 1707 |
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
| 1708 |
|
| 1709 |
std::string LLM_KV::operator()(llm_kv kv) const {
|
| 1710 |
-
|
| 1711 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1712 |
}
|
| 1713 |
|
| 1714 |
std::string LLM_TN_IMPL::str() const {
|
|
|
|
| 200 |
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
| 201 |
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
| 202 |
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
|
|
|
|
| 203 |
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
|
| 204 |
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
|
| 205 |
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
|
|
|
|
| 1706 |
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
| 1707 |
|
| 1708 |
std::string LLM_KV::operator()(llm_kv kv) const {
|
| 1709 |
+
std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
| 1710 |
+
|
| 1711 |
+
if (suffix != nullptr) {
|
| 1712 |
+
name += ".";
|
| 1713 |
+
name += suffix;
|
| 1714 |
+
}
|
| 1715 |
+
|
| 1716 |
+
return name;
|
| 1717 |
}
|
| 1718 |
|
| 1719 |
std::string LLM_TN_IMPL::str() const {
|
examples/talk-llama/llama-arch.h
CHANGED
|
@@ -196,7 +196,6 @@ enum llm_kv {
|
|
| 196 |
LLM_KV_TOKENIZER_HF_JSON,
|
| 197 |
LLM_KV_TOKENIZER_RWKV,
|
| 198 |
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
|
| 199 |
-
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
|
| 200 |
LLM_KV_TOKENIZER_FIM_PRE_ID,
|
| 201 |
LLM_KV_TOKENIZER_FIM_SUF_ID,
|
| 202 |
LLM_KV_TOKENIZER_FIM_MID_ID,
|
|
|
|
| 196 |
LLM_KV_TOKENIZER_HF_JSON,
|
| 197 |
LLM_KV_TOKENIZER_RWKV,
|
| 198 |
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
|
|
|
|
| 199 |
LLM_KV_TOKENIZER_FIM_PRE_ID,
|
| 200 |
LLM_KV_TOKENIZER_FIM_SUF_ID,
|
| 201 |
LLM_KV_TOKENIZER_FIM_MID_ID,
|
examples/talk-llama/llama-context.cpp
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
| 4 |
#include "llama-io.h"
|
|
|
|
| 5 |
#include "llama-mmap.h"
|
| 6 |
#include "llama-model.h"
|
| 7 |
-
#include "llama-kv-cache.h"
|
| 8 |
|
| 9 |
#include <cinttypes>
|
| 10 |
#include <cstring>
|
|
@@ -123,7 +123,7 @@ llama_context::llama_context(
|
|
| 123 |
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
| 124 |
}
|
| 125 |
|
| 126 |
-
if (!params.swa_full && cparams.n_seq_max > 1) {
|
| 127 |
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
|
| 128 |
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
|
| 129 |
}
|
|
@@ -277,10 +277,9 @@ llama_context::llama_context(
|
|
| 277 |
int n_nodes_tg = -1;
|
| 278 |
|
| 279 |
// simulate full KV cache
|
| 280 |
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
| 281 |
|
| 282 |
-
const auto
|
| 283 |
-
if (!
|
| 284 |
throw std::runtime_error("failed to initialize KV cache");
|
| 285 |
}
|
| 286 |
|
|
@@ -288,7 +287,7 @@ llama_context::llama_context(
|
|
| 288 |
|
| 289 |
// reserve pp graph first so that buffers are only allocated once
|
| 290 |
{
|
| 291 |
-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens,
|
| 292 |
if (!gf) {
|
| 293 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
| 294 |
}
|
|
@@ -299,7 +298,7 @@ llama_context::llama_context(
|
|
| 299 |
|
| 300 |
// reserve with tg graph to get the number of splits and nodes
|
| 301 |
{
|
| 302 |
-
auto * gf = graph_reserve(1, 1, 1,
|
| 303 |
if (!gf) {
|
| 304 |
throw std::runtime_error("failed to allocate compute tg buffers");
|
| 305 |
}
|
|
@@ -310,7 +309,7 @@ llama_context::llama_context(
|
|
| 310 |
|
| 311 |
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
| 312 |
{
|
| 313 |
-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens,
|
| 314 |
if (!gf) {
|
| 315 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
| 316 |
}
|
|
@@ -419,40 +418,68 @@ uint32_t llama_context::n_threads_batch() const {
|
|
| 419 |
return cparams.n_threads_batch;
|
| 420 |
}
|
| 421 |
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
return kv_self;
|
| 425 |
}
|
| 426 |
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
}
|
| 431 |
|
| 432 |
-
|
|
|
|
| 433 |
if (!memory) {
|
| 434 |
return false;
|
| 435 |
}
|
| 436 |
|
| 437 |
-
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
throw std::runtime_error("failed to initialize KV cache");
|
| 448 |
}
|
| 449 |
|
| 450 |
-
|
| 451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
}
|
| 457 |
|
| 458 |
return true;
|
|
@@ -814,16 +841,17 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
| 814 |
} break;
|
| 815 |
case LLAMA_POOLING_TYPE_RANK:
|
| 816 |
{
|
| 817 |
-
// extract the rerank score -
|
| 818 |
auto & embd_seq_out = embd_seq;
|
|
|
|
| 819 |
|
| 820 |
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
| 821 |
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 822 |
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
| 823 |
continue;
|
| 824 |
}
|
| 825 |
-
embd_seq_out[seq_id].resize(
|
| 826 |
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
| 827 |
}
|
| 828 |
} break;
|
| 829 |
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
@@ -880,10 +908,8 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 880 |
}
|
| 881 |
}
|
| 882 |
|
| 883 |
-
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
| 884 |
-
|
| 885 |
// temporary allocate memory for the input batch if needed
|
| 886 |
-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 :
|
| 887 |
|
| 888 |
const llama_batch & batch = batch_allocr.batch;
|
| 889 |
|
|
@@ -940,42 +966,49 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 940 |
n_outputs_all = 1;
|
| 941 |
}
|
| 942 |
|
| 943 |
-
|
| 944 |
-
kv_self_update();
|
| 945 |
|
| 946 |
-
|
|
|
|
| 947 |
|
| 948 |
-
|
| 949 |
|
| 950 |
while (true) {
|
| 951 |
-
|
| 952 |
-
if (!
|
| 953 |
return -2;
|
| 954 |
}
|
| 955 |
|
| 956 |
-
switch (
|
| 957 |
case LLAMA_MEMORY_STATUS_SUCCESS:
|
| 958 |
{
|
| 959 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 960 |
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
| 961 |
{
|
| 962 |
-
if (!
|
| 963 |
-
|
| 964 |
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
|
| 968 |
|
| 969 |
continue;
|
| 970 |
}
|
| 971 |
}
|
| 972 |
|
| 973 |
-
LLAMA_LOG_WARN("%s: failed to find
|
| 974 |
|
| 975 |
return 1;
|
| 976 |
}
|
| 977 |
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
| 978 |
{
|
|
|
|
|
|
|
| 979 |
return -2;
|
| 980 |
}
|
| 981 |
}
|
|
@@ -992,7 +1025,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 992 |
int64_t n_outputs_prev = 0;
|
| 993 |
|
| 994 |
do {
|
| 995 |
-
const auto & ubatch =
|
| 996 |
|
| 997 |
// count the outputs in this u_batch
|
| 998 |
{
|
|
@@ -1015,11 +1048,14 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1015 |
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
| 1016 |
|
| 1017 |
ggml_status status;
|
| 1018 |
-
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER,
|
| 1019 |
|
| 1020 |
if (!res) {
|
| 1021 |
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
| 1022 |
-
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]
|
|
|
|
|
|
|
|
|
|
| 1023 |
|
| 1024 |
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
| 1025 |
const auto & seq_id = ubatch.seq_id[i][0];
|
|
@@ -1034,7 +1070,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1034 |
|
| 1035 |
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
| 1036 |
|
| 1037 |
-
|
| 1038 |
}
|
| 1039 |
|
| 1040 |
switch (status) {
|
|
@@ -1128,7 +1164,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1128 |
}
|
| 1129 |
|
| 1130 |
n_outputs_prev += n_outputs;
|
| 1131 |
-
} while (
|
| 1132 |
|
| 1133 |
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
| 1134 |
n_outputs = n_outputs_all;
|
|
@@ -1137,7 +1173,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1137 |
{
|
| 1138 |
bool sorted_output = true;
|
| 1139 |
|
| 1140 |
-
auto & out_ids =
|
| 1141 |
|
| 1142 |
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
|
| 1143 |
|
|
@@ -1189,11 +1225,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1189 |
// wait for the computation to finish (automatically done when obtaining the model output)
|
| 1190 |
//synchronize();
|
| 1191 |
|
| 1192 |
-
// decide if we need to defrag the kv cache
|
| 1193 |
-
if (cparams.defrag_thold > 0.0f) {
|
| 1194 |
-
kv_self->defrag_sched(cparams.defrag_thold);
|
| 1195 |
-
}
|
| 1196 |
-
|
| 1197 |
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
| 1198 |
// overlap with device computation.
|
| 1199 |
ggml_backend_sched_reset(sched.get());
|
|
@@ -1810,11 +1841,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
| 1810 |
}
|
| 1811 |
}
|
| 1812 |
|
| 1813 |
-
|
| 1814 |
-
|
| 1815 |
-
if (kv_self != nullptr) {
|
| 1816 |
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
| 1817 |
-
|
| 1818 |
}
|
| 1819 |
|
| 1820 |
return io.n_bytes();
|
|
@@ -1901,9 +1930,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
| 1901 |
if (memory) {
|
| 1902 |
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
| 1903 |
|
| 1904 |
-
|
| 1905 |
-
|
| 1906 |
-
kv_self->state_read(io);
|
| 1907 |
}
|
| 1908 |
|
| 1909 |
return io.n_bytes();
|
|
@@ -1913,9 +1940,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
|
|
| 1913 |
GGML_UNUSED(seq_id);
|
| 1914 |
|
| 1915 |
if (memory) {
|
| 1916 |
-
|
| 1917 |
-
|
| 1918 |
-
kv_self->state_write(io, seq_id);
|
| 1919 |
}
|
| 1920 |
|
| 1921 |
return io.n_bytes();
|
|
@@ -1925,9 +1950,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
|
|
| 1925 |
GGML_UNUSED(seq_id);
|
| 1926 |
|
| 1927 |
if (memory) {
|
| 1928 |
-
|
| 1929 |
-
|
| 1930 |
-
kv_self->state_read(io, seq_id);
|
| 1931 |
}
|
| 1932 |
|
| 1933 |
return io.n_bytes();
|
|
@@ -2032,9 +2055,7 @@ void llama_context::opt_epoch_iter(
|
|
| 2032 |
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
| 2033 |
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
| 2034 |
|
| 2035 |
-
|
| 2036 |
-
|
| 2037 |
-
kv_self->clear();
|
| 2038 |
|
| 2039 |
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
| 2040 |
batch.n_tokens = n_batch;
|
|
@@ -2057,8 +2078,8 @@ void llama_context::opt_epoch_iter(
|
|
| 2057 |
|
| 2058 |
int64_t n_outputs_all = n_tokens_all;
|
| 2059 |
|
| 2060 |
-
auto
|
| 2061 |
-
if (!
|
| 2062 |
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
| 2063 |
break;
|
| 2064 |
}
|
|
@@ -2071,17 +2092,17 @@ void llama_context::opt_epoch_iter(
|
|
| 2071 |
|
| 2072 |
uint32_t pos_batch = 0;
|
| 2073 |
do {
|
| 2074 |
-
const auto & ubatch =
|
| 2075 |
|
| 2076 |
n_outputs = ubatch.n_tokens;
|
| 2077 |
|
| 2078 |
-
if (!
|
| 2079 |
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
| 2080 |
break;
|
| 2081 |
}
|
| 2082 |
|
| 2083 |
auto * gf = graph_init();
|
| 2084 |
-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT,
|
| 2085 |
|
| 2086 |
struct ggml_context * ctx_compute_opt;
|
| 2087 |
{
|
|
@@ -2116,7 +2137,7 @@ void llama_context::opt_epoch_iter(
|
|
| 2116 |
ggml_free(ctx_compute_opt);
|
| 2117 |
|
| 2118 |
pos_batch += ubatch.n_tokens;
|
| 2119 |
-
} while (
|
| 2120 |
}
|
| 2121 |
}
|
| 2122 |
|
|
@@ -2277,13 +2298,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
|
| 2277 |
return &ctx->get_model();
|
| 2278 |
}
|
| 2279 |
|
|
|
|
| 2280 |
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
| 2281 |
-
return ctx->
|
| 2282 |
}
|
| 2283 |
|
| 2284 |
// deprecated
|
| 2285 |
void llama_kv_self_update(llama_context * ctx) {
|
| 2286 |
-
ctx->kv_self_update();
|
| 2287 |
}
|
| 2288 |
|
| 2289 |
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
|
@@ -2398,13 +2420,118 @@ int32_t llama_apply_adapter_cvec(
|
|
| 2398 |
return res ? 0 : -1;
|
| 2399 |
}
|
| 2400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2401 |
//
|
| 2402 |
// kv cache
|
| 2403 |
//
|
| 2404 |
|
| 2405 |
// deprecated
|
| 2406 |
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
| 2407 |
-
const auto * kv = ctx
|
| 2408 |
if (!kv) {
|
| 2409 |
return 0;
|
| 2410 |
}
|
|
@@ -2426,7 +2553,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
|
| 2426 |
// deprecated
|
| 2427 |
// note: this is the same as above - will be removed anyway, so it's ok
|
| 2428 |
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
| 2429 |
-
const auto * kv = ctx
|
| 2430 |
if (!kv) {
|
| 2431 |
return 0;
|
| 2432 |
}
|
|
@@ -2445,115 +2572,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
|
| 2445 |
return res;
|
| 2446 |
}
|
| 2447 |
|
|
|
|
| 2448 |
void llama_kv_self_clear(llama_context * ctx) {
|
| 2449 |
-
auto * kv = ctx
|
| 2450 |
if (!kv) {
|
| 2451 |
return;
|
| 2452 |
}
|
| 2453 |
|
| 2454 |
-
kv
|
| 2455 |
}
|
| 2456 |
|
|
|
|
| 2457 |
bool llama_kv_self_seq_rm(
|
| 2458 |
llama_context * ctx,
|
| 2459 |
llama_seq_id seq_id,
|
| 2460 |
llama_pos p0,
|
| 2461 |
llama_pos p1) {
|
| 2462 |
-
auto * kv = ctx
|
| 2463 |
if (!kv) {
|
| 2464 |
return true;
|
| 2465 |
}
|
| 2466 |
|
| 2467 |
-
return kv
|
| 2468 |
}
|
| 2469 |
|
|
|
|
| 2470 |
void llama_kv_self_seq_cp(
|
| 2471 |
llama_context * ctx,
|
| 2472 |
llama_seq_id seq_id_src,
|
| 2473 |
llama_seq_id seq_id_dst,
|
| 2474 |
llama_pos p0,
|
| 2475 |
llama_pos p1) {
|
| 2476 |
-
auto * kv = ctx
|
| 2477 |
if (!kv) {
|
| 2478 |
return;
|
| 2479 |
}
|
| 2480 |
|
| 2481 |
-
kv
|
| 2482 |
}
|
| 2483 |
|
|
|
|
| 2484 |
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
| 2485 |
-
auto * kv = ctx
|
| 2486 |
if (!kv) {
|
| 2487 |
return;
|
| 2488 |
}
|
| 2489 |
|
| 2490 |
-
kv
|
| 2491 |
}
|
| 2492 |
|
|
|
|
| 2493 |
void llama_kv_self_seq_add(
|
| 2494 |
llama_context * ctx,
|
| 2495 |
llama_seq_id seq_id,
|
| 2496 |
llama_pos p0,
|
| 2497 |
llama_pos p1,
|
| 2498 |
llama_pos delta) {
|
| 2499 |
-
auto * kv = ctx
|
| 2500 |
if (!kv) {
|
| 2501 |
return;
|
| 2502 |
}
|
| 2503 |
|
| 2504 |
-
kv
|
| 2505 |
}
|
| 2506 |
|
|
|
|
| 2507 |
void llama_kv_self_seq_div(
|
| 2508 |
llama_context * ctx,
|
| 2509 |
llama_seq_id seq_id,
|
| 2510 |
llama_pos p0,
|
| 2511 |
llama_pos p1,
|
| 2512 |
int d) {
|
| 2513 |
-
auto * kv = ctx
|
| 2514 |
if (!kv) {
|
| 2515 |
return;
|
| 2516 |
}
|
| 2517 |
|
| 2518 |
-
kv
|
| 2519 |
}
|
| 2520 |
|
|
|
|
| 2521 |
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
| 2522 |
-
|
| 2523 |
if (!kv) {
|
| 2524 |
return -1;
|
| 2525 |
}
|
| 2526 |
|
| 2527 |
-
return kv
|
| 2528 |
}
|
| 2529 |
|
|
|
|
| 2530 |
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
| 2531 |
-
|
| 2532 |
if (!kv) {
|
| 2533 |
return -1;
|
| 2534 |
}
|
| 2535 |
|
| 2536 |
-
return kv
|
| 2537 |
}
|
| 2538 |
|
| 2539 |
// deprecated
|
| 2540 |
void llama_kv_self_defrag(llama_context * ctx) {
|
| 2541 |
-
auto * kv = ctx->get_kv_self();
|
| 2542 |
-
if (!kv) {
|
| 2543 |
-
return;
|
| 2544 |
-
}
|
| 2545 |
-
|
| 2546 |
// force defrag
|
| 2547 |
-
|
| 2548 |
}
|
| 2549 |
|
|
|
|
| 2550 |
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
| 2551 |
-
|
| 2552 |
if (!kv) {
|
| 2553 |
return false;
|
| 2554 |
}
|
| 2555 |
|
| 2556 |
-
return kv
|
| 2557 |
}
|
| 2558 |
|
| 2559 |
// llama state API
|
|
|
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
| 4 |
#include "llama-io.h"
|
| 5 |
+
#include "llama-memory.h"
|
| 6 |
#include "llama-mmap.h"
|
| 7 |
#include "llama-model.h"
|
|
|
|
| 8 |
|
| 9 |
#include <cinttypes>
|
| 10 |
#include <cstring>
|
|
|
|
| 123 |
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
| 124 |
}
|
| 125 |
|
| 126 |
+
if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
|
| 127 |
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
|
| 128 |
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
|
| 129 |
}
|
|
|
|
| 277 |
int n_nodes_tg = -1;
|
| 278 |
|
| 279 |
// simulate full KV cache
|
|
|
|
| 280 |
|
| 281 |
+
const auto mstate = memory->init_full();
|
| 282 |
+
if (!mstate) {
|
| 283 |
throw std::runtime_error("failed to initialize KV cache");
|
| 284 |
}
|
| 285 |
|
|
|
|
| 287 |
|
| 288 |
// reserve pp graph first so that buffers are only allocated once
|
| 289 |
{
|
| 290 |
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
| 291 |
if (!gf) {
|
| 292 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
| 293 |
}
|
|
|
|
| 298 |
|
| 299 |
// reserve with tg graph to get the number of splits and nodes
|
| 300 |
{
|
| 301 |
+
auto * gf = graph_reserve(1, 1, 1, mstate.get());
|
| 302 |
if (!gf) {
|
| 303 |
throw std::runtime_error("failed to allocate compute tg buffers");
|
| 304 |
}
|
|
|
|
| 309 |
|
| 310 |
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
| 311 |
{
|
| 312 |
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
| 313 |
if (!gf) {
|
| 314 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
| 315 |
}
|
|
|
|
| 418 |
return cparams.n_threads_batch;
|
| 419 |
}
|
| 420 |
|
| 421 |
+
llama_memory_t llama_context::get_memory() const {
|
| 422 |
+
return memory.get();
|
|
|
|
| 423 |
}
|
| 424 |
|
| 425 |
+
// deprecated
|
| 426 |
+
void llama_context::kv_self_defrag_sched() {
|
| 427 |
+
if (!memory) {
|
| 428 |
+
return;
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
memory_force_optimize = true;
|
| 432 |
}
|
| 433 |
|
| 434 |
+
// deprecated
|
| 435 |
+
bool llama_context::kv_self_update(bool optimize) {
|
| 436 |
if (!memory) {
|
| 437 |
return false;
|
| 438 |
}
|
| 439 |
|
| 440 |
+
{
|
| 441 |
+
// TODO: remove in the future
|
| 442 |
+
optimize |= memory_force_optimize;
|
| 443 |
+
memory_force_optimize = false;
|
| 444 |
|
| 445 |
+
const auto mstate = memory->init_update(this, optimize);
|
| 446 |
+
switch (mstate->get_status()) {
|
| 447 |
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
| 448 |
+
{
|
| 449 |
+
// noop
|
| 450 |
+
} break;
|
| 451 |
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
| 452 |
+
{
|
| 453 |
+
// no updates need to be performed
|
| 454 |
+
return false;
|
| 455 |
+
}
|
| 456 |
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
| 457 |
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
| 458 |
+
{
|
| 459 |
+
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
|
| 460 |
+
return false;
|
| 461 |
+
}
|
| 462 |
+
}
|
| 463 |
|
| 464 |
+
if (!mstate->apply()) {
|
| 465 |
+
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
| 466 |
+
}
|
|
|
|
| 467 |
}
|
| 468 |
|
| 469 |
+
// if the memory module did any computation, we have to reserve a new worst-case graph
|
| 470 |
+
{
|
| 471 |
+
const auto mstate = memory->init_full();
|
| 472 |
+
if (!mstate) {
|
| 473 |
+
throw std::runtime_error("failed to initialize memory state");
|
| 474 |
+
}
|
| 475 |
|
| 476 |
+
const uint32_t n_seqs = cparams.n_seq_max;
|
| 477 |
+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
| 478 |
+
|
| 479 |
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
| 480 |
+
if (!gf) {
|
| 481 |
+
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
| 482 |
+
}
|
| 483 |
}
|
| 484 |
|
| 485 |
return true;
|
|
|
|
| 841 |
} break;
|
| 842 |
case LLAMA_POOLING_TYPE_RANK:
|
| 843 |
{
|
| 844 |
+
// extract the rerank score - n_cls_out floats per sequence
|
| 845 |
auto & embd_seq_out = embd_seq;
|
| 846 |
+
const uint32_t n_cls_out = hparams.n_cls_out;
|
| 847 |
|
| 848 |
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
| 849 |
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 850 |
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
| 851 |
continue;
|
| 852 |
}
|
| 853 |
+
embd_seq_out[seq_id].resize(n_cls_out);
|
| 854 |
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
|
| 855 |
}
|
| 856 |
} break;
|
| 857 |
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
|
|
| 908 |
}
|
| 909 |
}
|
| 910 |
|
|
|
|
|
|
|
| 911 |
// temporary allocate memory for the input batch if needed
|
| 912 |
+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
|
| 913 |
|
| 914 |
const llama_batch & batch = batch_allocr.batch;
|
| 915 |
|
|
|
|
| 966 |
n_outputs_all = 1;
|
| 967 |
}
|
| 968 |
|
| 969 |
+
bool did_optimize = false;
|
|
|
|
| 970 |
|
| 971 |
+
// handle any pending defrags/shifts
|
| 972 |
+
kv_self_update(false);
|
| 973 |
|
| 974 |
+
llama_memory_state_ptr mstate;
|
| 975 |
|
| 976 |
while (true) {
|
| 977 |
+
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
| 978 |
+
if (!mstate) {
|
| 979 |
return -2;
|
| 980 |
}
|
| 981 |
|
| 982 |
+
switch (mstate->get_status()) {
|
| 983 |
case LLAMA_MEMORY_STATUS_SUCCESS:
|
| 984 |
{
|
| 985 |
} break;
|
| 986 |
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
| 987 |
+
{
|
| 988 |
+
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
|
| 989 |
+
|
| 990 |
+
return -2;
|
| 991 |
+
}
|
| 992 |
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
| 993 |
{
|
| 994 |
+
if (!did_optimize) {
|
| 995 |
+
did_optimize = true;
|
| 996 |
|
| 997 |
+
if (kv_self_update(true)) {
|
| 998 |
+
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
|
|
|
|
| 999 |
|
| 1000 |
continue;
|
| 1001 |
}
|
| 1002 |
}
|
| 1003 |
|
| 1004 |
+
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
|
| 1005 |
|
| 1006 |
return 1;
|
| 1007 |
}
|
| 1008 |
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
| 1009 |
{
|
| 1010 |
+
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
|
| 1011 |
+
|
| 1012 |
return -2;
|
| 1013 |
}
|
| 1014 |
}
|
|
|
|
| 1025 |
int64_t n_outputs_prev = 0;
|
| 1026 |
|
| 1027 |
do {
|
| 1028 |
+
const auto & ubatch = mstate->get_ubatch();
|
| 1029 |
|
| 1030 |
// count the outputs in this u_batch
|
| 1031 |
{
|
|
|
|
| 1048 |
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
| 1049 |
|
| 1050 |
ggml_status status;
|
| 1051 |
+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
|
| 1052 |
|
| 1053 |
if (!res) {
|
| 1054 |
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
| 1055 |
+
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
|
| 1056 |
+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
| 1057 |
+
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
| 1058 |
+
}
|
| 1059 |
|
| 1060 |
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
| 1061 |
const auto & seq_id = ubatch.seq_id[i][0];
|
|
|
|
| 1070 |
|
| 1071 |
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
| 1072 |
|
| 1073 |
+
memory->seq_rm(s, pos_min[s], -1);
|
| 1074 |
}
|
| 1075 |
|
| 1076 |
switch (status) {
|
|
|
|
| 1164 |
}
|
| 1165 |
|
| 1166 |
n_outputs_prev += n_outputs;
|
| 1167 |
+
} while (mstate->next());
|
| 1168 |
|
| 1169 |
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
| 1170 |
n_outputs = n_outputs_all;
|
|
|
|
| 1173 |
{
|
| 1174 |
bool sorted_output = true;
|
| 1175 |
|
| 1176 |
+
auto & out_ids = mstate->out_ids();
|
| 1177 |
|
| 1178 |
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
|
| 1179 |
|
|
|
|
| 1225 |
// wait for the computation to finish (automatically done when obtaining the model output)
|
| 1226 |
//synchronize();
|
| 1227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1228 |
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
| 1229 |
// overlap with device computation.
|
| 1230 |
ggml_backend_sched_reset(sched.get());
|
|
|
|
| 1841 |
}
|
| 1842 |
}
|
| 1843 |
|
| 1844 |
+
if (memory != nullptr) {
|
|
|
|
|
|
|
| 1845 |
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
| 1846 |
+
memory->state_write(io);
|
| 1847 |
}
|
| 1848 |
|
| 1849 |
return io.n_bytes();
|
|
|
|
| 1930 |
if (memory) {
|
| 1931 |
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
| 1932 |
|
| 1933 |
+
memory->state_read(io);
|
|
|
|
|
|
|
| 1934 |
}
|
| 1935 |
|
| 1936 |
return io.n_bytes();
|
|
|
|
| 1940 |
GGML_UNUSED(seq_id);
|
| 1941 |
|
| 1942 |
if (memory) {
|
| 1943 |
+
memory->state_write(io, seq_id);
|
|
|
|
|
|
|
| 1944 |
}
|
| 1945 |
|
| 1946 |
return io.n_bytes();
|
|
|
|
| 1950 |
GGML_UNUSED(seq_id);
|
| 1951 |
|
| 1952 |
if (memory) {
|
| 1953 |
+
memory->state_read(io, seq_id);
|
|
|
|
|
|
|
| 1954 |
}
|
| 1955 |
|
| 1956 |
return io.n_bytes();
|
|
|
|
| 2055 |
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
| 2056 |
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
| 2057 |
|
| 2058 |
+
memory->clear(true);
|
|
|
|
|
|
|
| 2059 |
|
| 2060 |
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
| 2061 |
batch.n_tokens = n_batch;
|
|
|
|
| 2078 |
|
| 2079 |
int64_t n_outputs_all = n_tokens_all;
|
| 2080 |
|
| 2081 |
+
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
|
| 2082 |
+
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
| 2083 |
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
| 2084 |
break;
|
| 2085 |
}
|
|
|
|
| 2092 |
|
| 2093 |
uint32_t pos_batch = 0;
|
| 2094 |
do {
|
| 2095 |
+
const auto & ubatch = mstate->get_ubatch();
|
| 2096 |
|
| 2097 |
n_outputs = ubatch.n_tokens;
|
| 2098 |
|
| 2099 |
+
if (!mstate->apply()) {
|
| 2100 |
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
| 2101 |
break;
|
| 2102 |
}
|
| 2103 |
|
| 2104 |
auto * gf = graph_init();
|
| 2105 |
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
|
| 2106 |
|
| 2107 |
struct ggml_context * ctx_compute_opt;
|
| 2108 |
{
|
|
|
|
| 2137 |
ggml_free(ctx_compute_opt);
|
| 2138 |
|
| 2139 |
pos_batch += ubatch.n_tokens;
|
| 2140 |
+
} while (mstate->next());
|
| 2141 |
}
|
| 2142 |
}
|
| 2143 |
|
|
|
|
| 2298 |
return &ctx->get_model();
|
| 2299 |
}
|
| 2300 |
|
| 2301 |
+
// deprecated
|
| 2302 |
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
| 2303 |
+
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
|
| 2304 |
}
|
| 2305 |
|
| 2306 |
// deprecated
|
| 2307 |
void llama_kv_self_update(llama_context * ctx) {
|
| 2308 |
+
ctx->kv_self_update(false);
|
| 2309 |
}
|
| 2310 |
|
| 2311 |
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
|
|
|
| 2420 |
return res ? 0 : -1;
|
| 2421 |
}
|
| 2422 |
|
| 2423 |
+
//
|
| 2424 |
+
// memory
|
| 2425 |
+
//
|
| 2426 |
+
|
| 2427 |
+
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
| 2428 |
+
return ctx->get_memory();
|
| 2429 |
+
}
|
| 2430 |
+
|
| 2431 |
+
void llama_memory_clear(llama_memory_t mem, bool data) {
|
| 2432 |
+
if (!mem) {
|
| 2433 |
+
return;
|
| 2434 |
+
}
|
| 2435 |
+
|
| 2436 |
+
mem->clear(data);
|
| 2437 |
+
}
|
| 2438 |
+
|
| 2439 |
+
bool llama_memory_seq_rm(
|
| 2440 |
+
llama_memory_t mem,
|
| 2441 |
+
llama_seq_id seq_id,
|
| 2442 |
+
llama_pos p0,
|
| 2443 |
+
llama_pos p1) {
|
| 2444 |
+
if (!mem) {
|
| 2445 |
+
return true;
|
| 2446 |
+
}
|
| 2447 |
+
|
| 2448 |
+
return mem->seq_rm(seq_id, p0, p1);
|
| 2449 |
+
}
|
| 2450 |
+
|
| 2451 |
+
void llama_memory_seq_cp(
|
| 2452 |
+
llama_memory_t mem,
|
| 2453 |
+
llama_seq_id seq_id_src,
|
| 2454 |
+
llama_seq_id seq_id_dst,
|
| 2455 |
+
llama_pos p0,
|
| 2456 |
+
llama_pos p1) {
|
| 2457 |
+
if (!mem) {
|
| 2458 |
+
return;
|
| 2459 |
+
}
|
| 2460 |
+
|
| 2461 |
+
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
| 2462 |
+
}
|
| 2463 |
+
|
| 2464 |
+
void llama_memory_seq_keep(
|
| 2465 |
+
llama_memory_t mem,
|
| 2466 |
+
llama_seq_id seq_id) {
|
| 2467 |
+
if (!mem) {
|
| 2468 |
+
return;
|
| 2469 |
+
}
|
| 2470 |
+
|
| 2471 |
+
mem->seq_keep(seq_id);
|
| 2472 |
+
}
|
| 2473 |
+
|
| 2474 |
+
void llama_memory_seq_add(
|
| 2475 |
+
llama_memory_t mem,
|
| 2476 |
+
llama_seq_id seq_id,
|
| 2477 |
+
llama_pos p0,
|
| 2478 |
+
llama_pos p1,
|
| 2479 |
+
llama_pos delta) {
|
| 2480 |
+
if (!mem) {
|
| 2481 |
+
return;
|
| 2482 |
+
}
|
| 2483 |
+
|
| 2484 |
+
mem->seq_add(seq_id, p0, p1, delta);
|
| 2485 |
+
}
|
| 2486 |
+
|
| 2487 |
+
void llama_memory_seq_div(
|
| 2488 |
+
llama_memory_t mem,
|
| 2489 |
+
llama_seq_id seq_id,
|
| 2490 |
+
llama_pos p0,
|
| 2491 |
+
llama_pos p1,
|
| 2492 |
+
int d) {
|
| 2493 |
+
if (!mem) {
|
| 2494 |
+
return;
|
| 2495 |
+
}
|
| 2496 |
+
|
| 2497 |
+
mem->seq_div(seq_id, p0, p1, d);
|
| 2498 |
+
}
|
| 2499 |
+
|
| 2500 |
+
llama_pos llama_memory_seq_pos_min(
|
| 2501 |
+
llama_memory_t mem,
|
| 2502 |
+
llama_seq_id seq_id) {
|
| 2503 |
+
if (!mem) {
|
| 2504 |
+
return -1;
|
| 2505 |
+
}
|
| 2506 |
+
|
| 2507 |
+
return mem->seq_pos_min(seq_id);
|
| 2508 |
+
}
|
| 2509 |
+
|
| 2510 |
+
llama_pos llama_memory_seq_pos_max(
|
| 2511 |
+
llama_memory_t mem,
|
| 2512 |
+
llama_seq_id seq_id) {
|
| 2513 |
+
if (!mem) {
|
| 2514 |
+
return -1;
|
| 2515 |
+
}
|
| 2516 |
+
|
| 2517 |
+
return mem->seq_pos_max(seq_id);
|
| 2518 |
+
}
|
| 2519 |
+
|
| 2520 |
+
bool llama_memory_can_shift(llama_memory_t mem) {
|
| 2521 |
+
if (!mem) {
|
| 2522 |
+
return false;
|
| 2523 |
+
}
|
| 2524 |
+
|
| 2525 |
+
return mem->get_can_shift();
|
| 2526 |
+
}
|
| 2527 |
+
|
| 2528 |
//
|
| 2529 |
// kv cache
|
| 2530 |
//
|
| 2531 |
|
| 2532 |
// deprecated
|
| 2533 |
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
| 2534 |
+
const auto * kv = llama_get_memory(ctx);
|
| 2535 |
if (!kv) {
|
| 2536 |
return 0;
|
| 2537 |
}
|
|
|
|
| 2553 |
// deprecated
|
| 2554 |
// note: this is the same as above - will be removed anyway, so it's ok
|
| 2555 |
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
| 2556 |
+
const auto * kv = llama_get_memory(ctx);
|
| 2557 |
if (!kv) {
|
| 2558 |
return 0;
|
| 2559 |
}
|
|
|
|
| 2572 |
return res;
|
| 2573 |
}
|
| 2574 |
|
| 2575 |
+
// deprecated
|
| 2576 |
void llama_kv_self_clear(llama_context * ctx) {
|
| 2577 |
+
auto * kv = llama_get_memory(ctx);
|
| 2578 |
if (!kv) {
|
| 2579 |
return;
|
| 2580 |
}
|
| 2581 |
|
| 2582 |
+
llama_memory_clear(kv, true);
|
| 2583 |
}
|
| 2584 |
|
| 2585 |
+
// deprecated
|
| 2586 |
bool llama_kv_self_seq_rm(
|
| 2587 |
llama_context * ctx,
|
| 2588 |
llama_seq_id seq_id,
|
| 2589 |
llama_pos p0,
|
| 2590 |
llama_pos p1) {
|
| 2591 |
+
auto * kv = llama_get_memory(ctx);
|
| 2592 |
if (!kv) {
|
| 2593 |
return true;
|
| 2594 |
}
|
| 2595 |
|
| 2596 |
+
return llama_memory_seq_rm(kv, seq_id, p0, p1);
|
| 2597 |
}
|
| 2598 |
|
| 2599 |
+
// deprecated
|
| 2600 |
void llama_kv_self_seq_cp(
|
| 2601 |
llama_context * ctx,
|
| 2602 |
llama_seq_id seq_id_src,
|
| 2603 |
llama_seq_id seq_id_dst,
|
| 2604 |
llama_pos p0,
|
| 2605 |
llama_pos p1) {
|
| 2606 |
+
auto * kv = llama_get_memory(ctx);
|
| 2607 |
if (!kv) {
|
| 2608 |
return;
|
| 2609 |
}
|
| 2610 |
|
| 2611 |
+
llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
|
| 2612 |
}
|
| 2613 |
|
| 2614 |
+
// deprecated
|
| 2615 |
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
| 2616 |
+
auto * kv = llama_get_memory(ctx);
|
| 2617 |
if (!kv) {
|
| 2618 |
return;
|
| 2619 |
}
|
| 2620 |
|
| 2621 |
+
llama_memory_seq_keep(kv, seq_id);
|
| 2622 |
}
|
| 2623 |
|
| 2624 |
+
// deprecated
|
| 2625 |
void llama_kv_self_seq_add(
|
| 2626 |
llama_context * ctx,
|
| 2627 |
llama_seq_id seq_id,
|
| 2628 |
llama_pos p0,
|
| 2629 |
llama_pos p1,
|
| 2630 |
llama_pos delta) {
|
| 2631 |
+
auto * kv = llama_get_memory(ctx);
|
| 2632 |
if (!kv) {
|
| 2633 |
return;
|
| 2634 |
}
|
| 2635 |
|
| 2636 |
+
llama_memory_seq_add(kv, seq_id, p0, p1, delta);
|
| 2637 |
}
|
| 2638 |
|
| 2639 |
+
// deprecated
|
| 2640 |
void llama_kv_self_seq_div(
|
| 2641 |
llama_context * ctx,
|
| 2642 |
llama_seq_id seq_id,
|
| 2643 |
llama_pos p0,
|
| 2644 |
llama_pos p1,
|
| 2645 |
int d) {
|
| 2646 |
+
auto * kv = llama_get_memory(ctx);
|
| 2647 |
if (!kv) {
|
| 2648 |
return;
|
| 2649 |
}
|
| 2650 |
|
| 2651 |
+
llama_memory_seq_div(kv, seq_id, p0, p1, d);
|
| 2652 |
}
|
| 2653 |
|
| 2654 |
+
// deprecated
|
| 2655 |
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
| 2656 |
+
auto * kv = llama_get_memory(ctx);
|
| 2657 |
if (!kv) {
|
| 2658 |
return -1;
|
| 2659 |
}
|
| 2660 |
|
| 2661 |
+
return llama_memory_seq_pos_min(kv, seq_id);
|
| 2662 |
}
|
| 2663 |
|
| 2664 |
+
// deprecated
|
| 2665 |
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
| 2666 |
+
auto * kv = llama_get_memory(ctx);
|
| 2667 |
if (!kv) {
|
| 2668 |
return -1;
|
| 2669 |
}
|
| 2670 |
|
| 2671 |
+
return llama_memory_seq_pos_max(kv, seq_id);
|
| 2672 |
}
|
| 2673 |
|
| 2674 |
// deprecated
|
| 2675 |
void llama_kv_self_defrag(llama_context * ctx) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2676 |
// force defrag
|
| 2677 |
+
ctx->kv_self_defrag_sched();
|
| 2678 |
}
|
| 2679 |
|
| 2680 |
+
// deprecated
|
| 2681 |
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
| 2682 |
+
auto * kv = llama_get_memory(ctx);
|
| 2683 |
if (!kv) {
|
| 2684 |
return false;
|
| 2685 |
}
|
| 2686 |
|
| 2687 |
+
return llama_memory_can_shift(kv);
|
| 2688 |
}
|
| 2689 |
|
| 2690 |
// llama state API
|
examples/talk-llama/llama-context.h
CHANGED
|
@@ -13,13 +13,12 @@
|
|
| 13 |
#include <vector>
|
| 14 |
|
| 15 |
struct llama_model;
|
| 16 |
-
struct llama_kv_cache;
|
| 17 |
|
| 18 |
class llama_io_read_i;
|
| 19 |
class llama_io_write_i;
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
|
| 24 |
struct llama_context {
|
| 25 |
// init scheduler and compute buffers, reserve worst-case graphs
|
|
@@ -47,12 +46,12 @@ struct llama_context {
|
|
| 47 |
uint32_t n_threads() const;
|
| 48 |
uint32_t n_threads_batch() const;
|
| 49 |
|
| 50 |
-
|
| 51 |
-
const llama_kv_cache * get_kv_self() const;
|
| 52 |
|
| 53 |
// return true of the KV cache was updated
|
| 54 |
// TODO: remove
|
| 55 |
-
bool kv_self_update();
|
|
|
|
| 56 |
|
| 57 |
enum llama_pooling_type pooling_type() const;
|
| 58 |
|
|
@@ -231,6 +230,9 @@ private:
|
|
| 231 |
|
| 232 |
std::unique_ptr<llama_memory_i> memory;
|
| 233 |
|
|
|
|
|
|
|
|
|
|
| 234 |
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
| 235 |
size_t logits_size = 0; // capacity (of floats) for logits
|
| 236 |
float * logits = nullptr;
|
|
|
|
| 13 |
#include <vector>
|
| 14 |
|
| 15 |
struct llama_model;
|
|
|
|
| 16 |
|
| 17 |
class llama_io_read_i;
|
| 18 |
class llama_io_write_i;
|
| 19 |
|
| 20 |
+
struct llama_memory_i;
|
| 21 |
+
struct llama_memory_state_i;
|
| 22 |
|
| 23 |
struct llama_context {
|
| 24 |
// init scheduler and compute buffers, reserve worst-case graphs
|
|
|
|
| 46 |
uint32_t n_threads() const;
|
| 47 |
uint32_t n_threads_batch() const;
|
| 48 |
|
| 49 |
+
llama_memory_t get_memory() const;
|
|
|
|
| 50 |
|
| 51 |
// return true of the KV cache was updated
|
| 52 |
// TODO: remove
|
| 53 |
+
bool kv_self_update(bool optimize);
|
| 54 |
+
void kv_self_defrag_sched();
|
| 55 |
|
| 56 |
enum llama_pooling_type pooling_type() const;
|
| 57 |
|
|
|
|
| 230 |
|
| 231 |
std::unique_ptr<llama_memory_i> memory;
|
| 232 |
|
| 233 |
+
// TODO: temporary, until the llama_kv_self_defrag() API is removed
|
| 234 |
+
bool memory_force_optimize = false;
|
| 235 |
+
|
| 236 |
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
| 237 |
size_t logits_size = 0; // capacity (of floats) for logits
|
| 238 |
float * logits = nullptr;
|
examples/talk-llama/llama-graph.cpp
CHANGED
|
@@ -659,6 +659,20 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
| 659 |
cur = ggml_mul(ctx0, x0, x1);
|
| 660 |
cb(cur, "ffn_mul", il);
|
| 661 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
}
|
| 663 |
|
| 664 |
if (gate && type_gate == LLM_FFN_PAR) {
|
|
@@ -769,9 +783,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
| 769 |
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
| 770 |
|
| 771 |
if (weight_before_ffn) {
|
| 772 |
-
//
|
| 773 |
-
ggml_tensor * repeated =
|
| 774 |
-
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
|
| 775 |
cur = ggml_mul(ctx0, repeated, weights);
|
| 776 |
cb(cur, "ffn_moe_weighted", il);
|
| 777 |
}
|
|
|
|
| 659 |
cur = ggml_mul(ctx0, x0, x1);
|
| 660 |
cb(cur, "ffn_mul", il);
|
| 661 |
} break;
|
| 662 |
+
case LLM_FFN_GEGLU:
|
| 663 |
+
{
|
| 664 |
+
// Split into two equal parts
|
| 665 |
+
int64_t split_point = cur->ne[0] / 2;
|
| 666 |
+
// TODO: these conts should not be needed
|
| 667 |
+
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
| 668 |
+
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
| 669 |
+
|
| 670 |
+
x0 = ggml_gelu(ctx0, x0);
|
| 671 |
+
cb(x0, "ffn_gelu", il);
|
| 672 |
+
|
| 673 |
+
cur = ggml_mul(ctx0, x0, x1);
|
| 674 |
+
cb(cur, "ffn_geglu", il);
|
| 675 |
+
} break;
|
| 676 |
}
|
| 677 |
|
| 678 |
if (gate && type_gate == LLM_FFN_PAR) {
|
|
|
|
| 783 |
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
| 784 |
|
| 785 |
if (weight_before_ffn) {
|
| 786 |
+
// repeat cur to [n_embd, n_expert_used, n_tokens]
|
| 787 |
+
ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
|
|
|
|
| 788 |
cur = ggml_mul(ctx0, repeated, weights);
|
| 789 |
cb(cur, "ffn_moe_weighted", il);
|
| 790 |
}
|
examples/talk-llama/llama-graph.h
CHANGED
|
@@ -17,7 +17,7 @@ struct ggml_tensor;
|
|
| 17 |
struct llama_ubatch;
|
| 18 |
struct llama_cparams;
|
| 19 |
|
| 20 |
-
|
| 21 |
|
| 22 |
class llama_kv_cache_unified_state;
|
| 23 |
class llama_kv_cache_unified_iswa_state;
|
|
@@ -36,6 +36,7 @@ enum llm_ffn_op_type {
|
|
| 36 |
LLM_FFN_RELU,
|
| 37 |
LLM_FFN_RELU_SQR,
|
| 38 |
LLM_FFN_SWIGLU,
|
|
|
|
| 39 |
};
|
| 40 |
|
| 41 |
enum llm_ffn_gate_type {
|
|
|
|
| 17 |
struct llama_ubatch;
|
| 18 |
struct llama_cparams;
|
| 19 |
|
| 20 |
+
struct llama_memory_state_i;
|
| 21 |
|
| 22 |
class llama_kv_cache_unified_state;
|
| 23 |
class llama_kv_cache_unified_iswa_state;
|
|
|
|
| 36 |
LLM_FFN_RELU,
|
| 37 |
LLM_FFN_RELU_SQR,
|
| 38 |
LLM_FFN_SWIGLU,
|
| 39 |
+
LLM_FFN_GEGLU,
|
| 40 |
};
|
| 41 |
|
| 42 |
enum llm_ffn_gate_type {
|
examples/talk-llama/llama-kv-cache-recurrent.cpp
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
#include "llama-kv-cache-recurrent.h"
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
|
|
|
| 4 |
#include "llama-batch.h"
|
| 5 |
#include "llama-model.h"
|
| 6 |
|
|
@@ -116,18 +117,21 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|
| 116 |
}
|
| 117 |
}
|
| 118 |
|
| 119 |
-
void llama_kv_cache_recurrent::clear() {
|
| 120 |
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
| 121 |
cells[i].pos = -1;
|
| 122 |
cells[i].seq_id.clear();
|
| 123 |
cells[i].src = -1;
|
| 124 |
cells[i].tail = -1;
|
| 125 |
}
|
|
|
|
| 126 |
head = 0;
|
| 127 |
used = 0;
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
| 131 |
}
|
| 132 |
}
|
| 133 |
|
|
@@ -386,6 +390,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
|
|
| 386 |
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
| 387 |
}
|
| 388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
| 390 |
// simply remember the full state because it is very small for this type of cache
|
| 391 |
// TODO: optimize
|
|
@@ -419,17 +430,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
|
| 419 |
return success;
|
| 420 |
}
|
| 421 |
|
| 422 |
-
bool llama_kv_cache_recurrent::update(llama_context & lctx) {
|
| 423 |
-
GGML_UNUSED(lctx);
|
| 424 |
-
// noop
|
| 425 |
-
return false;
|
| 426 |
-
}
|
| 427 |
-
|
| 428 |
-
void llama_kv_cache_recurrent::defrag_sched(float thold) {
|
| 429 |
-
GGML_UNUSED(thold);
|
| 430 |
-
// noop
|
| 431 |
-
}
|
| 432 |
-
|
| 433 |
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
| 434 |
const uint32_t n_tokens = ubatch.n_tokens;
|
| 435 |
const uint32_t n_seqs = ubatch.n_seqs;
|
|
@@ -726,7 +726,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
|
|
| 726 |
|
| 727 |
if (!res) {
|
| 728 |
if (seq_id == -1) {
|
| 729 |
-
clear();
|
| 730 |
} else {
|
| 731 |
seq_rm(seq_id, -1, -1);
|
| 732 |
}
|
|
@@ -883,7 +883,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
| 883 |
return false;
|
| 884 |
}
|
| 885 |
|
| 886 |
-
clear();
|
| 887 |
|
| 888 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 889 |
kv_cell & cell = cells[i];
|
|
|
|
| 1 |
#include "llama-kv-cache-recurrent.h"
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
| 4 |
+
#include "llama-io.h"
|
| 5 |
#include "llama-batch.h"
|
| 6 |
#include "llama-model.h"
|
| 7 |
|
|
|
|
| 117 |
}
|
| 118 |
}
|
| 119 |
|
| 120 |
+
void llama_kv_cache_recurrent::clear(bool data) {
|
| 121 |
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
| 122 |
cells[i].pos = -1;
|
| 123 |
cells[i].seq_id.clear();
|
| 124 |
cells[i].src = -1;
|
| 125 |
cells[i].tail = -1;
|
| 126 |
}
|
| 127 |
+
|
| 128 |
head = 0;
|
| 129 |
used = 0;
|
| 130 |
|
| 131 |
+
if (data) {
|
| 132 |
+
for (auto & buf : bufs) {
|
| 133 |
+
ggml_backend_buffer_clear(buf.get(), 0);
|
| 134 |
+
}
|
| 135 |
}
|
| 136 |
}
|
| 137 |
|
|
|
|
| 390 |
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
| 391 |
}
|
| 392 |
|
| 393 |
+
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
|
| 394 |
+
GGML_UNUSED(lctx);
|
| 395 |
+
GGML_UNUSED(optimize);
|
| 396 |
+
|
| 397 |
+
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
| 401 |
// simply remember the full state because it is very small for this type of cache
|
| 402 |
// TODO: optimize
|
|
|
|
| 430 |
return success;
|
| 431 |
}
|
| 432 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
| 434 |
const uint32_t n_tokens = ubatch.n_tokens;
|
| 435 |
const uint32_t n_seqs = ubatch.n_seqs;
|
|
|
|
| 726 |
|
| 727 |
if (!res) {
|
| 728 |
if (seq_id == -1) {
|
| 729 |
+
clear(true);
|
| 730 |
} else {
|
| 731 |
seq_rm(seq_id, -1, -1);
|
| 732 |
}
|
|
|
|
| 883 |
return false;
|
| 884 |
}
|
| 885 |
|
| 886 |
+
clear(true);
|
| 887 |
|
| 888 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 889 |
kv_cell & cell = cells[i];
|
examples/talk-llama/llama-kv-cache-recurrent.h
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
|
| 3 |
#include "llama-batch.h"
|
| 4 |
#include "llama-graph.h"
|
| 5 |
-
#include "llama-
|
| 6 |
|
| 7 |
#include <set>
|
| 8 |
#include <vector>
|
|
@@ -13,7 +13,7 @@
|
|
| 13 |
|
| 14 |
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
| 15 |
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
| 16 |
-
class llama_kv_cache_recurrent : public
|
| 17 |
public:
|
| 18 |
llama_kv_cache_recurrent(
|
| 19 |
const llama_model & model,
|
|
@@ -29,21 +29,6 @@ public:
|
|
| 29 |
// llama_memory_i
|
| 30 |
//
|
| 31 |
|
| 32 |
-
void clear() override;
|
| 33 |
-
|
| 34 |
-
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 35 |
-
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 36 |
-
void seq_keep(llama_seq_id seq_id) override;
|
| 37 |
-
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 38 |
-
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 39 |
-
|
| 40 |
-
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 41 |
-
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 42 |
-
|
| 43 |
-
//
|
| 44 |
-
// llama_kv_cache
|
| 45 |
-
//
|
| 46 |
-
|
| 47 |
llama_memory_state_ptr init_batch(
|
| 48 |
const llama_batch & batch,
|
| 49 |
uint32_t n_ubatch,
|
|
@@ -52,9 +37,18 @@ public:
|
|
| 52 |
|
| 53 |
llama_memory_state_ptr init_full() override;
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
| 60 |
|
|
|
|
| 2 |
|
| 3 |
#include "llama-batch.h"
|
| 4 |
#include "llama-graph.h"
|
| 5 |
+
#include "llama-memory.h"
|
| 6 |
|
| 7 |
#include <set>
|
| 8 |
#include <vector>
|
|
|
|
| 13 |
|
| 14 |
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
| 15 |
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
| 16 |
+
class llama_kv_cache_recurrent : public llama_memory_i {
|
| 17 |
public:
|
| 18 |
llama_kv_cache_recurrent(
|
| 19 |
const llama_model & model,
|
|
|
|
| 29 |
// llama_memory_i
|
| 30 |
//
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
llama_memory_state_ptr init_batch(
|
| 33 |
const llama_batch & batch,
|
| 34 |
uint32_t n_ubatch,
|
|
|
|
| 37 |
|
| 38 |
llama_memory_state_ptr init_full() override;
|
| 39 |
|
| 40 |
+
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
| 41 |
+
|
| 42 |
+
void clear(bool data) override;
|
| 43 |
|
| 44 |
+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 45 |
+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 46 |
+
void seq_keep(llama_seq_id seq_id) override;
|
| 47 |
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 48 |
+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 49 |
+
|
| 50 |
+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 51 |
+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 52 |
|
| 53 |
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
| 54 |
|
examples/talk-llama/llama-kv-cache-unified-iswa.cpp
CHANGED
|
@@ -52,9 +52,9 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|
| 52 |
hparams.n_swa, hparams.swa_type);
|
| 53 |
}
|
| 54 |
|
| 55 |
-
void llama_kv_cache_unified_iswa::clear() {
|
| 56 |
-
kv_base->clear();
|
| 57 |
-
kv_swa ->clear();
|
| 58 |
}
|
| 59 |
|
| 60 |
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
@@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
|
|
| 123 |
|
| 124 |
assert(heads_base.size() == heads_swa.size());
|
| 125 |
|
| 126 |
-
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
| 127 |
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
| 128 |
}
|
| 129 |
|
| 130 |
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
| 131 |
-
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
| 132 |
}
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
res = res | kv_base->update(lctx);
|
| 138 |
-
res = res | kv_swa ->update(lctx);
|
| 139 |
-
|
| 140 |
-
return res;
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
|
| 144 |
-
kv_base->defrag_sched(thold);
|
| 145 |
-
kv_swa ->defrag_sched(thold);
|
| 146 |
}
|
| 147 |
|
| 148 |
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
|
@@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
|
| 174 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
| 175 |
|
| 176 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
}
|
| 182 |
|
| 183 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
| 184 |
-
llama_memory_status status,
|
| 185 |
llama_kv_cache_unified_iswa * kv,
|
| 186 |
llama_sbatch sbatch,
|
| 187 |
std::vector<uint32_t> heads_base,
|
| 188 |
std::vector<uint32_t> heads_swa,
|
| 189 |
std::vector<llama_ubatch> ubatches)
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
| 197 |
|
| 198 |
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
| 199 |
|
|
@@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
|
| 233 |
|
| 234 |
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
| 235 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
| 236 |
return ubatches[i_next];
|
| 237 |
}
|
| 238 |
|
| 239 |
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
| 240 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 241 |
|
| 242 |
-
return state_base.get();
|
| 243 |
}
|
| 244 |
|
| 245 |
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
| 246 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 247 |
|
| 248 |
-
return state_swa.get();
|
| 249 |
}
|
|
|
|
| 52 |
hparams.n_swa, hparams.swa_type);
|
| 53 |
}
|
| 54 |
|
| 55 |
+
void llama_kv_cache_unified_iswa::clear(bool data) {
|
| 56 |
+
kv_base->clear(data);
|
| 57 |
+
kv_swa ->clear(data);
|
| 58 |
}
|
| 59 |
|
| 60 |
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
|
|
| 123 |
|
| 124 |
assert(heads_base.size() == heads_swa.size());
|
| 125 |
|
| 126 |
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
| 127 |
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
| 128 |
}
|
| 129 |
|
| 130 |
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
| 131 |
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
|
| 132 |
}
|
| 133 |
|
| 134 |
+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
| 135 |
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
}
|
| 137 |
|
| 138 |
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
|
|
|
| 164 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
| 165 |
|
| 166 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
| 167 |
+
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
| 168 |
+
state_base = kv->get_base()->init_full();
|
| 169 |
+
state_swa = kv->get_swa ()->init_full();
|
| 170 |
+
|
| 171 |
+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
| 175 |
+
llama_kv_cache_unified_iswa * kv,
|
| 176 |
+
llama_context * lctx,
|
| 177 |
+
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
| 178 |
+
state_base = kv->get_base()->init_update(lctx, optimize);
|
| 179 |
+
state_swa = kv->get_swa ()->init_update(lctx, optimize);
|
| 180 |
+
|
| 181 |
+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
| 182 |
}
|
| 183 |
|
| 184 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
|
|
|
| 185 |
llama_kv_cache_unified_iswa * kv,
|
| 186 |
llama_sbatch sbatch,
|
| 187 |
std::vector<uint32_t> heads_base,
|
| 188 |
std::vector<uint32_t> heads_swa,
|
| 189 |
std::vector<llama_ubatch> ubatches)
|
| 190 |
+
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
| 191 |
+
sbatch(std::move(sbatch)),
|
| 192 |
+
ubatches(std::move(ubatches)) {
|
| 193 |
+
// note: here we copy the ubatches. not sure if this is ideal
|
| 194 |
+
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
| 195 |
+
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
| 196 |
+
|
| 197 |
+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
| 198 |
+
}
|
| 199 |
|
| 200 |
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
| 201 |
|
|
|
|
| 235 |
|
| 236 |
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
| 237 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 238 |
+
|
| 239 |
return ubatches[i_next];
|
| 240 |
}
|
| 241 |
|
| 242 |
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
| 243 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 244 |
|
| 245 |
+
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
|
| 246 |
}
|
| 247 |
|
| 248 |
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
| 249 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 250 |
|
| 251 |
+
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
|
| 252 |
}
|
examples/talk-llama/llama-kv-cache-unified-iswa.h
CHANGED
|
@@ -11,7 +11,7 @@
|
|
| 11 |
// utilizes two instances of llama_kv_cache_unified
|
| 12 |
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
| 13 |
|
| 14 |
-
class llama_kv_cache_unified_iswa : public
|
| 15 |
public:
|
| 16 |
llama_kv_cache_unified_iswa(
|
| 17 |
const llama_model & model,
|
|
@@ -31,21 +31,6 @@ public:
|
|
| 31 |
// llama_memory_i
|
| 32 |
//
|
| 33 |
|
| 34 |
-
void clear() override;
|
| 35 |
-
|
| 36 |
-
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 37 |
-
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 38 |
-
void seq_keep(llama_seq_id seq_id) override;
|
| 39 |
-
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 40 |
-
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 41 |
-
|
| 42 |
-
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 43 |
-
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 44 |
-
|
| 45 |
-
//
|
| 46 |
-
// llama_kv_cache
|
| 47 |
-
//
|
| 48 |
-
|
| 49 |
llama_memory_state_ptr init_batch(
|
| 50 |
const llama_batch & batch,
|
| 51 |
uint32_t n_ubatch,
|
|
@@ -54,12 +39,21 @@ public:
|
|
| 54 |
|
| 55 |
llama_memory_state_ptr init_full() override;
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
void defrag_sched(float thold) override;
|
| 60 |
|
| 61 |
bool get_can_shift() const override;
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
// state write/load
|
| 64 |
|
| 65 |
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
@@ -86,12 +80,16 @@ public:
|
|
| 86 |
|
| 87 |
// used to create a full-cache state
|
| 88 |
llama_kv_cache_unified_iswa_state(
|
| 89 |
-
llama_memory_status status,
|
| 90 |
llama_kv_cache_unified_iswa * kv);
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
// used to create a state from a batch
|
| 93 |
llama_kv_cache_unified_iswa_state(
|
| 94 |
-
llama_memory_status status,
|
| 95 |
llama_kv_cache_unified_iswa * kv,
|
| 96 |
llama_sbatch sbatch,
|
| 97 |
std::vector<uint32_t> heads_base,
|
|
@@ -120,7 +118,7 @@ public:
|
|
| 120 |
const llama_kv_cache_unified_state * get_swa() const;
|
| 121 |
|
| 122 |
private:
|
| 123 |
-
|
| 124 |
|
| 125 |
//llama_kv_cache_unified_iswa * kv;
|
| 126 |
|
|
@@ -131,6 +129,6 @@ private:
|
|
| 131 |
|
| 132 |
std::vector<llama_ubatch> ubatches;
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
};
|
|
|
|
| 11 |
// utilizes two instances of llama_kv_cache_unified
|
| 12 |
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
| 13 |
|
| 14 |
+
class llama_kv_cache_unified_iswa : public llama_memory_i {
|
| 15 |
public:
|
| 16 |
llama_kv_cache_unified_iswa(
|
| 17 |
const llama_model & model,
|
|
|
|
| 31 |
// llama_memory_i
|
| 32 |
//
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
llama_memory_state_ptr init_batch(
|
| 35 |
const llama_batch & batch,
|
| 36 |
uint32_t n_ubatch,
|
|
|
|
| 39 |
|
| 40 |
llama_memory_state_ptr init_full() override;
|
| 41 |
|
| 42 |
+
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
|
|
|
|
|
|
| 43 |
|
| 44 |
bool get_can_shift() const override;
|
| 45 |
|
| 46 |
+
void clear(bool data) override;
|
| 47 |
+
|
| 48 |
+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 49 |
+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 50 |
+
void seq_keep(llama_seq_id seq_id) override;
|
| 51 |
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 52 |
+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 53 |
+
|
| 54 |
+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 55 |
+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 56 |
+
|
| 57 |
// state write/load
|
| 58 |
|
| 59 |
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
|
|
| 80 |
|
| 81 |
// used to create a full-cache state
|
| 82 |
llama_kv_cache_unified_iswa_state(
|
|
|
|
| 83 |
llama_kv_cache_unified_iswa * kv);
|
| 84 |
|
| 85 |
+
// used to create an update state
|
| 86 |
+
llama_kv_cache_unified_iswa_state(
|
| 87 |
+
llama_kv_cache_unified_iswa * kv,
|
| 88 |
+
llama_context * lctx,
|
| 89 |
+
bool optimize);
|
| 90 |
+
|
| 91 |
// used to create a state from a batch
|
| 92 |
llama_kv_cache_unified_iswa_state(
|
|
|
|
| 93 |
llama_kv_cache_unified_iswa * kv,
|
| 94 |
llama_sbatch sbatch,
|
| 95 |
std::vector<uint32_t> heads_base,
|
|
|
|
| 118 |
const llama_kv_cache_unified_state * get_swa() const;
|
| 119 |
|
| 120 |
private:
|
| 121 |
+
llama_memory_status status;
|
| 122 |
|
| 123 |
//llama_kv_cache_unified_iswa * kv;
|
| 124 |
|
|
|
|
| 129 |
|
| 130 |
std::vector<llama_ubatch> ubatches;
|
| 131 |
|
| 132 |
+
llama_memory_state_ptr state_base;
|
| 133 |
+
llama_memory_state_ptr state_swa;
|
| 134 |
};
|
examples/talk-llama/llama-kv-cache-unified.cpp
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
#include "llama-kv-cache-unified.h"
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
|
|
|
| 4 |
#include "llama-model.h"
|
| 5 |
#include "llama-context.h"
|
| 6 |
|
|
@@ -128,13 +129,15 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 128 |
}
|
| 129 |
}
|
| 130 |
|
| 131 |
-
void llama_kv_cache_unified::clear() {
|
| 132 |
cells.reset();
|
| 133 |
|
| 134 |
head = 0;
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
| 138 |
}
|
| 139 |
}
|
| 140 |
|
|
@@ -149,12 +152,27 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|
| 149 |
p1 = std::numeric_limits<llama_pos>::max();
|
| 150 |
}
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
-
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
| 158 |
if (new_head == cells.size()) {
|
| 159 |
new_head = i;
|
| 160 |
}
|
|
@@ -305,16 +323,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
|
| 305 |
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 306 |
}
|
| 307 |
|
| 308 |
-
return std::make_unique<llama_kv_cache_unified_state>(
|
| 309 |
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
| 310 |
}
|
| 311 |
|
| 312 |
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
| 313 |
-
return std::make_unique<llama_kv_cache_unified_state>(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
}
|
| 315 |
|
| 316 |
-
|
| 317 |
-
|
| 318 |
|
| 319 |
struct state {
|
| 320 |
uint32_t head_old; // old position of the head, before placing the ubatch
|
|
@@ -359,12 +410,12 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
|
|
| 359 |
return res;
|
| 360 |
}
|
| 361 |
|
| 362 |
-
bool llama_kv_cache_unified::update(llama_context &
|
| 363 |
bool updated = false;
|
| 364 |
|
| 365 |
-
auto * sched = lctx
|
| 366 |
|
| 367 |
-
if (
|
| 368 |
if (!get_can_shift()) {
|
| 369 |
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
| 370 |
}
|
|
@@ -375,9 +426,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
| 375 |
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
| 376 |
ggml_backend_sched_reset(sched);
|
| 377 |
|
| 378 |
-
auto * gf = lctx
|
| 379 |
|
| 380 |
-
auto res = build_graph_shift(lctx
|
| 381 |
if (!res) {
|
| 382 |
LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
|
| 383 |
return updated;
|
|
@@ -390,7 +441,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
| 390 |
|
| 391 |
res->set_inputs(nullptr);
|
| 392 |
|
| 393 |
-
if (lctx
|
| 394 |
LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
|
| 395 |
return updated;
|
| 396 |
}
|
|
@@ -401,54 +452,53 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|
| 401 |
cells.reset_shift();
|
| 402 |
}
|
| 403 |
|
| 404 |
-
if (
|
| 405 |
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
| 406 |
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
auto * gf = lctx.graph_init();
|
| 411 |
-
|
| 412 |
-
auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
| 413 |
-
if (!res) {
|
| 414 |
-
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
| 415 |
-
return updated;
|
| 416 |
-
}
|
| 417 |
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
return updated;
|
| 421 |
-
}
|
| 422 |
|
| 423 |
-
|
|
|
|
|
|
|
| 424 |
|
| 425 |
-
|
| 426 |
-
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
| 427 |
-
return updated;
|
| 428 |
}
|
| 429 |
|
| 430 |
-
|
|
|
|
| 431 |
}
|
| 432 |
|
| 433 |
-
|
| 434 |
-
}
|
| 435 |
|
| 436 |
-
|
| 437 |
-
}
|
| 438 |
|
| 439 |
-
|
| 440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
|
| 442 |
-
|
| 443 |
-
// - count the padding towards the number of used tokens
|
| 444 |
-
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
| 445 |
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
|
|
|
| 449 |
|
| 450 |
-
|
| 451 |
}
|
|
|
|
|
|
|
| 452 |
}
|
| 453 |
|
| 454 |
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
@@ -597,6 +647,10 @@ uint32_t llama_kv_cache_unified::get_size() const {
|
|
| 597 |
return cells.size();
|
| 598 |
}
|
| 599 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
| 601 |
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
| 602 |
}
|
|
@@ -890,11 +944,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|
| 890 |
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
| 891 |
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
| 892 |
|
| 893 |
-
//GGML_ASSERT(kv_self->size == n_ctx);
|
| 894 |
-
|
| 895 |
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
| 896 |
|
| 897 |
-
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32,
|
| 898 |
ggml_set_input(inp->k_shift);
|
| 899 |
|
| 900 |
for (const auto & layer : layers) {
|
|
@@ -926,12 +978,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|
| 926 |
}
|
| 927 |
|
| 928 |
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
|
|
|
| 932 |
auto res = std::make_unique<llm_graph_result>();
|
| 933 |
|
| 934 |
-
const auto & ids =
|
| 935 |
|
| 936 |
#if 0
|
| 937 |
// CPU defrag
|
|
@@ -1072,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|
| 1072 |
return res;
|
| 1073 |
}
|
| 1074 |
|
| 1075 |
-
|
| 1076 |
const uint32_t n_layer = layers.size();
|
| 1077 |
|
| 1078 |
const uint32_t n_kv = cells.used_max_p1();
|
|
@@ -1093,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
| 1093 |
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
| 1094 |
|
| 1095 |
// determine which KV cells to move where
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
-
//
|
| 1099 |
-
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
|
| 1100 |
-
//
|
| 1101 |
-
auto & ids = defrag_info.ids;
|
| 1102 |
|
| 1103 |
-
ids.clear();
|
| 1104 |
ids.resize(n_kv, n_kv);
|
| 1105 |
|
| 1106 |
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
|
@@ -1164,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
| 1164 |
// this cell goes to (i0 + nf)
|
| 1165 |
ids[i1] = i0 + nf;
|
| 1166 |
|
| 1167 |
-
// move the cell meta data
|
| 1168 |
-
cells.mv(i1, i0 + nf);
|
| 1169 |
-
|
| 1170 |
-
head = n_used;
|
| 1171 |
-
|
| 1172 |
if (!cont) {
|
| 1173 |
n_moves++;
|
| 1174 |
cont = true;
|
|
@@ -1191,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|
| 1191 |
}
|
| 1192 |
|
| 1193 |
if (n_moves == 0) {
|
| 1194 |
-
return
|
| 1195 |
}
|
| 1196 |
|
| 1197 |
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
|
| 1198 |
|
| 1199 |
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
|
| 1200 |
|
| 1201 |
-
return
|
| 1202 |
}
|
| 1203 |
|
| 1204 |
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
|
@@ -1276,7 +1319,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
|
| 1276 |
|
| 1277 |
if (!res) {
|
| 1278 |
if (seq_id == -1) {
|
| 1279 |
-
clear();
|
| 1280 |
} else {
|
| 1281 |
seq_rm(seq_id, -1, -1);
|
| 1282 |
}
|
|
@@ -1457,7 +1500,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
| 1457 |
return false;
|
| 1458 |
}
|
| 1459 |
|
| 1460 |
-
clear();
|
| 1461 |
|
| 1462 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1463 |
llama_pos pos;
|
|
@@ -1621,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1621 |
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
|
| 1622 |
|
| 1623 |
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
| 1624 |
-
|
| 1625 |
-
|
| 1626 |
-
|
| 1627 |
-
|
| 1628 |
-
}
|
| 1629 |
|
| 1630 |
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
| 1631 |
-
|
| 1632 |
-
|
| 1633 |
-
|
| 1634 |
-
|
| 1635 |
-
|
| 1636 |
-
|
| 1637 |
-
kv(kv),
|
| 1638 |
-
sbatch(std::move(sbatch)),
|
| 1639 |
-
heads(std::move(heads)),
|
| 1640 |
-
ubatches(std::move(ubatches)) {
|
| 1641 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1642 |
|
| 1643 |
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
| 1644 |
|
|
@@ -1655,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() {
|
|
| 1655 |
bool llama_kv_cache_unified_state::apply() {
|
| 1656 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1657 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1658 |
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
|
| 1659 |
|
| 1660 |
n_kv = kv->get_n_kv();
|
|
|
|
| 1 |
#include "llama-kv-cache-unified.h"
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
| 4 |
+
#include "llama-io.h"
|
| 5 |
#include "llama-model.h"
|
| 6 |
#include "llama-context.h"
|
| 7 |
|
|
|
|
| 129 |
}
|
| 130 |
}
|
| 131 |
|
| 132 |
+
void llama_kv_cache_unified::clear(bool data) {
|
| 133 |
cells.reset();
|
| 134 |
|
| 135 |
head = 0;
|
| 136 |
|
| 137 |
+
if (data) {
|
| 138 |
+
for (auto & buf : bufs) {
|
| 139 |
+
ggml_backend_buffer_clear(buf.get(), 0);
|
| 140 |
+
}
|
| 141 |
}
|
| 142 |
}
|
| 143 |
|
|
|
|
| 152 |
p1 = std::numeric_limits<llama_pos>::max();
|
| 153 |
}
|
| 154 |
|
| 155 |
+
if (seq_id >= 0) {
|
| 156 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 157 |
+
if (!cells.pos_in(i, p0, p1)) {
|
| 158 |
+
continue;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
| 162 |
+
if (new_head == cells.size()) {
|
| 163 |
+
new_head = i;
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
}
|
| 167 |
+
} else {
|
| 168 |
+
// match any sequence
|
| 169 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 170 |
+
if (!cells.pos_in(i, p0, p1)) {
|
| 171 |
+
continue;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
cells.rm(i);
|
| 175 |
|
|
|
|
| 176 |
if (new_head == cells.size()) {
|
| 177 |
new_head = i;
|
| 178 |
}
|
|
|
|
| 323 |
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 324 |
}
|
| 325 |
|
| 326 |
+
return std::make_unique<llama_kv_cache_unified_state>(
|
| 327 |
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
| 328 |
}
|
| 329 |
|
| 330 |
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
| 331 |
+
return std::make_unique<llama_kv_cache_unified_state>(this);
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
| 335 |
+
bool do_shift = get_has_shift();
|
| 336 |
+
|
| 337 |
+
defrag_info dinfo;
|
| 338 |
+
|
| 339 |
+
// see if we need to defrag
|
| 340 |
+
{
|
| 341 |
+
bool do_defrag = optimize;
|
| 342 |
+
|
| 343 |
+
const auto thold = lctx->get_cparams().defrag_thold;
|
| 344 |
+
|
| 345 |
+
if (!do_defrag && thold > 0.0f) {
|
| 346 |
+
const auto n_kv = cells.used_max_p1();
|
| 347 |
+
|
| 348 |
+
// - do not defrag small contexts (i.e. < 2048 tokens)
|
| 349 |
+
// - count the padding towards the number of used tokens
|
| 350 |
+
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
| 351 |
+
|
| 352 |
+
if (fragmentation > thold) {
|
| 353 |
+
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
| 354 |
+
|
| 355 |
+
do_defrag = true;
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
if (do_defrag) {
|
| 360 |
+
dinfo = defrag_prepare(lctx->graph_max_nodes());
|
| 361 |
+
}
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
|
| 365 |
}
|
| 366 |
|
| 367 |
+
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
| 368 |
+
llama_kv_cache_unified::ubatch_heads res;
|
| 369 |
|
| 370 |
struct state {
|
| 371 |
uint32_t head_old; // old position of the head, before placing the ubatch
|
|
|
|
| 410 |
return res;
|
| 411 |
}
|
| 412 |
|
| 413 |
+
bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
|
| 414 |
bool updated = false;
|
| 415 |
|
| 416 |
+
auto * sched = lctx->get_sched();
|
| 417 |
|
| 418 |
+
if (do_shift) {
|
| 419 |
if (!get_can_shift()) {
|
| 420 |
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
| 421 |
}
|
|
|
|
| 426 |
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
| 427 |
ggml_backend_sched_reset(sched);
|
| 428 |
|
| 429 |
+
auto * gf = lctx->graph_init();
|
| 430 |
|
| 431 |
+
auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
|
| 432 |
if (!res) {
|
| 433 |
LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
|
| 434 |
return updated;
|
|
|
|
| 441 |
|
| 442 |
res->set_inputs(nullptr);
|
| 443 |
|
| 444 |
+
if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
| 445 |
LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
|
| 446 |
return updated;
|
| 447 |
}
|
|
|
|
| 452 |
cells.reset_shift();
|
| 453 |
}
|
| 454 |
|
| 455 |
+
if (!dinfo.empty()) {
|
| 456 |
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
| 457 |
|
| 458 |
+
// apply moves:
|
| 459 |
+
{
|
| 460 |
+
const auto n_kv = dinfo.ids.size();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
|
| 462 |
+
for (uint32_t i = 0; i < n_kv; ++i) {
|
| 463 |
+
assert(dinfo.ids[i] <= n_kv);
|
|
|
|
|
|
|
| 464 |
|
| 465 |
+
if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
|
| 466 |
+
continue;
|
| 467 |
+
}
|
| 468 |
|
| 469 |
+
cells.mv(i, dinfo.ids[i]);
|
|
|
|
|
|
|
| 470 |
}
|
| 471 |
|
| 472 |
+
// reset the head so we can find the first free slot during the next ubatch
|
| 473 |
+
head = 0;
|
| 474 |
}
|
| 475 |
|
| 476 |
+
ggml_backend_sched_reset(sched);
|
|
|
|
| 477 |
|
| 478 |
+
auto * gf = lctx->graph_init();
|
|
|
|
| 479 |
|
| 480 |
+
auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
|
| 481 |
+
if (!res) {
|
| 482 |
+
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
| 483 |
+
return updated;
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
| 487 |
+
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
| 488 |
+
return updated;
|
| 489 |
+
}
|
| 490 |
|
| 491 |
+
res->set_inputs(nullptr);
|
|
|
|
|
|
|
| 492 |
|
| 493 |
+
if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
| 494 |
+
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
| 495 |
+
return updated;
|
| 496 |
+
}
|
| 497 |
|
| 498 |
+
updated = true;
|
| 499 |
}
|
| 500 |
+
|
| 501 |
+
return updated;
|
| 502 |
}
|
| 503 |
|
| 504 |
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
|
|
| 647 |
return cells.size();
|
| 648 |
}
|
| 649 |
|
| 650 |
+
bool llama_kv_cache_unified::get_has_shift() const {
|
| 651 |
+
return cells.get_has_shift();
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
| 655 |
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
| 656 |
}
|
|
|
|
| 944 |
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
| 945 |
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
| 946 |
|
|
|
|
|
|
|
| 947 |
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
| 948 |
|
| 949 |
+
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
|
| 950 |
ggml_set_input(inp->k_shift);
|
| 951 |
|
| 952 |
for (const auto & layer : layers) {
|
|
|
|
| 978 |
}
|
| 979 |
|
| 980 |
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
| 981 |
+
const llama_cparams & cparams,
|
| 982 |
+
ggml_context * ctx,
|
| 983 |
+
ggml_cgraph * gf,
|
| 984 |
+
const defrag_info & dinfo) const {
|
| 985 |
auto res = std::make_unique<llm_graph_result>();
|
| 986 |
|
| 987 |
+
const auto & ids = dinfo.ids;
|
| 988 |
|
| 989 |
#if 0
|
| 990 |
// CPU defrag
|
|
|
|
| 1125 |
return res;
|
| 1126 |
}
|
| 1127 |
|
| 1128 |
+
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
| 1129 |
const uint32_t n_layer = layers.size();
|
| 1130 |
|
| 1131 |
const uint32_t n_kv = cells.used_max_p1();
|
|
|
|
| 1146 |
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
| 1147 |
|
| 1148 |
// determine which KV cells to move where
|
| 1149 |
+
defrag_info res;
|
| 1150 |
+
auto & ids = res.ids;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1151 |
|
|
|
|
| 1152 |
ids.resize(n_kv, n_kv);
|
| 1153 |
|
| 1154 |
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
|
|
|
| 1212 |
// this cell goes to (i0 + nf)
|
| 1213 |
ids[i1] = i0 + nf;
|
| 1214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1215 |
if (!cont) {
|
| 1216 |
n_moves++;
|
| 1217 |
cont = true;
|
|
|
|
| 1234 |
}
|
| 1235 |
|
| 1236 |
if (n_moves == 0) {
|
| 1237 |
+
return {};
|
| 1238 |
}
|
| 1239 |
|
| 1240 |
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
|
| 1241 |
|
| 1242 |
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
|
| 1243 |
|
| 1244 |
+
return res;
|
| 1245 |
}
|
| 1246 |
|
| 1247 |
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
|
|
|
| 1319 |
|
| 1320 |
if (!res) {
|
| 1321 |
if (seq_id == -1) {
|
| 1322 |
+
clear(true);
|
| 1323 |
} else {
|
| 1324 |
seq_rm(seq_id, -1, -1);
|
| 1325 |
}
|
|
|
|
| 1500 |
return false;
|
| 1501 |
}
|
| 1502 |
|
| 1503 |
+
clear(true);
|
| 1504 |
|
| 1505 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1506 |
llama_pos pos;
|
|
|
|
| 1664 |
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
|
| 1665 |
|
| 1666 |
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
| 1667 |
+
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
| 1668 |
+
n_kv = kv->get_size();
|
| 1669 |
+
head = 0;
|
| 1670 |
+
}
|
|
|
|
| 1671 |
|
| 1672 |
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
| 1673 |
+
llama_kv_cache_unified * kv,
|
| 1674 |
+
llama_context * lctx,
|
| 1675 |
+
bool do_shift,
|
| 1676 |
+
defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
|
| 1677 |
+
if (!do_shift && dinfo.empty()) {
|
| 1678 |
+
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1679 |
}
|
| 1680 |
+
}
|
| 1681 |
+
|
| 1682 |
+
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
| 1683 |
+
llama_kv_cache_unified * kv,
|
| 1684 |
+
llama_sbatch sbatch,
|
| 1685 |
+
llama_kv_cache_unified::ubatch_heads heads,
|
| 1686 |
+
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
| 1687 |
+
}
|
| 1688 |
|
| 1689 |
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
| 1690 |
|
|
|
|
| 1701 |
bool llama_kv_cache_unified_state::apply() {
|
| 1702 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1703 |
|
| 1704 |
+
// no ubatches -> this is a KV cache update
|
| 1705 |
+
if (ubatches.empty()) {
|
| 1706 |
+
kv->update(lctx, do_shift, dinfo);
|
| 1707 |
+
|
| 1708 |
+
return true;
|
| 1709 |
+
}
|
| 1710 |
+
|
| 1711 |
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
|
| 1712 |
|
| 1713 |
n_kv = kv->get_n_kv();
|
examples/talk-llama/llama-kv-cache-unified.h
CHANGED
|
@@ -2,8 +2,8 @@
|
|
| 2 |
|
| 3 |
#include "llama-batch.h"
|
| 4 |
#include "llama-graph.h"
|
| 5 |
-
#include "llama-kv-cache.h"
|
| 6 |
#include "llama-kv-cells.h"
|
|
|
|
| 7 |
|
| 8 |
#include <unordered_map>
|
| 9 |
#include <vector>
|
|
@@ -17,13 +17,26 @@ struct llama_context;
|
|
| 17 |
// llama_kv_cache_unified
|
| 18 |
//
|
| 19 |
|
| 20 |
-
class llama_kv_cache_unified : public
|
| 21 |
public:
|
| 22 |
static uint32_t get_padding(const llama_cparams & cparams);
|
| 23 |
|
| 24 |
// this callback is used to filter out layers that should not be included in the cache
|
| 25 |
using layer_filter_cb = std::function<bool(int32_t il)>;
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
llama_kv_cache_unified(
|
| 28 |
const llama_model & model,
|
| 29 |
layer_filter_cb && filter,
|
|
@@ -43,21 +56,6 @@ public:
|
|
| 43 |
// llama_memory_i
|
| 44 |
//
|
| 45 |
|
| 46 |
-
void clear() override;
|
| 47 |
-
|
| 48 |
-
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 49 |
-
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 50 |
-
void seq_keep(llama_seq_id seq_id) override;
|
| 51 |
-
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 52 |
-
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 53 |
-
|
| 54 |
-
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 55 |
-
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 56 |
-
|
| 57 |
-
//
|
| 58 |
-
// llama_kv_cache
|
| 59 |
-
//
|
| 60 |
-
|
| 61 |
llama_memory_state_ptr init_batch(
|
| 62 |
const llama_batch & batch,
|
| 63 |
uint32_t n_ubatch,
|
|
@@ -66,12 +64,21 @@ public:
|
|
| 66 |
|
| 67 |
llama_memory_state_ptr init_full() override;
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
void defrag_sched(float thold) override;
|
| 72 |
|
| 73 |
bool get_can_shift() const override;
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
// state write/load
|
| 76 |
|
| 77 |
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
@@ -83,6 +90,8 @@ public:
|
|
| 83 |
|
| 84 |
uint32_t get_size() const;
|
| 85 |
|
|
|
|
|
|
|
| 86 |
//
|
| 87 |
// graph_build API
|
| 88 |
//
|
|
@@ -103,7 +112,9 @@ public:
|
|
| 103 |
|
| 104 |
// find places for the provided ubatches in the cache, returns the head locations
|
| 105 |
// return empty vector on failure
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
|
| 108 |
// return the cell position where we can insert the ubatch
|
| 109 |
// return -1 on failure to find a contiguous slot of kv cells
|
|
@@ -133,8 +144,7 @@ private:
|
|
| 133 |
ggml_tensor * v;
|
| 134 |
};
|
| 135 |
|
| 136 |
-
bool
|
| 137 |
-
bool v_trans = true; // the value tensor is transposed
|
| 138 |
|
| 139 |
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
| 140 |
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
|
@@ -160,13 +170,8 @@ private:
|
|
| 160 |
// model layer id -> KV cache layer id
|
| 161 |
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
| 162 |
|
| 163 |
-
//
|
| 164 |
-
|
| 165 |
-
std::vector<uint32_t> ids;
|
| 166 |
-
} defrag_info;
|
| 167 |
-
|
| 168 |
-
// return true if cells have been moved
|
| 169 |
-
bool defrag_prepare(int32_t n_max_nodes);
|
| 170 |
|
| 171 |
size_t total_size() const;
|
| 172 |
|
|
@@ -192,7 +197,8 @@ private:
|
|
| 192 |
llm_graph_result_ptr build_graph_defrag(
|
| 193 |
const llama_cparams & cparams,
|
| 194 |
ggml_context * ctx,
|
| 195 |
-
ggml_cgraph * gf
|
|
|
|
| 196 |
|
| 197 |
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
| 198 |
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
|
@@ -203,20 +209,29 @@ private:
|
|
| 203 |
|
| 204 |
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
| 205 |
public:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
// used for errors
|
| 207 |
llama_kv_cache_unified_state(llama_memory_status status);
|
| 208 |
|
| 209 |
// used to create a full-cache state
|
| 210 |
llama_kv_cache_unified_state(
|
| 211 |
-
llama_memory_status status,
|
| 212 |
llama_kv_cache_unified * kv);
|
| 213 |
|
| 214 |
-
// used to create
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
llama_kv_cache_unified_state(
|
| 216 |
-
llama_memory_status status,
|
| 217 |
llama_kv_cache_unified * kv,
|
| 218 |
llama_sbatch sbatch,
|
| 219 |
-
|
| 220 |
std::vector<llama_ubatch> ubatches);
|
| 221 |
|
| 222 |
virtual ~llama_kv_cache_unified_state();
|
|
@@ -253,16 +268,30 @@ public:
|
|
| 253 |
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
| 254 |
|
| 255 |
private:
|
| 256 |
-
|
| 257 |
|
| 258 |
llama_kv_cache_unified * kv;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
llama_sbatch sbatch;
|
| 261 |
|
| 262 |
// the index of the next ubatch to process
|
| 263 |
size_t i_next = 0;
|
| 264 |
|
| 265 |
-
|
|
|
|
| 266 |
std::vector<llama_ubatch> ubatches;
|
| 267 |
|
| 268 |
//
|
|
|
|
| 2 |
|
| 3 |
#include "llama-batch.h"
|
| 4 |
#include "llama-graph.h"
|
|
|
|
| 5 |
#include "llama-kv-cells.h"
|
| 6 |
+
#include "llama-memory.h"
|
| 7 |
|
| 8 |
#include <unordered_map>
|
| 9 |
#include <vector>
|
|
|
|
| 17 |
// llama_kv_cache_unified
|
| 18 |
//
|
| 19 |
|
| 20 |
+
class llama_kv_cache_unified : public llama_memory_i {
|
| 21 |
public:
|
| 22 |
static uint32_t get_padding(const llama_cparams & cparams);
|
| 23 |
|
| 24 |
// this callback is used to filter out layers that should not be included in the cache
|
| 25 |
using layer_filter_cb = std::function<bool(int32_t il)>;
|
| 26 |
|
| 27 |
+
using ubatch_heads = std::vector<uint32_t>;
|
| 28 |
+
|
| 29 |
+
struct defrag_info {
|
| 30 |
+
bool empty() const {
|
| 31 |
+
return ids.empty();
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// contains information about which cell moves where:
|
| 35 |
+
// - cell i moves to ids[i]
|
| 36 |
+
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
|
| 37 |
+
std::vector<uint32_t> ids;
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
llama_kv_cache_unified(
|
| 41 |
const llama_model & model,
|
| 42 |
layer_filter_cb && filter,
|
|
|
|
| 56 |
// llama_memory_i
|
| 57 |
//
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
llama_memory_state_ptr init_batch(
|
| 60 |
const llama_batch & batch,
|
| 61 |
uint32_t n_ubatch,
|
|
|
|
| 64 |
|
| 65 |
llama_memory_state_ptr init_full() override;
|
| 66 |
|
| 67 |
+
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
|
|
|
|
|
|
| 68 |
|
| 69 |
bool get_can_shift() const override;
|
| 70 |
|
| 71 |
+
void clear(bool data) override;
|
| 72 |
+
|
| 73 |
+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 74 |
+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 75 |
+
void seq_keep(llama_seq_id seq_id) override;
|
| 76 |
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 77 |
+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 78 |
+
|
| 79 |
+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 80 |
+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 81 |
+
|
| 82 |
// state write/load
|
| 83 |
|
| 84 |
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
|
|
| 90 |
|
| 91 |
uint32_t get_size() const;
|
| 92 |
|
| 93 |
+
bool get_has_shift() const;
|
| 94 |
+
|
| 95 |
//
|
| 96 |
// graph_build API
|
| 97 |
//
|
|
|
|
| 112 |
|
| 113 |
// find places for the provided ubatches in the cache, returns the head locations
|
| 114 |
// return empty vector on failure
|
| 115 |
+
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
|
| 116 |
+
|
| 117 |
+
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
| 118 |
|
| 119 |
// return the cell position where we can insert the ubatch
|
| 120 |
// return -1 on failure to find a contiguous slot of kv cells
|
|
|
|
| 144 |
ggml_tensor * v;
|
| 145 |
};
|
| 146 |
|
| 147 |
+
bool v_trans = true; // the value tensor is transposed
|
|
|
|
| 148 |
|
| 149 |
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
| 150 |
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
|
|
|
| 170 |
// model layer id -> KV cache layer id
|
| 171 |
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
| 172 |
|
| 173 |
+
// return non-empty vector if cells have been moved
|
| 174 |
+
defrag_info defrag_prepare(int32_t n_max_nodes) const;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
size_t total_size() const;
|
| 177 |
|
|
|
|
| 197 |
llm_graph_result_ptr build_graph_defrag(
|
| 198 |
const llama_cparams & cparams,
|
| 199 |
ggml_context * ctx,
|
| 200 |
+
ggml_cgraph * gf,
|
| 201 |
+
const defrag_info & dinfo) const;
|
| 202 |
|
| 203 |
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
| 204 |
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
|
|
|
| 209 |
|
| 210 |
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
| 211 |
public:
|
| 212 |
+
// some shorthands
|
| 213 |
+
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
| 214 |
+
using defrag_info = llama_kv_cache_unified::defrag_info;
|
| 215 |
+
|
| 216 |
// used for errors
|
| 217 |
llama_kv_cache_unified_state(llama_memory_status status);
|
| 218 |
|
| 219 |
// used to create a full-cache state
|
| 220 |
llama_kv_cache_unified_state(
|
|
|
|
| 221 |
llama_kv_cache_unified * kv);
|
| 222 |
|
| 223 |
+
// used to create an update state
|
| 224 |
+
llama_kv_cache_unified_state(
|
| 225 |
+
llama_kv_cache_unified * kv,
|
| 226 |
+
llama_context * lctx,
|
| 227 |
+
bool do_shift,
|
| 228 |
+
defrag_info dinfo);
|
| 229 |
+
|
| 230 |
+
// used to create a decode state from a batch
|
| 231 |
llama_kv_cache_unified_state(
|
|
|
|
| 232 |
llama_kv_cache_unified * kv,
|
| 233 |
llama_sbatch sbatch,
|
| 234 |
+
ubatch_heads heads,
|
| 235 |
std::vector<llama_ubatch> ubatches);
|
| 236 |
|
| 237 |
virtual ~llama_kv_cache_unified_state();
|
|
|
|
| 268 |
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
| 269 |
|
| 270 |
private:
|
| 271 |
+
llama_memory_status status;
|
| 272 |
|
| 273 |
llama_kv_cache_unified * kv;
|
| 274 |
+
llama_context * lctx;
|
| 275 |
+
|
| 276 |
+
//
|
| 277 |
+
// update state
|
| 278 |
+
//
|
| 279 |
+
|
| 280 |
+
bool do_shift = false;
|
| 281 |
+
|
| 282 |
+
defrag_info dinfo;
|
| 283 |
+
|
| 284 |
+
//
|
| 285 |
+
// batch processing state
|
| 286 |
+
//
|
| 287 |
|
| 288 |
llama_sbatch sbatch;
|
| 289 |
|
| 290 |
// the index of the next ubatch to process
|
| 291 |
size_t i_next = 0;
|
| 292 |
|
| 293 |
+
ubatch_heads heads;
|
| 294 |
+
|
| 295 |
std::vector<llama_ubatch> ubatches;
|
| 296 |
|
| 297 |
//
|
examples/talk-llama/llama-kv-cache.cpp
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
#include "llama-kv-cache.h"
|
|
|
|
|
|
examples/talk-llama/llama-kv-cells.h
CHANGED
|
@@ -80,6 +80,9 @@ public:
|
|
| 80 |
assert(isrc < pos.size());
|
| 81 |
assert(idst < pos.size());
|
| 82 |
|
|
|
|
|
|
|
|
|
|
| 83 |
pos [idst] = pos [isrc];
|
| 84 |
shift[idst] = shift[isrc];
|
| 85 |
seq [idst] = seq [isrc];
|
|
@@ -144,9 +147,10 @@ public:
|
|
| 144 |
assert(pos[i] != -1);
|
| 145 |
|
| 146 |
seq_pos_rm(i);
|
|
|
|
| 147 |
|
| 148 |
pos[i] = -1;
|
| 149 |
-
|
| 150 |
|
| 151 |
used.erase(i);
|
| 152 |
}
|
|
@@ -164,6 +168,7 @@ public:
|
|
| 164 |
|
| 165 |
if (seq[i].none()) {
|
| 166 |
pos[i] = -1;
|
|
|
|
| 167 |
|
| 168 |
used.erase(i);
|
| 169 |
|
|
@@ -192,6 +197,7 @@ public:
|
|
| 192 |
seq[i].reset();
|
| 193 |
|
| 194 |
pos[i] = -1;
|
|
|
|
| 195 |
|
| 196 |
used.erase(i);
|
| 197 |
|
|
@@ -317,21 +323,20 @@ public:
|
|
| 317 |
pos[i] += d;
|
| 318 |
shift[i] += d;
|
| 319 |
|
| 320 |
-
seq_pos_add(i);
|
| 321 |
-
|
| 322 |
has_shift = true;
|
| 323 |
|
| 324 |
if (pos[i] < 0) {
|
| 325 |
-
seq_pos_rm(i);
|
| 326 |
-
|
| 327 |
seq[i].reset();
|
| 328 |
pos[i] = -1;
|
|
|
|
| 329 |
|
| 330 |
used.erase(i);
|
| 331 |
|
| 332 |
return true;
|
| 333 |
}
|
| 334 |
|
|
|
|
|
|
|
| 335 |
return false;
|
| 336 |
}
|
| 337 |
|
|
|
|
| 80 |
assert(isrc < pos.size());
|
| 81 |
assert(idst < pos.size());
|
| 82 |
|
| 83 |
+
assert(pos[idst] == -1);
|
| 84 |
+
assert(pos[isrc] != -1);
|
| 85 |
+
|
| 86 |
pos [idst] = pos [isrc];
|
| 87 |
shift[idst] = shift[isrc];
|
| 88 |
seq [idst] = seq [isrc];
|
|
|
|
| 147 |
assert(pos[i] != -1);
|
| 148 |
|
| 149 |
seq_pos_rm(i);
|
| 150 |
+
seq[i].reset();
|
| 151 |
|
| 152 |
pos[i] = -1;
|
| 153 |
+
shift[i] = 0;
|
| 154 |
|
| 155 |
used.erase(i);
|
| 156 |
}
|
|
|
|
| 168 |
|
| 169 |
if (seq[i].none()) {
|
| 170 |
pos[i] = -1;
|
| 171 |
+
shift[i] = 0;
|
| 172 |
|
| 173 |
used.erase(i);
|
| 174 |
|
|
|
|
| 197 |
seq[i].reset();
|
| 198 |
|
| 199 |
pos[i] = -1;
|
| 200 |
+
shift[i] = 0;
|
| 201 |
|
| 202 |
used.erase(i);
|
| 203 |
|
|
|
|
| 323 |
pos[i] += d;
|
| 324 |
shift[i] += d;
|
| 325 |
|
|
|
|
|
|
|
| 326 |
has_shift = true;
|
| 327 |
|
| 328 |
if (pos[i] < 0) {
|
|
|
|
|
|
|
| 329 |
seq[i].reset();
|
| 330 |
pos[i] = -1;
|
| 331 |
+
shift[i] = 0;
|
| 332 |
|
| 333 |
used.erase(i);
|
| 334 |
|
| 335 |
return true;
|
| 336 |
}
|
| 337 |
|
| 338 |
+
seq_pos_add(i);
|
| 339 |
+
|
| 340 |
return false;
|
| 341 |
}
|
| 342 |
|
examples/talk-llama/llama-memory.cpp
CHANGED
|
@@ -1 +1,42 @@
|
|
| 1 |
#include "llama-memory.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#include "llama-memory.h"
|
| 2 |
+
|
| 3 |
+
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) {
|
| 4 |
+
bool has_update = false;
|
| 5 |
+
|
| 6 |
+
switch (s0) {
|
| 7 |
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
| 8 |
+
{
|
| 9 |
+
has_update = true;
|
| 10 |
+
break;
|
| 11 |
+
}
|
| 12 |
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
| 13 |
+
{
|
| 14 |
+
break;
|
| 15 |
+
}
|
| 16 |
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
| 17 |
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
| 18 |
+
{
|
| 19 |
+
return s0;
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
switch (s1) {
|
| 24 |
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
| 25 |
+
{
|
| 26 |
+
has_update = true;
|
| 27 |
+
break;
|
| 28 |
+
}
|
| 29 |
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
| 30 |
+
{
|
| 31 |
+
break;
|
| 32 |
+
}
|
| 33 |
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
| 34 |
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
| 35 |
+
{
|
| 36 |
+
return s1;
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// if either status has an update, then the combined status has an update
|
| 41 |
+
return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
|
| 42 |
+
}
|
examples/talk-llama/llama-memory.h
CHANGED
|
@@ -7,6 +7,9 @@
|
|
| 7 |
|
| 8 |
struct llama_ubatch;
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
struct llama_memory_params {
|
| 11 |
// kv cache
|
| 12 |
ggml_type type_k;
|
|
@@ -16,32 +19,17 @@ struct llama_memory_params {
|
|
| 16 |
bool swa_full;
|
| 17 |
};
|
| 18 |
|
| 19 |
-
// general concept of LLM memory
|
| 20 |
-
// the KV cache is a type of LLM memory, but there can be other types
|
| 21 |
-
class llama_memory_i {
|
| 22 |
-
public:
|
| 23 |
-
virtual ~llama_memory_i() = default;
|
| 24 |
-
|
| 25 |
-
virtual void clear() = 0;
|
| 26 |
-
|
| 27 |
-
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
| 28 |
-
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
| 29 |
-
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
| 30 |
-
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
|
| 31 |
-
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
| 32 |
-
|
| 33 |
-
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
|
| 34 |
-
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
| 35 |
-
|
| 36 |
-
virtual bool get_can_edit() const = 0;
|
| 37 |
-
};
|
| 38 |
-
|
| 39 |
enum llama_memory_status {
|
| 40 |
LLAMA_MEMORY_STATUS_SUCCESS = 0,
|
|
|
|
| 41 |
LLAMA_MEMORY_STATUS_FAILED_PREPARE,
|
| 42 |
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
| 43 |
};
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
// the interface for managing the memory state during batch processing
|
| 46 |
// this interface is implemented per memory type. see:
|
| 47 |
// - llama_kv_cache_unified_state
|
|
@@ -51,8 +39,7 @@ enum llama_memory_status {
|
|
| 51 |
// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
|
| 52 |
//
|
| 53 |
// TODO: rename to llama_memory_context_i ?
|
| 54 |
-
|
| 55 |
-
public:
|
| 56 |
virtual ~llama_memory_state_i() = default;
|
| 57 |
|
| 58 |
// consume the current ubatch from the state and proceed to the next one
|
|
@@ -69,8 +56,63 @@ public:
|
|
| 69 |
// get the current ubatch
|
| 70 |
virtual const llama_ubatch & get_ubatch() const = 0;
|
| 71 |
|
| 72 |
-
// get the status of the memory state
|
| 73 |
virtual llama_memory_status get_status() const = 0;
|
| 74 |
};
|
| 75 |
|
| 76 |
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
struct llama_ubatch;
|
| 9 |
|
| 10 |
+
class llama_io_write_i;
|
| 11 |
+
class llama_io_read_i;
|
| 12 |
+
|
| 13 |
struct llama_memory_params {
|
| 14 |
// kv cache
|
| 15 |
ggml_type type_k;
|
|
|
|
| 19 |
bool swa_full;
|
| 20 |
};
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
enum llama_memory_status {
|
| 23 |
LLAMA_MEMORY_STATUS_SUCCESS = 0,
|
| 24 |
+
LLAMA_MEMORY_STATUS_NO_UPDATE,
|
| 25 |
LLAMA_MEMORY_STATUS_FAILED_PREPARE,
|
| 26 |
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
| 27 |
};
|
| 28 |
|
| 29 |
+
// helper function for combining the status of two memory states
|
| 30 |
+
// useful for implementing hybrid memory types (e.g. iSWA)
|
| 31 |
+
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
|
| 32 |
+
|
| 33 |
// the interface for managing the memory state during batch processing
|
| 34 |
// this interface is implemented per memory type. see:
|
| 35 |
// - llama_kv_cache_unified_state
|
|
|
|
| 39 |
// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
|
| 40 |
//
|
| 41 |
// TODO: rename to llama_memory_context_i ?
|
| 42 |
+
struct llama_memory_state_i {
|
|
|
|
| 43 |
virtual ~llama_memory_state_i() = default;
|
| 44 |
|
| 45 |
// consume the current ubatch from the state and proceed to the next one
|
|
|
|
| 56 |
// get the current ubatch
|
| 57 |
virtual const llama_ubatch & get_ubatch() const = 0;
|
| 58 |
|
| 59 |
+
// get the status of the memory state - used for error handling and checking if any updates would be applied
|
| 60 |
virtual llama_memory_status get_status() const = 0;
|
| 61 |
};
|
| 62 |
|
| 63 |
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
|
| 64 |
+
|
| 65 |
+
// general concept of LLM memory
|
| 66 |
+
// the KV cache is a type of LLM memory, but there can be other types
|
| 67 |
+
struct llama_memory_i {
|
| 68 |
+
virtual ~llama_memory_i() = default;
|
| 69 |
+
|
| 70 |
+
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
| 71 |
+
// return a state object containing the ubatches and KV cache state required to process them
|
| 72 |
+
// check the llama_memory_state_i::get_status() for the result
|
| 73 |
+
virtual llama_memory_state_ptr init_batch(
|
| 74 |
+
const llama_batch & batch,
|
| 75 |
+
uint32_t n_ubatch,
|
| 76 |
+
bool embd_pooled,
|
| 77 |
+
bool logits_all) = 0;
|
| 78 |
+
|
| 79 |
+
// simulate full cache, used for allocating worst-case compute buffers
|
| 80 |
+
virtual llama_memory_state_ptr init_full() = 0;
|
| 81 |
+
|
| 82 |
+
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
| 83 |
+
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
| 84 |
+
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
| 85 |
+
|
| 86 |
+
// getters
|
| 87 |
+
virtual bool get_can_shift() const = 0;
|
| 88 |
+
|
| 89 |
+
//
|
| 90 |
+
// ops
|
| 91 |
+
//
|
| 92 |
+
|
| 93 |
+
// if data == true, the data buffers will also be cleared together with the metadata
|
| 94 |
+
virtual void clear(bool data) = 0;
|
| 95 |
+
|
| 96 |
+
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
| 97 |
+
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
| 98 |
+
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
| 99 |
+
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
|
| 100 |
+
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
| 101 |
+
|
| 102 |
+
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
|
| 103 |
+
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
| 104 |
+
|
| 105 |
+
//
|
| 106 |
+
// state write/read
|
| 107 |
+
//
|
| 108 |
+
|
| 109 |
+
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
| 110 |
+
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
| 111 |
+
};
|
| 112 |
+
|
| 113 |
+
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
| 114 |
+
|
| 115 |
+
// TODO: temporary until the llama_kv_cache is removed from the public API
|
| 116 |
+
struct llama_kv_cache : public llama_memory_i {
|
| 117 |
+
virtual ~llama_kv_cache() = default;
|
| 118 |
+
};
|
examples/talk-llama/llama-mmap.cpp
CHANGED
|
@@ -401,7 +401,7 @@ struct llama_mmap::impl {
|
|
| 401 |
}
|
| 402 |
}
|
| 403 |
#else
|
| 404 |
-
|
| 405 |
#endif
|
| 406 |
}
|
| 407 |
}
|
|
|
|
| 401 |
}
|
| 402 |
}
|
| 403 |
#else
|
| 404 |
+
LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n");
|
| 405 |
#endif
|
| 406 |
}
|
| 407 |
}
|
examples/talk-llama/llama-model-loader.cpp
CHANGED
|
@@ -288,9 +288,10 @@ namespace GGUFMeta {
|
|
| 288 |
|
| 289 |
template<typename T>
|
| 290 |
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
|
| 291 |
-
const
|
|
|
|
| 292 |
|
| 293 |
-
if (kid < 0 || gguf_get_kv_type(
|
| 294 |
if (required) {
|
| 295 |
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
| 296 |
}
|
|
@@ -298,28 +299,40 @@ namespace GGUFMeta {
|
|
| 298 |
}
|
| 299 |
|
| 300 |
struct GGUFMeta::ArrayInfo arr_info =
|
| 301 |
-
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(
|
| 302 |
|
| 303 |
switch (arr_info.gt) {
|
| 304 |
case GGUF_TYPE_UINT32:
|
| 305 |
-
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T,
|
| 306 |
-
(std::is_same<T,
|
| 307 |
-
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,
|
|
|
|
| 308 |
default:
|
| 309 |
-
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
|
| 310 |
}
|
| 311 |
|
| 312 |
-
|
| 313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
return true;
|
| 316 |
}
|
| 317 |
|
| 318 |
template<typename T, size_t N_MAX>
|
| 319 |
bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
|
| 320 |
-
const
|
|
|
|
| 321 |
|
| 322 |
-
if (kid < 0 || gguf_get_kv_type(
|
| 323 |
if (required) {
|
| 324 |
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
| 325 |
}
|
|
@@ -327,22 +340,32 @@ namespace GGUFMeta {
|
|
| 327 |
}
|
| 328 |
|
| 329 |
struct GGUFMeta::ArrayInfo arr_info =
|
| 330 |
-
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(
|
| 331 |
|
| 332 |
switch (arr_info.gt) {
|
| 333 |
case GGUF_TYPE_UINT32:
|
| 334 |
-
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T,
|
| 335 |
-
(std::is_same<T,
|
| 336 |
-
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,
|
|
|
|
| 337 |
default:
|
| 338 |
-
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
|
| 339 |
}
|
| 340 |
|
| 341 |
if (arr_info.length > N_MAX) {
|
| 342 |
throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
|
| 343 |
}
|
| 344 |
|
| 345 |
-
std::
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
return true;
|
| 348 |
}
|
|
@@ -352,6 +375,8 @@ namespace GGUFMeta {
|
|
| 352 |
return get_arr(llm_kv(kid), result, required);
|
| 353 |
}
|
| 354 |
|
|
|
|
|
|
|
| 355 |
template<typename T>
|
| 356 |
bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
|
| 357 |
auto it = kv_overrides.find(key);
|
|
|
|
| 288 |
|
| 289 |
template<typename T>
|
| 290 |
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
|
| 291 |
+
const gguf_context * ctx = meta.get();
|
| 292 |
+
const int kid = gguf_find_key(ctx, key.c_str());
|
| 293 |
|
| 294 |
+
if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
|
| 295 |
if (required) {
|
| 296 |
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
| 297 |
}
|
|
|
|
| 299 |
}
|
| 300 |
|
| 301 |
struct GGUFMeta::ArrayInfo arr_info =
|
| 302 |
+
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
|
| 303 |
|
| 304 |
switch (arr_info.gt) {
|
| 305 |
case GGUF_TYPE_UINT32:
|
| 306 |
+
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
|
| 307 |
+
(std::is_same<T, uint32_t>::value)); break;
|
| 308 |
+
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
| 309 |
+
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
|
| 310 |
default:
|
| 311 |
+
throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
|
| 312 |
}
|
| 313 |
|
| 314 |
+
if constexpr (std::is_same<T, std::string>::value) {
|
| 315 |
+
const size_t n_items = gguf_get_arr_n(ctx, kid);
|
| 316 |
+
result.clear();
|
| 317 |
+
|
| 318 |
+
for (size_t i = 0; i < n_items; i++) {
|
| 319 |
+
const T value = gguf_get_arr_str(ctx, kid, i);
|
| 320 |
+
result.emplace_back(value);
|
| 321 |
+
}
|
| 322 |
+
} else {
|
| 323 |
+
result.resize(arr_info.length);
|
| 324 |
+
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
|
| 325 |
+
}
|
| 326 |
|
| 327 |
return true;
|
| 328 |
}
|
| 329 |
|
| 330 |
template<typename T, size_t N_MAX>
|
| 331 |
bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
|
| 332 |
+
const gguf_context * ctx = meta.get();
|
| 333 |
+
const int kid = gguf_find_key(ctx, key.c_str());
|
| 334 |
|
| 335 |
+
if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
|
| 336 |
if (required) {
|
| 337 |
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
| 338 |
}
|
|
|
|
| 340 |
}
|
| 341 |
|
| 342 |
struct GGUFMeta::ArrayInfo arr_info =
|
| 343 |
+
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
|
| 344 |
|
| 345 |
switch (arr_info.gt) {
|
| 346 |
case GGUF_TYPE_UINT32:
|
| 347 |
+
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
|
| 348 |
+
(std::is_same<T, uint32_t>::value)); break;
|
| 349 |
+
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
| 350 |
+
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
|
| 351 |
default:
|
| 352 |
+
throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
|
| 353 |
}
|
| 354 |
|
| 355 |
if (arr_info.length > N_MAX) {
|
| 356 |
throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
|
| 357 |
}
|
| 358 |
|
| 359 |
+
if constexpr (std::is_same<T, std::string>::value) {
|
| 360 |
+
const size_t n_items = gguf_get_arr_n(ctx, kid);
|
| 361 |
+
|
| 362 |
+
for (size_t i = 0; i < n_items; i++) {
|
| 363 |
+
const T value = gguf_get_arr_str(ctx, kid, i);
|
| 364 |
+
result[i] = value;
|
| 365 |
+
}
|
| 366 |
+
} else {
|
| 367 |
+
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
|
| 368 |
+
}
|
| 369 |
|
| 370 |
return true;
|
| 371 |
}
|
|
|
|
| 375 |
return get_arr(llm_kv(kid), result, required);
|
| 376 |
}
|
| 377 |
|
| 378 |
+
template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
|
| 379 |
+
|
| 380 |
template<typename T>
|
| 381 |
bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
|
| 382 |
auto it = kv_overrides.find(key);
|
examples/talk-llama/llama-model.cpp
CHANGED
|
@@ -543,6 +543,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 543 |
uint32_t n_vocab = 0;
|
| 544 |
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
|
| 545 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
// arch-specific KVs
|
| 547 |
switch (arch) {
|
| 548 |
case LLM_ARCH_LLAMA:
|
|
@@ -686,7 +692,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 686 |
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
| 687 |
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
| 688 |
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
|
| 689 |
-
ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
|
| 690 |
|
| 691 |
switch (hparams.n_layer) {
|
| 692 |
case 3:
|
|
@@ -956,6 +961,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 956 |
case 46: type = LLM_TYPE_27B; break;
|
| 957 |
default: type = LLM_TYPE_UNKNOWN;
|
| 958 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 959 |
} break;
|
| 960 |
case LLM_ARCH_GEMMA3:
|
| 961 |
{
|
|
@@ -976,6 +986,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 976 |
default: type = LLM_TYPE_UNKNOWN;
|
| 977 |
}
|
| 978 |
|
|
|
|
| 979 |
hparams.f_attention_scale = type == LLM_TYPE_27B
|
| 980 |
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
|
| 981 |
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
|
@@ -4356,6 +4367,15 @@ void llama_model::print_info() const {
|
|
| 4356 |
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
| 4357 |
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
|
| 4358 |
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4359 |
}
|
| 4360 |
|
| 4361 |
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
|
|
@@ -8484,14 +8504,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
|
|
| 8484 |
cb(Kcur, "Kcur", il);
|
| 8485 |
cb(Vcur, "Vcur", il);
|
| 8486 |
|
| 8487 |
-
|
| 8488 |
-
switch (model.type) {
|
| 8489 |
-
case LLM_TYPE_2B:
|
| 8490 |
-
case LLM_TYPE_9B:
|
| 8491 |
-
case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
|
| 8492 |
-
default: GGML_ABORT("fatal error");
|
| 8493 |
-
};
|
| 8494 |
-
cb(Qcur, "Qcur_scaled", il);
|
| 8495 |
|
| 8496 |
cur = build_attn(inp_attn, gf,
|
| 8497 |
model.layers[il].wo, NULL,
|
|
@@ -8632,9 +8645,12 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
|
|
| 8632 |
cb(Kcur, "Kcur", il);
|
| 8633 |
cb(Vcur, "Vcur", il);
|
| 8634 |
|
|
|
|
|
|
|
|
|
|
| 8635 |
cur = build_attn(inp_attn, gf,
|
| 8636 |
model.layers[il].wo, NULL,
|
| 8637 |
-
Qcur, Kcur, Vcur, nullptr, nullptr,
|
| 8638 |
}
|
| 8639 |
|
| 8640 |
cur = build_norm(cur,
|
|
@@ -13600,6 +13616,18 @@ int32_t llama_model_n_swa(const llama_model * model) {
|
|
| 13600 |
return model->hparams.n_swa;
|
| 13601 |
}
|
| 13602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13603 |
// deprecated
|
| 13604 |
int32_t llama_n_ctx_train(const llama_model * model) {
|
| 13605 |
return llama_model_n_ctx_train(model);
|
|
@@ -13760,7 +13788,7 @@ uint64_t llama_model_size(const llama_model * model) {
|
|
| 13760 |
}
|
| 13761 |
|
| 13762 |
const char * llama_model_chat_template(const llama_model * model, const char * name) {
|
| 13763 |
-
const auto key = name ? LLM_KV(model->arch, name)(
|
| 13764 |
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
|
| 13765 |
const auto & it = model->gguf_kv.find(key);
|
| 13766 |
if (it == model->gguf_kv.end()) {
|
|
|
|
| 543 |
uint32_t n_vocab = 0;
|
| 544 |
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
|
| 545 |
|
| 546 |
+
// for classifier models
|
| 547 |
+
ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
|
| 548 |
+
if (!classifier_labels.empty()) {
|
| 549 |
+
hparams.n_cls_out = classifier_labels.size();
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
// arch-specific KVs
|
| 553 |
switch (arch) {
|
| 554 |
case LLM_ARCH_LLAMA:
|
|
|
|
| 692 |
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
| 693 |
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
| 694 |
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
|
|
|
|
| 695 |
|
| 696 |
switch (hparams.n_layer) {
|
| 697 |
case 3:
|
|
|
|
| 961 |
case 46: type = LLM_TYPE_27B; break;
|
| 962 |
default: type = LLM_TYPE_UNKNOWN;
|
| 963 |
}
|
| 964 |
+
|
| 965 |
+
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
|
| 966 |
+
hparams.f_attention_scale = type == LLM_TYPE_27B
|
| 967 |
+
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
|
| 968 |
+
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
| 969 |
} break;
|
| 970 |
case LLM_ARCH_GEMMA3:
|
| 971 |
{
|
|
|
|
| 986 |
default: type = LLM_TYPE_UNKNOWN;
|
| 987 |
}
|
| 988 |
|
| 989 |
+
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
|
| 990 |
hparams.f_attention_scale = type == LLM_TYPE_27B
|
| 991 |
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
|
| 992 |
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
|
|
|
| 4367 |
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
| 4368 |
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
|
| 4369 |
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
|
| 4370 |
+
|
| 4371 |
+
if (!classifier_labels.empty()) {
|
| 4372 |
+
LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
|
| 4373 |
+
|
| 4374 |
+
size_t i = 0;
|
| 4375 |
+
for (auto label : classifier_labels) {
|
| 4376 |
+
LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str());
|
| 4377 |
+
}
|
| 4378 |
+
}
|
| 4379 |
}
|
| 4380 |
|
| 4381 |
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
|
|
|
|
| 8504 |
cb(Kcur, "Kcur", il);
|
| 8505 |
cb(Vcur, "Vcur", il);
|
| 8506 |
|
| 8507 |
+
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8508 |
|
| 8509 |
cur = build_attn(inp_attn, gf,
|
| 8510 |
model.layers[il].wo, NULL,
|
|
|
|
| 8645 |
cb(Kcur, "Kcur", il);
|
| 8646 |
cb(Vcur, "Vcur", il);
|
| 8647 |
|
| 8648 |
+
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
|
| 8649 |
+
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
|
| 8650 |
+
|
| 8651 |
cur = build_attn(inp_attn, gf,
|
| 8652 |
model.layers[il].wo, NULL,
|
| 8653 |
+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
| 8654 |
}
|
| 8655 |
|
| 8656 |
cur = build_norm(cur,
|
|
|
|
| 13616 |
return model->hparams.n_swa;
|
| 13617 |
}
|
| 13618 |
|
| 13619 |
+
uint32_t llama_model_n_cls_out(const struct llama_model * model) {
|
| 13620 |
+
return model->hparams.n_cls_out;
|
| 13621 |
+
}
|
| 13622 |
+
|
| 13623 |
+
const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
|
| 13624 |
+
if (i < model->classifier_labels.size()) {
|
| 13625 |
+
return model->classifier_labels[i].c_str();
|
| 13626 |
+
}
|
| 13627 |
+
|
| 13628 |
+
return nullptr;
|
| 13629 |
+
}
|
| 13630 |
+
|
| 13631 |
// deprecated
|
| 13632 |
int32_t llama_n_ctx_train(const llama_model * model) {
|
| 13633 |
return llama_model_n_ctx_train(model);
|
|
|
|
| 13788 |
}
|
| 13789 |
|
| 13790 |
const char * llama_model_chat_template(const llama_model * model, const char * name) {
|
| 13791 |
+
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)
|
| 13792 |
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
|
| 13793 |
const auto & it = model->gguf_kv.find(key);
|
| 13794 |
if (it == model->gguf_kv.end()) {
|
examples/talk-llama/llama-model.h
CHANGED
|
@@ -329,6 +329,9 @@ struct llama_model {
|
|
| 329 |
llama_hparams hparams = {};
|
| 330 |
llama_vocab vocab;
|
| 331 |
|
|
|
|
|
|
|
|
|
|
| 332 |
struct ggml_tensor * tok_embd = nullptr;
|
| 333 |
struct ggml_tensor * type_embd = nullptr;
|
| 334 |
struct ggml_tensor * pos_embd = nullptr;
|
|
|
|
| 329 |
llama_hparams hparams = {};
|
| 330 |
llama_vocab vocab;
|
| 331 |
|
| 332 |
+
// for classifier models
|
| 333 |
+
std::vector<std::string> classifier_labels;
|
| 334 |
+
|
| 335 |
struct ggml_tensor * tok_embd = nullptr;
|
| 336 |
struct ggml_tensor * type_embd = nullptr;
|
| 337 |
struct ggml_tensor * pos_embd = nullptr;
|
examples/talk-llama/llama-vocab.cpp
CHANGED
|
@@ -2080,9 +2080,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 2080 |
|
| 2081 |
std::string model_name;
|
| 2082 |
std::string tokenizer_pre;
|
|
|
|
| 2083 |
|
| 2084 |
ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
|
| 2085 |
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
|
|
|
|
| 2086 |
|
| 2087 |
// model name to lowercase
|
| 2088 |
std::transform(model_name.begin(), model_name.end(), model_name.begin(),
|
|
@@ -2091,9 +2093,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 2091 |
}
|
| 2092 |
);
|
| 2093 |
|
| 2094 |
-
// set attributes by model/tokenizer name
|
| 2095 |
-
if (
|
| 2096 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2097 |
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
|
| 2098 |
for (auto id : cache_special_tokens) {
|
| 2099 |
_set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
|
|
|
|
| 2080 |
|
| 2081 |
std::string model_name;
|
| 2082 |
std::string tokenizer_pre;
|
| 2083 |
+
std::string general_arch;
|
| 2084 |
|
| 2085 |
ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
|
| 2086 |
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
|
| 2087 |
+
ml.get_key(LLM_KV_GENERAL_ARCHITECTURE, general_arch, false);
|
| 2088 |
|
| 2089 |
// model name to lowercase
|
| 2090 |
std::transform(model_name.begin(), model_name.end(), model_name.begin(),
|
|
|
|
| 2093 |
}
|
| 2094 |
);
|
| 2095 |
|
| 2096 |
+
// set attributes by model/tokenizer/architecture name
|
| 2097 |
+
if (false
|
| 2098 |
+
|| _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
|
| 2099 |
+
|| _contains_any(general_arch, {"nomic-bert-moe"})
|
| 2100 |
+
) {
|
| 2101 |
+
if (token_to_id.count("<mask>") == 0) {
|
| 2102 |
+
LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
|
| 2103 |
+
} else {
|
| 2104 |
+
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
|
| 2105 |
+
}
|
| 2106 |
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
|
| 2107 |
for (auto id : cache_special_tokens) {
|
| 2108 |
_set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
|
examples/talk-llama/llama.h
CHANGED
|
@@ -61,7 +61,10 @@ extern "C" {
|
|
| 61 |
struct llama_model;
|
| 62 |
struct llama_context;
|
| 63 |
struct llama_sampler;
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
typedef int32_t llama_pos;
|
| 67 |
typedef int32_t llama_token;
|
|
@@ -493,9 +496,11 @@ extern "C" {
|
|
| 493 |
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
| 494 |
|
| 495 |
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
| 496 |
-
LLAMA_API
|
| 497 |
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
| 498 |
|
|
|
|
|
|
|
| 499 |
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
| 500 |
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
| 501 |
|
|
@@ -509,6 +514,13 @@ extern "C" {
|
|
| 509 |
// Get the model's RoPE frequency scaling factor
|
| 510 |
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
| 511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
|
| 513 |
|
| 514 |
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
|
|
@@ -609,7 +621,81 @@ extern "C" {
|
|
| 609 |
int32_t il_end);
|
| 610 |
|
| 611 |
//
|
| 612 |
-
//
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
//
|
| 614 |
|
| 615 |
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
|
@@ -622,86 +708,95 @@ extern "C" {
|
|
| 622 |
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
|
| 623 |
|
| 624 |
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
| 625 |
-
LLAMA_API void llama_kv_self_clear(
|
| 626 |
-
|
|
|
|
| 627 |
|
| 628 |
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
| 629 |
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
| 630 |
// seq_id < 0 : match any sequence
|
| 631 |
// p0 < 0 : [0, p1]
|
| 632 |
// p1 < 0 : [p0, inf)
|
| 633 |
-
LLAMA_API bool llama_kv_self_seq_rm(
|
| 634 |
struct llama_context * ctx,
|
| 635 |
llama_seq_id seq_id,
|
| 636 |
llama_pos p0,
|
| 637 |
-
llama_pos p1)
|
|
|
|
| 638 |
|
| 639 |
// Copy all tokens that belong to the specified sequence to another sequence
|
| 640 |
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
| 641 |
// p0 < 0 : [0, p1]
|
| 642 |
// p1 < 0 : [p0, inf)
|
| 643 |
-
LLAMA_API void llama_kv_self_seq_cp(
|
| 644 |
struct llama_context * ctx,
|
| 645 |
llama_seq_id seq_id_src,
|
| 646 |
llama_seq_id seq_id_dst,
|
| 647 |
llama_pos p0,
|
| 648 |
-
llama_pos p1)
|
|
|
|
| 649 |
|
| 650 |
// Removes all tokens that do not belong to the specified sequence
|
| 651 |
-
LLAMA_API void llama_kv_self_seq_keep(
|
| 652 |
struct llama_context * ctx,
|
| 653 |
-
llama_seq_id seq_id)
|
|
|
|
| 654 |
|
| 655 |
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
| 656 |
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
| 657 |
// - lazily on next llama_decode()
|
| 658 |
// p0 < 0 : [0, p1]
|
| 659 |
// p1 < 0 : [p0, inf)
|
| 660 |
-
LLAMA_API void llama_kv_self_seq_add(
|
| 661 |
struct llama_context * ctx,
|
| 662 |
llama_seq_id seq_id,
|
| 663 |
llama_pos p0,
|
| 664 |
llama_pos p1,
|
| 665 |
-
llama_pos delta)
|
|
|
|
| 666 |
|
| 667 |
// Integer division of the positions by factor of `d > 1`
|
| 668 |
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
| 669 |
// - lazily on next llama_decode()
|
| 670 |
// p0 < 0 : [0, p1]
|
| 671 |
// p1 < 0 : [p0, inf)
|
| 672 |
-
|
| 673 |
struct llama_context * ctx,
|
| 674 |
llama_seq_id seq_id,
|
| 675 |
llama_pos p0,
|
| 676 |
llama_pos p1,
|
| 677 |
-
int d)
|
|
|
|
| 678 |
|
| 679 |
// Returns the smallest position present in the KV cache for the specified sequence
|
| 680 |
// This is typically non-zero only for SWA caches
|
| 681 |
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
| 682 |
// Return -1 if the sequence is empty
|
| 683 |
-
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
| 684 |
struct llama_context * ctx,
|
| 685 |
-
llama_seq_id seq_id)
|
|
|
|
| 686 |
|
| 687 |
// Returns the largest position present in the KV cache for the specified sequence
|
| 688 |
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
| 689 |
// Return -1 if the sequence is empty
|
| 690 |
-
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
| 691 |
struct llama_context * ctx,
|
| 692 |
-
llama_seq_id seq_id)
|
|
|
|
| 693 |
|
| 694 |
// Defragment the KV cache
|
| 695 |
// This will be applied:
|
| 696 |
// - lazily on next llama_decode()
|
| 697 |
-
LLAMA_API
|
| 698 |
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
|
| 699 |
|
| 700 |
// Check if the context supports KV cache shifting
|
| 701 |
-
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx)
|
|
|
|
| 702 |
|
| 703 |
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
| 704 |
-
LLAMA_API
|
| 705 |
"simply remove this call, updates are applied lazily on the next llama_decode()");
|
| 706 |
|
| 707 |
//
|
|
@@ -709,7 +804,7 @@ extern "C" {
|
|
| 709 |
//
|
| 710 |
|
| 711 |
// Returns the *actual* size in bytes of the state
|
| 712 |
-
// (logits, embedding and
|
| 713 |
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
| 714 |
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
| 715 |
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
|
@@ -765,12 +860,12 @@ extern "C" {
|
|
| 765 |
size_t n_token_count),
|
| 766 |
"use llama_state_save_file instead");
|
| 767 |
|
| 768 |
-
// Get the exact size needed to copy the
|
| 769 |
LLAMA_API size_t llama_state_seq_get_size(
|
| 770 |
struct llama_context * ctx,
|
| 771 |
llama_seq_id seq_id);
|
| 772 |
|
| 773 |
-
// Copy the
|
| 774 |
LLAMA_API size_t llama_state_seq_get_data(
|
| 775 |
struct llama_context * ctx,
|
| 776 |
uint8_t * dst,
|
|
@@ -836,16 +931,16 @@ extern "C" {
|
|
| 836 |
// For encode-decoder contexts, processes the batch using the encoder.
|
| 837 |
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
| 838 |
// 0 - success
|
| 839 |
-
// < 0 - error. the
|
| 840 |
LLAMA_API int32_t llama_encode(
|
| 841 |
struct llama_context * ctx,
|
| 842 |
struct llama_batch batch);
|
| 843 |
|
| 844 |
// Process a batch of tokens.
|
| 845 |
-
// Requires
|
| 846 |
// For encode-decoder contexts, processes the batch using the decoder.
|
| 847 |
// Positive return values does not mean a fatal error, but rather a warning.
|
| 848 |
-
// Upon non-zero return values, the
|
| 849 |
// 0 - success
|
| 850 |
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
| 851 |
// 2 - aborted
|
|
@@ -916,7 +1011,7 @@ extern "C" {
|
|
| 916 |
|
| 917 |
// Get the embeddings for a sequence id
|
| 918 |
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
| 919 |
-
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[
|
| 920 |
// otherwise: float[n_embd] (1-dimensional)
|
| 921 |
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
| 922 |
|
|
|
|
| 61 |
struct llama_model;
|
| 62 |
struct llama_context;
|
| 63 |
struct llama_sampler;
|
| 64 |
+
|
| 65 |
+
typedef struct llama_memory_i * llama_memory_t;
|
| 66 |
+
|
| 67 |
+
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
|
| 68 |
|
| 69 |
typedef int32_t llama_pos;
|
| 70 |
typedef int32_t llama_token;
|
|
|
|
| 496 |
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
| 497 |
|
| 498 |
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
| 499 |
+
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
|
| 500 |
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
| 501 |
|
| 502 |
+
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
|
| 503 |
+
|
| 504 |
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
| 505 |
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
| 506 |
|
|
|
|
| 514 |
// Get the model's RoPE frequency scaling factor
|
| 515 |
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
| 516 |
|
| 517 |
+
// Returns the number of classifier outputs (only valid for classifier models)
|
| 518 |
+
// Undefined behavior for non-classifier models
|
| 519 |
+
LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);
|
| 520 |
+
|
| 521 |
+
// Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
|
| 522 |
+
LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i);
|
| 523 |
+
|
| 524 |
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
|
| 525 |
|
| 526 |
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
|
|
|
|
| 621 |
int32_t il_end);
|
| 622 |
|
| 623 |
//
|
| 624 |
+
// Memory
|
| 625 |
+
//
|
| 626 |
+
|
| 627 |
+
// Clear the memory contents
|
| 628 |
+
// If data == true, the data buffers will also be cleared together with the metadata
|
| 629 |
+
LLAMA_API void llama_memory_clear(
|
| 630 |
+
llama_memory_t mem,
|
| 631 |
+
bool data);
|
| 632 |
+
|
| 633 |
+
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
| 634 |
+
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
| 635 |
+
// seq_id < 0 : match any sequence
|
| 636 |
+
// p0 < 0 : [0, p1]
|
| 637 |
+
// p1 < 0 : [p0, inf)
|
| 638 |
+
LLAMA_API bool llama_memory_seq_rm(
|
| 639 |
+
llama_memory_t mem,
|
| 640 |
+
llama_seq_id seq_id,
|
| 641 |
+
llama_pos p0,
|
| 642 |
+
llama_pos p1);
|
| 643 |
+
|
| 644 |
+
// Copy all tokens that belong to the specified sequence to another sequence
|
| 645 |
+
// p0 < 0 : [0, p1]
|
| 646 |
+
// p1 < 0 : [p0, inf)
|
| 647 |
+
LLAMA_API void llama_memory_seq_cp(
|
| 648 |
+
llama_memory_t mem,
|
| 649 |
+
llama_seq_id seq_id_src,
|
| 650 |
+
llama_seq_id seq_id_dst,
|
| 651 |
+
llama_pos p0,
|
| 652 |
+
llama_pos p1);
|
| 653 |
+
|
| 654 |
+
// Removes all tokens that do not belong to the specified sequence
|
| 655 |
+
LLAMA_API void llama_memory_seq_keep(
|
| 656 |
+
llama_memory_t mem,
|
| 657 |
+
llama_seq_id seq_id);
|
| 658 |
+
|
| 659 |
+
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
| 660 |
+
// p0 < 0 : [0, p1]
|
| 661 |
+
// p1 < 0 : [p0, inf)
|
| 662 |
+
LLAMA_API void llama_memory_seq_add(
|
| 663 |
+
llama_memory_t mem,
|
| 664 |
+
llama_seq_id seq_id,
|
| 665 |
+
llama_pos p0,
|
| 666 |
+
llama_pos p1,
|
| 667 |
+
llama_pos delta);
|
| 668 |
+
|
| 669 |
+
// Integer division of the positions by factor of `d > 1`
|
| 670 |
+
// p0 < 0 : [0, p1]
|
| 671 |
+
// p1 < 0 : [p0, inf)
|
| 672 |
+
LLAMA_API void llama_memory_seq_div(
|
| 673 |
+
llama_memory_t mem,
|
| 674 |
+
llama_seq_id seq_id,
|
| 675 |
+
llama_pos p0,
|
| 676 |
+
llama_pos p1,
|
| 677 |
+
int d);
|
| 678 |
+
|
| 679 |
+
// Returns the smallest position present in the memory for the specified sequence
|
| 680 |
+
// This is typically non-zero only for SWA caches
|
| 681 |
+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
|
| 682 |
+
// Return -1 if the sequence is empty
|
| 683 |
+
LLAMA_API llama_pos llama_memory_seq_pos_min(
|
| 684 |
+
llama_memory_t mem,
|
| 685 |
+
llama_seq_id seq_id);
|
| 686 |
+
|
| 687 |
+
// Returns the largest position present in the memory for the specified sequence
|
| 688 |
+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
|
| 689 |
+
// Return -1 if the sequence is empty
|
| 690 |
+
LLAMA_API llama_pos llama_memory_seq_pos_max(
|
| 691 |
+
llama_memory_t mem,
|
| 692 |
+
llama_seq_id seq_id);
|
| 693 |
+
|
| 694 |
+
// Check if the memory supports shifting
|
| 695 |
+
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
|
| 696 |
+
|
| 697 |
+
//
|
| 698 |
+
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
|
| 699 |
//
|
| 700 |
|
| 701 |
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
|
|
|
| 708 |
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
|
| 709 |
|
| 710 |
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
| 711 |
+
DEPRECATED(LLAMA_API void llama_kv_self_clear(
|
| 712 |
+
struct llama_context * ctx),
|
| 713 |
+
"Use llama_memory_clear() instead");
|
| 714 |
|
| 715 |
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
| 716 |
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
| 717 |
// seq_id < 0 : match any sequence
|
| 718 |
// p0 < 0 : [0, p1]
|
| 719 |
// p1 < 0 : [p0, inf)
|
| 720 |
+
DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm(
|
| 721 |
struct llama_context * ctx,
|
| 722 |
llama_seq_id seq_id,
|
| 723 |
llama_pos p0,
|
| 724 |
+
llama_pos p1),
|
| 725 |
+
"Use llama_memory_seq_rm() instead");
|
| 726 |
|
| 727 |
// Copy all tokens that belong to the specified sequence to another sequence
|
| 728 |
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
| 729 |
// p0 < 0 : [0, p1]
|
| 730 |
// p1 < 0 : [p0, inf)
|
| 731 |
+
DEPRECATED(LLAMA_API void llama_kv_self_seq_cp(
|
| 732 |
struct llama_context * ctx,
|
| 733 |
llama_seq_id seq_id_src,
|
| 734 |
llama_seq_id seq_id_dst,
|
| 735 |
llama_pos p0,
|
| 736 |
+
llama_pos p1),
|
| 737 |
+
"Use llama_memory_seq_cp() instead");
|
| 738 |
|
| 739 |
// Removes all tokens that do not belong to the specified sequence
|
| 740 |
+
DEPRECATED(LLAMA_API void llama_kv_self_seq_keep(
|
| 741 |
struct llama_context * ctx,
|
| 742 |
+
llama_seq_id seq_id),
|
| 743 |
+
"Use llama_memory_seq_keep() instead");
|
| 744 |
|
| 745 |
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
| 746 |
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
| 747 |
// - lazily on next llama_decode()
|
| 748 |
// p0 < 0 : [0, p1]
|
| 749 |
// p1 < 0 : [p0, inf)
|
| 750 |
+
DEPRECATED(LLAMA_API void llama_kv_self_seq_add(
|
| 751 |
struct llama_context * ctx,
|
| 752 |
llama_seq_id seq_id,
|
| 753 |
llama_pos p0,
|
| 754 |
llama_pos p1,
|
| 755 |
+
llama_pos delta),
|
| 756 |
+
"Use llama_memory_seq_add() instead");
|
| 757 |
|
| 758 |
// Integer division of the positions by factor of `d > 1`
|
| 759 |
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
| 760 |
// - lazily on next llama_decode()
|
| 761 |
// p0 < 0 : [0, p1]
|
| 762 |
// p1 < 0 : [p0, inf)
|
| 763 |
+
DEPRECATED(void llama_kv_self_seq_div(
|
| 764 |
struct llama_context * ctx,
|
| 765 |
llama_seq_id seq_id,
|
| 766 |
llama_pos p0,
|
| 767 |
llama_pos p1,
|
| 768 |
+
int d),
|
| 769 |
+
"Use llama_memory_seq_div() instead");
|
| 770 |
|
| 771 |
// Returns the smallest position present in the KV cache for the specified sequence
|
| 772 |
// This is typically non-zero only for SWA caches
|
| 773 |
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
| 774 |
// Return -1 if the sequence is empty
|
| 775 |
+
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
| 776 |
struct llama_context * ctx,
|
| 777 |
+
llama_seq_id seq_id),
|
| 778 |
+
"Use llama_memory_seq_pos_min() instead");
|
| 779 |
|
| 780 |
// Returns the largest position present in the KV cache for the specified sequence
|
| 781 |
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
| 782 |
// Return -1 if the sequence is empty
|
| 783 |
+
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
| 784 |
struct llama_context * ctx,
|
| 785 |
+
llama_seq_id seq_id),
|
| 786 |
+
"Use llama_memory_seq_pos_max() instead");
|
| 787 |
|
| 788 |
// Defragment the KV cache
|
| 789 |
// This will be applied:
|
| 790 |
// - lazily on next llama_decode()
|
| 791 |
+
DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
|
| 792 |
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
|
| 793 |
|
| 794 |
// Check if the context supports KV cache shifting
|
| 795 |
+
DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx),
|
| 796 |
+
"use llama_memory_can_shift() instead");
|
| 797 |
|
| 798 |
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
| 799 |
+
DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
|
| 800 |
"simply remove this call, updates are applied lazily on the next llama_decode()");
|
| 801 |
|
| 802 |
//
|
|
|
|
| 804 |
//
|
| 805 |
|
| 806 |
// Returns the *actual* size in bytes of the state
|
| 807 |
+
// (logits, embedding and memory)
|
| 808 |
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
| 809 |
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
| 810 |
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
|
|
|
| 860 |
size_t n_token_count),
|
| 861 |
"use llama_state_save_file instead");
|
| 862 |
|
| 863 |
+
// Get the exact size needed to copy the state of a single sequence
|
| 864 |
LLAMA_API size_t llama_state_seq_get_size(
|
| 865 |
struct llama_context * ctx,
|
| 866 |
llama_seq_id seq_id);
|
| 867 |
|
| 868 |
+
// Copy the state of a single sequence into the specified buffer
|
| 869 |
LLAMA_API size_t llama_state_seq_get_data(
|
| 870 |
struct llama_context * ctx,
|
| 871 |
uint8_t * dst,
|
|
|
|
| 931 |
// For encode-decoder contexts, processes the batch using the encoder.
|
| 932 |
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
| 933 |
// 0 - success
|
| 934 |
+
// < 0 - error. the memory state is restored to the state before this call
|
| 935 |
LLAMA_API int32_t llama_encode(
|
| 936 |
struct llama_context * ctx,
|
| 937 |
struct llama_batch batch);
|
| 938 |
|
| 939 |
// Process a batch of tokens.
|
| 940 |
+
// Requires the context to have a memory.
|
| 941 |
// For encode-decoder contexts, processes the batch using the decoder.
|
| 942 |
// Positive return values does not mean a fatal error, but rather a warning.
|
| 943 |
+
// Upon non-zero return values, the memory state is restored to the state before this call
|
| 944 |
// 0 - success
|
| 945 |
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
| 946 |
// 2 - aborted
|
|
|
|
| 1011 |
|
| 1012 |
// Get the embeddings for a sequence id
|
| 1013 |
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
| 1014 |
+
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence
|
| 1015 |
// otherwise: float[n_embd] (1-dimensional)
|
| 1016 |
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
| 1017 |
|