Luis Herrera commited on
Commit
a7b3aa5
·
unverified ·
1 Parent(s): d75ae65

talk-llama : add --session support (#845)

Browse files

* feat: adding session support

* readme: adding --session info in examples/talk-llama

* llama: adding session fixes

* readme: updating session doc

* talk-llama: update the value of need_to_save_session to true in order to save the session in the subsequent interaction

* talk-llama: adding missing function which updates session_tokens

examples/talk-llama/README.md CHANGED
@@ -25,6 +25,20 @@ make talk-llama
25
  - The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
26
  - The `-ml` argument specifies the LLaMA model that you would like to use. Read the instructions in https://github.com/ggerganov/llama.cpp for information about how to obtain a `ggml` compatible LLaMA model
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  ## TTS
29
 
30
  For best experience, this example needs a TTS tool to convert the generated text responses to voice.
 
25
  - The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
26
  - The `-ml` argument specifies the LLaMA model that you would like to use. Read the instructions in https://github.com/ggerganov/llama.cpp for information about how to obtain a `ggml` compatible LLaMA model
27
 
28
+ ## Session
29
+
30
+ The `talk-llama` tool supports session management to enable more coherent and continuous conversations. By maintaining context from previous interactions, it can better understand and respond to user requests in a more natural way.
31
+
32
+ To enable session support, use the `--session FILE` command line option when running the program. The `talk-llama` model state will be saved to the specified file after each interaction. If the file does not exist, it will be created. If the file exists, the model state will be loaded from it, allowing you to resume a previous session.
33
+
34
+ This feature is especially helpful for maintaining context in long conversations or when interacting with the AI assistant across multiple sessions. It ensures that the assistant remembers the previous interactions and can provide more relevant and contextual responses.
35
+
36
+ Example usage:
37
+
38
+ ```bash
39
+ ./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
40
+ ```
41
+
42
  ## TTS
43
 
44
  For best experience, this example needs a TTS tool to convert the generated text responses to voice.
examples/talk-llama/llama.cpp CHANGED
@@ -2695,56 +2695,81 @@ std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_te
2695
  return ctx->model.tensors_by_name;
2696
  }
2697
 
2698
- size_t llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2699
- // TODO leverage mmap
2700
  llama_file file(path_session, "rb");
2701
- const uint32_t magic = file.read_u32();
2702
- const uint32_t version = file.read_u32();
2703
 
2704
- if (!(magic == 'ggsn' && version == 0)) {
2705
- fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
2706
- return 0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2707
  }
2708
 
2709
- llama_hparams session_hparams;
2710
- file.read_raw(&session_hparams, sizeof(llama_hparams));
 
2711
 
2712
- // REVIEW
2713
- if (session_hparams != ctx->model.hparams) {
2714
- fprintf(stderr, "%s : model hparams didn't match from session file!\n", __func__);
2715
- return 0;
 
 
 
2716
  }
2717
 
2718
- const uint32_t n_token_count = file.read_u32();
2719
- LLAMA_ASSERT(n_token_capacity >= n_token_count);
2720
- file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
2721
- *n_token_count_out = n_token_count;
 
 
 
 
 
 
 
 
2722
 
2723
- const size_t n_state_size = file.size - file.tell();
2724
- const size_t n_orig_state_size = llama_get_state_size(ctx);
2725
- if (n_state_size != n_orig_state_size) {
2726
- fprintf(stderr, "%s : failed to validate state size\n", __func__);
2727
  }
2728
- std::unique_ptr<uint8_t[]> state_data(new uint8_t[n_state_size]);
2729
- file.read_raw(state_data.get(), n_state_size);
2730
- return llama_set_state_data(ctx, state_data.get());
2731
  }
2732
 
2733
- size_t llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
2734
- // TODO save temp & swap
2735
  llama_file file(path_session, "wb");
2736
 
2737
- const size_t n_state_size = llama_get_state_size(ctx);
2738
- std::unique_ptr<uint8_t[]> state_data(new uint8_t[n_state_size]);
2739
- llama_copy_state_data(ctx, state_data.get());
2740
 
2741
- file.write_u32('ggsn'); // magic
2742
- file.write_u32(0); // version
2743
  file.write_raw(&ctx->model.hparams, sizeof(llama_hparams));
2744
 
2745
- file.write_u32((uint32_t) n_token_count); // REVIEW
 
2746
  file.write_raw(tokens, sizeof(llama_token) * n_token_count);
2747
 
2748
- file.write_raw(state_data.get(), n_state_size);
2749
- return n_state_size; // REVIEW
2750
- }
 
 
 
 
 
 
 
 
 
 
2695
  return ctx->model.tensors_by_name;
2696
  }
2697
 
2698
+ bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
 
2699
  llama_file file(path_session, "rb");
 
 
2700
 
2701
+ // sanity checks
2702
+ {
2703
+ const uint32_t magic = file.read_u32();
2704
+ const uint32_t version = file.read_u32();
2705
+
2706
+ if (!(magic == LLAMA_SESSION_MAGIC && version == LLAMA_SESSION_VERSION)) {
2707
+ fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
2708
+ return false;
2709
+ }
2710
+
2711
+ llama_hparams session_hparams;
2712
+ file.read_raw(&session_hparams, sizeof(llama_hparams));
2713
+
2714
+ if (session_hparams != ctx->model.hparams) {
2715
+ fprintf(stderr, "%s : model hparams didn't match from session file!\n", __func__);
2716
+ return false;
2717
+ }
2718
  }
2719
 
2720
+ // load the prompt
2721
+ {
2722
+ const uint32_t n_token_count = file.read_u32();
2723
 
2724
+ if (n_token_count > n_token_capacity) {
2725
+ fprintf(stderr, "%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
2726
+ return false;
2727
+ }
2728
+
2729
+ file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
2730
+ *n_token_count_out = n_token_count;
2731
  }
2732
 
2733
+ // restore the context state
2734
+ {
2735
+ const size_t n_state_size_cur = file.size - file.tell();
2736
+ const size_t n_state_size_exp = llama_get_state_size(ctx);
2737
+
2738
+ if (n_state_size_cur != n_state_size_exp) {
2739
+ fprintf(stderr, "%s : the state size in session file didn't match! expected %zu, got %zu\n", __func__, n_state_size_exp, n_state_size_cur);
2740
+ return false;
2741
+ }
2742
+
2743
+ std::vector<uint8_t> state_data(n_state_size_cur);
2744
+ file.read_raw(state_data.data(), n_state_size_cur);
2745
 
2746
+ llama_set_state_data(ctx, state_data.data());
 
 
 
2747
  }
2748
+
2749
+ return true;
 
2750
  }
2751
 
2752
+ bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
 
2753
  llama_file file(path_session, "wb");
2754
 
2755
+ file.write_u32(LLAMA_SESSION_MAGIC);
2756
+ file.write_u32(LLAMA_SESSION_VERSION);
 
2757
 
 
 
2758
  file.write_raw(&ctx->model.hparams, sizeof(llama_hparams));
2759
 
2760
+ // save the prompt
2761
+ file.write_u32((uint32_t) n_token_count);
2762
  file.write_raw(tokens, sizeof(llama_token) * n_token_count);
2763
 
2764
+ // save the context state
2765
+ {
2766
+ const size_t n_state_size = llama_get_state_size(ctx);
2767
+
2768
+ std::vector<uint8_t> state_data(n_state_size);
2769
+ llama_copy_state_data(ctx, state_data.data());
2770
+
2771
+ file.write_raw(state_data.data(), n_state_size);
2772
+ }
2773
+
2774
+ return true;
2775
+ }
examples/talk-llama/llama.h CHANGED
@@ -19,9 +19,11 @@
19
  # define LLAMA_API
20
  #endif
21
 
22
- #define LLAMA_FILE_VERSION 1
23
- #define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex
24
- #define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
 
 
25
 
26
  #ifdef __cplusplus
27
  extern "C" {
@@ -138,9 +140,8 @@ extern "C" {
138
  LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);
139
 
140
  // Save/load session file
141
- LLAMA_API size_t llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
142
- LLAMA_API size_t llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
143
-
144
  // Run the llama inference to obtain the logits and probabilities for the next token.
145
  // tokens + n_tokens is the provided batch of new tokens to process
146
  // n_past is the number of tokens to use from previous eval calls
 
19
  # define LLAMA_API
20
  #endif
21
 
22
+ #define LLAMA_FILE_VERSION 1
23
+ #define LLAMA_FILE_MAGIC 'ggjt'
24
+ #define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
25
+ #define LLAMA_SESSION_MAGIC 'ggsn'
26
+ #define LLAMA_SESSION_VERSION 0
27
 
28
  #ifdef __cplusplus
29
  extern "C" {
 
140
  LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);
141
 
142
  // Save/load session file
143
+ LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
144
+ LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
 
145
  // Run the llama inference to obtain the logits and probabilities for the next token.
146
  // tokens + n_tokens is the provided batch of new tokens to process
147
  // n_past is the number of tokens to use from previous eval calls
examples/talk-llama/talk-llama.cpp CHANGED
@@ -52,6 +52,7 @@ struct whisper_params {
52
  std::string speak = "./examples/talk-llama/speak.sh";
53
  std::string prompt = "";
54
  std::string fname_out;
 
55
  };
56
 
57
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -78,6 +79,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
78
  else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
79
  else if (arg == "--verbose-prompt") { params.verbose_prompt = true; }
80
  else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
 
81
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
82
  else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
83
  else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
@@ -124,6 +126,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
124
  fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama);
125
  fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
126
  fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
 
127
  fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
128
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
129
  fprintf(stderr, "\n");
@@ -348,6 +351,57 @@ int main(int argc, char ** argv) {
348
  fflush(stdout);
349
  }
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  printf("%s : done! start speaking in the microphone\n", __func__);
352
  printf("\n");
353
  printf("%s%s", params.person.c_str(), chat_symb.c_str());
@@ -363,6 +417,7 @@ int main(int argc, char ** argv) {
363
 
364
  int n_past = n_keep;
365
  int n_prev = 64; // TODO arg
 
366
 
367
  std::vector<llama_token> embd;
368
 
@@ -450,7 +505,8 @@ int main(int argc, char ** argv) {
450
 
451
  // insert n_left/2 tokens at the start of embd from last_n_tokens
452
  embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
453
-
 
454
  //printf("\n---\n");
455
  //printf("resetting: '");
456
  //for (int i = 0; i < (int) embd.size(); i++) {
@@ -460,6 +516,29 @@ int main(int argc, char ** argv) {
460
  //printf("\n---\n");
461
  }
462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
464
  fprintf(stderr, "%s : failed to eval\n", __func__);
465
  return 1;
@@ -470,6 +549,10 @@ int main(int argc, char ** argv) {
470
 
471
  embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
472
  n_past += embd.size();
 
 
 
 
473
  embd.clear();
474
 
475
  if (done) break;
@@ -483,6 +566,11 @@ int main(int argc, char ** argv) {
483
 
484
  const int repeat_last_n = 256;
485
 
 
 
 
 
 
486
  llama_token id = 0;
487
 
488
  {
@@ -542,6 +630,7 @@ int main(int argc, char ** argv) {
542
  done = true;
543
  text_to_speak = ::replace(text_to_speak, antiprompt, "");
544
  fflush(stdout);
 
545
  break;
546
  }
547
  }
 
52
  std::string speak = "./examples/talk-llama/speak.sh";
53
  std::string prompt = "";
54
  std::string fname_out;
55
+ std::string path_session = ""; // path to file for saving/loading model eval state
56
  };
57
 
58
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
 
79
  else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
80
  else if (arg == "--verbose-prompt") { params.verbose_prompt = true; }
81
  else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
82
+ else if (arg == "--session") { params.path_session = argv[++i];}
83
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
84
  else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
85
  else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
 
126
  fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama);
127
  fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
128
  fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
129
+ fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
130
  fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
131
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
132
  fprintf(stderr, "\n");
 
351
  fflush(stdout);
352
  }
353
 
354
+ // init session
355
+ std::string path_session = params.path_session;
356
+ std::vector<llama_token> session_tokens;
357
+
358
+ if (!path_session.empty()) {
359
+ fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());
360
+
361
+ // fopen to check for existing session
362
+ FILE * fp = std::fopen(path_session.c_str(), "rb");
363
+ if (fp != NULL) {
364
+ std::fclose(fp);
365
+
366
+ session_tokens.resize(lparams.n_ctx);
367
+ size_t n_token_count_out = 0;
368
+ if (!llama_load_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
369
+ fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
370
+ return 1;
371
+ }
372
+ session_tokens.resize(n_token_count_out);
373
+
374
+ fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
375
+ } else {
376
+ fprintf(stderr, "%s: session file does not exist, will create\n", __func__);
377
+ }
378
+ }
379
+
380
+ // debug message about similarity of saved session, if applicable
381
+ size_t n_matching_session_tokens = 0;
382
+ if (session_tokens.size()) {
383
+ for (llama_token id : session_tokens) {
384
+ if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
385
+ break;
386
+ }
387
+ n_matching_session_tokens++;
388
+ }
389
+ if (n_matching_session_tokens >= embd_inp.size()) {
390
+ fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
391
+ } else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
392
+ fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
393
+ __func__, n_matching_session_tokens, embd_inp.size());
394
+ } else {
395
+ fprintf(stderr, "%s: session file matches %zu / %zu tokens of prompt\n",
396
+ __func__, n_matching_session_tokens, embd_inp.size());
397
+ }
398
+ }
399
+
400
+ // HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
401
+ // if we loaded a session with at least 75% similarity. It's currently just used to speed up the
402
+ // initial prompt so it doesn't need to be an exact match.
403
+ bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
404
+
405
  printf("%s : done! start speaking in the microphone\n", __func__);
406
  printf("\n");
407
  printf("%s%s", params.person.c_str(), chat_symb.c_str());
 
417
 
418
  int n_past = n_keep;
419
  int n_prev = 64; // TODO arg
420
+ int n_session_consumed = 0;
421
 
422
  std::vector<llama_token> embd;
423
 
 
505
 
506
  // insert n_left/2 tokens at the start of embd from last_n_tokens
507
  embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
508
+ // stop saving session if we run out of context
509
+ path_session = "";
510
  //printf("\n---\n");
511
  //printf("resetting: '");
512
  //for (int i = 0; i < (int) embd.size(); i++) {
 
516
  //printf("\n---\n");
517
  }
518
 
519
+ // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
520
+ // REVIEW
521
+ if (n_session_consumed < (int) session_tokens.size()) {
522
+ size_t i = 0;
523
+ for ( ; i < embd.size(); i++) {
524
+ if (embd[i] != session_tokens[n_session_consumed]) {
525
+ session_tokens.resize(n_session_consumed);
526
+ break;
527
+ }
528
+
529
+ n_past++;
530
+ n_session_consumed++;
531
+
532
+ if (n_session_consumed >= (int) session_tokens.size()) {
533
+ i++;
534
+ break;
535
+ }
536
+ }
537
+ if (i > 0) {
538
+ embd.erase(embd.begin(), embd.begin() + i);
539
+ }
540
+ }
541
+
542
  if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
543
  fprintf(stderr, "%s : failed to eval\n", __func__);
544
  return 1;
 
549
 
550
  embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
551
  n_past += embd.size();
552
+ if (embd.size() > 0 && !path_session.empty()) {
553
+ session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
554
+ n_session_consumed = session_tokens.size();
555
+ }
556
  embd.clear();
557
 
558
  if (done) break;
 
566
 
567
  const int repeat_last_n = 256;
568
 
569
+ if (!path_session.empty() && need_to_save_session) {
570
+ need_to_save_session = false;
571
+ llama_save_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
572
+ }
573
+
574
  llama_token id = 0;
575
 
576
  {
 
630
  done = true;
631
  text_to_speak = ::replace(text_to_speak, antiprompt, "");
632
  fflush(stdout);
633
+ need_to_save_session = true;
634
  break;
635
  }
636
  }