ggerganov commited on
Commit
301b000
·
unverified ·
1 Parent(s): fe602cb

talk-llama : use llama_decode instead of llama_eval

Browse files
Files changed (1) hide show
  1. examples/talk-llama/talk-llama.cpp +32 -4
examples/talk-llama/talk-llama.cpp CHANGED
@@ -391,6 +391,8 @@ int main(int argc, char ** argv) {
391
 
392
  prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
393
 
 
 
394
  // init session
395
  std::string path_session = params.path_session;
396
  std::vector<llama_token> session_tokens;
@@ -426,8 +428,21 @@ int main(int argc, char ** argv) {
426
  printf("\n");
427
  printf("%s : initializing - please wait ...\n", __func__);
428
 
429
- if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0)) {
430
- fprintf(stderr, "%s : failed to eval\n", __func__);
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  return 1;
432
  }
433
 
@@ -647,8 +662,21 @@ int main(int argc, char ** argv) {
647
  n_session_consumed = session_tokens.size();
648
  }
649
 
650
- if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past)) {
651
- fprintf(stderr, "%s : failed to eval\n", __func__);
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  return 1;
653
  }
654
  }
 
391
 
392
  prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
393
 
394
+ llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
395
+
396
  // init session
397
  std::string path_session = params.path_session;
398
  std::vector<llama_token> session_tokens;
 
428
  printf("\n");
429
  printf("%s : initializing - please wait ...\n", __func__);
430
 
431
+ // prepare batch
432
+ {
433
+ batch.n_tokens = embd_inp.size();
434
+
435
+ for (int i = 0; i < batch.n_tokens; i++) {
436
+ batch.token[i] = embd_inp[i];
437
+ batch.pos[i] = i;
438
+ batch.n_seq_id[i] = 1;
439
+ batch.seq_id[i][0] = 0;
440
+ batch.logits[i] = i == batch.n_tokens - 1;
441
+ }
442
+ }
443
+
444
+ if (llama_decode(ctx_llama, batch)) {
445
+ fprintf(stderr, "%s : failed to decode\n", __func__);
446
  return 1;
447
  }
448
 
 
662
  n_session_consumed = session_tokens.size();
663
  }
664
 
665
+ // prepare batch
666
+ {
667
+ batch.n_tokens = embd.size();
668
+
669
+ for (int i = 0; i < batch.n_tokens; i++) {
670
+ batch.token[i] = embd[i];
671
+ batch.pos[i] = n_past + i;
672
+ batch.n_seq_id[i] = 1;
673
+ batch.seq_id[i][0] = 0;
674
+ batch.logits[i] = i == batch.n_tokens - 1;
675
+ }
676
+ }
677
+
678
+ if (llama_decode(ctx_llama, batch)) {
679
+ fprintf(stderr, "%s : failed to decode\n", __func__);
680
  return 1;
681
  }
682
  }