File size: 9,465 Bytes
58220b6
 
 
 
 
5ef1601
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
5ef1601
58220b6
 
 
 
 
 
5ef1601
 
 
 
 
 
 
 
 
 
 
 
 
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ade9bc3
58220b6
fc04dc0
58220b6
 
 
5ef1601
58220b6
 
 
5ef1601
 
 
 
 
 
 
 
 
 
 
58220b6
 
 
 
 
 
 
 
 
 
 
5ef1601
 
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ef1601
 
 
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ef1601
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
fc04dc0
 
58220b6
 
 
 
 
 
 
 
 
 
 
 
5ef1601
 
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ef1601
 
58220b6
 
 
 
 
 
 
 
 
 
5ef1601
 
 
 
58220b6
 
 
 
 
 
 
5ef1601
 
 
 
 
 
 
 
58220b6
 
5ef1601
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ef1601
58220b6
 
5ef1601
 
 
 
 
 
 
 
 
 
 
 
 
58220b6
 
 
 
5ef1601
 
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
#pragma once

#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cells.h"
#include "llama-memory.h"

#include <unordered_map>
#include <vector>

struct llama_cparams;
struct llama_hparams;
struct llama_model;
struct llama_context;

//
// llama_kv_cache_unified
//

class llama_kv_cache_unified : public llama_memory_i {
public:
    static uint32_t get_padding(const llama_cparams & cparams);

    // this callback is used to filter out layers that should not be included in the cache
    using layer_filter_cb = std::function<bool(int32_t il)>;

    using ubatch_heads = std::vector<uint32_t>;

    struct defrag_info {
        bool empty() const {
            return ids.empty();
        }

        // contains information about which cell moves where:
        //  - cell i moves to ids[i]
        //  - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
        std::vector<uint32_t> ids;
    };

    llama_kv_cache_unified(
            const llama_model &  model,
              layer_filter_cb && filter,
                    ggml_type    type_k,
                    ggml_type    type_v,
                         bool    v_trans,
                         bool    offload,
                     uint32_t    kv_size,
                     uint32_t    n_seq_max,
                     uint32_t    n_pad,
                     uint32_t    n_swa,
               llama_swa_type    swa_type);

    ~llama_kv_cache_unified() = default;

    //
    // llama_memory_i
    //

    llama_memory_state_ptr init_batch(
            llama_batch_allocr & balloc,
            uint32_t n_ubatch,
            bool embd_all) override;

    llama_memory_state_ptr init_full() override;

    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;

    bool get_can_shift() const override;

    void clear(bool data) override;

    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
    void seq_keep(llama_seq_id seq_id)                                                          override;
    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;

    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
    llama_pos seq_pos_max(llama_seq_id seq_id) const override;

    // state write/load

    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1)       override;

    //
    // llama_kv_cache_unified specific API
    //

    uint32_t get_size() const;

    bool get_has_shift() const;

    //
    // graph_build API
    //

    uint32_t get_n_kv() const;

    // get views of the current state of the cache
    ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
    ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;

    // store k_cur and v_cur in the cache based on the provided head location
    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;

    //
    // preparation API
    //

    // find places for the provided ubatches in the cache, returns the head locations
    // return empty vector on failure
    ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);

    bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);

    // return the cell position where we can insert the ubatch
    // return -1 on failure to find a contiguous slot of kv cells
    int32_t find_slot(const llama_ubatch & ubatch) const;

    // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
    void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);

    //
    // set_input API
    //

    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
    void set_input_k_shift   (ggml_tensor * dst) const;
    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;

private:
    const llama_model & model;
    const llama_hparams & hparams;

    struct kv_layer {
        // layer index in the model
        // note: can be different from the layer index in the KV cache
        uint32_t il;

        ggml_tensor * k;
        ggml_tensor * v;
    };

    bool v_trans = true;  // the value tensor is transposed

    // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
    // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
    uint32_t head = 0;

    const uint32_t n_seq_max = 1;

    // required padding
    const uint32_t n_pad = 1;

    // SWA
    const uint32_t n_swa = 0;

    int debug = 0;

    const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;

    std::vector<ggml_context_ptr>        ctxs;
    std::vector<ggml_backend_buffer_ptr> bufs;

    llama_kv_cells_unified cells;

    std::vector<kv_layer> layers;

    // model layer id -> KV cache layer id
    std::unordered_map<int32_t, int32_t> map_layer_ids;

    // return non-empty vector if cells have been moved
    defrag_info defrag_prepare(int32_t n_max_nodes) const;

    size_t total_size() const;

    size_t size_k_bytes() const;
    size_t size_v_bytes() const;

    bool is_masked_swa(llama_pos p0, llama_pos p1) const;

    ggml_tensor * build_rope_shift(
            const llama_cparams & cparams,
                   ggml_context * ctx,
                    ggml_tensor * cur,
                    ggml_tensor * shift,
                    ggml_tensor * factors,
                          float   freq_base,
                          float   freq_scale) const;

    llm_graph_result_ptr build_graph_shift(
            const llama_cparams & cparams,
                   ggml_context * ctx,
                    ggml_cgraph * gf) const;

    llm_graph_result_ptr build_graph_defrag(
            const llama_cparams & cparams,
                   ggml_context * ctx,
                    ggml_cgraph * gf,
              const defrag_info & dinfo) const;

    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;

    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};

class llama_kv_cache_unified_state : public llama_memory_state_i {
public:
    // some shorthands
    using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
    using defrag_info  = llama_kv_cache_unified::defrag_info;

    // used for errors
    llama_kv_cache_unified_state(llama_memory_status status);

    // used to create a full-cache state
    llama_kv_cache_unified_state(
            llama_kv_cache_unified * kv);

    // used to create an update state
    llama_kv_cache_unified_state(
            llama_kv_cache_unified * kv,
            llama_context * lctx,
            bool do_shift,
            defrag_info dinfo);

    // used to create a decode state from a batch
    llama_kv_cache_unified_state(
            llama_kv_cache_unified * kv,
            ubatch_heads heads,
            std::vector<llama_ubatch> ubatches);

    virtual ~llama_kv_cache_unified_state();

    //
    // llama_memory_state_i
    //

    bool next()  override;
    bool apply() override;

    llama_memory_status  get_status() const override;
    const llama_ubatch & get_ubatch() const override;

    //
    // llama_kv_cache_unified_state specific API
    //

    uint32_t get_n_kv() const;

    // get views of the current state of the cache
    ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
    ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;

    // store k_cur and v_cur in the cache based on the provided head location
    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;

    void set_input_k_shift(ggml_tensor * dst) const;

    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;

private:
    llama_memory_status status;

    llama_kv_cache_unified * kv;
    llama_context * lctx;

    //
    // update state
    //

    bool do_shift = false;

    defrag_info dinfo;

    //
    // batch processing state
    //

    // the index of the next ubatch to process
    size_t i_next = 0;

    ubatch_heads heads;

    std::vector<llama_ubatch> ubatches;

    //
    // data needed for building the compute graph for the current ubatch:
    //

    // a heuristic, to avoid attending the full cache if it is not yet utilized
    // as the cache gets filled, the benefit from this heuristic disappears
    int32_t n_kv;

    // the beginning of the current slot in which the ubatch will be inserted
    int32_t head;
};