Spaces:
Sleeping
Sleeping
talk-llama : use llama_decode instead of llama_eval
Browse files
examples/talk-llama/talk-llama.cpp
CHANGED
|
@@ -391,6 +391,8 @@ int main(int argc, char ** argv) {
|
|
| 391 |
|
| 392 |
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
|
| 393 |
|
|
|
|
|
|
|
| 394 |
// init session
|
| 395 |
std::string path_session = params.path_session;
|
| 396 |
std::vector<llama_token> session_tokens;
|
|
@@ -426,8 +428,21 @@ int main(int argc, char ** argv) {
|
|
| 426 |
printf("\n");
|
| 427 |
printf("%s : initializing - please wait ...\n", __func__);
|
| 428 |
|
| 429 |
-
|
| 430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
return 1;
|
| 432 |
}
|
| 433 |
|
|
@@ -647,8 +662,21 @@ int main(int argc, char ** argv) {
|
|
| 647 |
n_session_consumed = session_tokens.size();
|
| 648 |
}
|
| 649 |
|
| 650 |
-
|
| 651 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
return 1;
|
| 653 |
}
|
| 654 |
}
|
|
|
|
| 391 |
|
| 392 |
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
|
| 393 |
|
| 394 |
+
llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
|
| 395 |
+
|
| 396 |
// init session
|
| 397 |
std::string path_session = params.path_session;
|
| 398 |
std::vector<llama_token> session_tokens;
|
|
|
|
| 428 |
printf("\n");
|
| 429 |
printf("%s : initializing - please wait ...\n", __func__);
|
| 430 |
|
| 431 |
+
// prepare batch
|
| 432 |
+
{
|
| 433 |
+
batch.n_tokens = embd_inp.size();
|
| 434 |
+
|
| 435 |
+
for (int i = 0; i < batch.n_tokens; i++) {
|
| 436 |
+
batch.token[i] = embd_inp[i];
|
| 437 |
+
batch.pos[i] = i;
|
| 438 |
+
batch.n_seq_id[i] = 1;
|
| 439 |
+
batch.seq_id[i][0] = 0;
|
| 440 |
+
batch.logits[i] = i == batch.n_tokens - 1;
|
| 441 |
+
}
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
if (llama_decode(ctx_llama, batch)) {
|
| 445 |
+
fprintf(stderr, "%s : failed to decode\n", __func__);
|
| 446 |
return 1;
|
| 447 |
}
|
| 448 |
|
|
|
|
| 662 |
n_session_consumed = session_tokens.size();
|
| 663 |
}
|
| 664 |
|
| 665 |
+
// prepare batch
|
| 666 |
+
{
|
| 667 |
+
batch.n_tokens = embd.size();
|
| 668 |
+
|
| 669 |
+
for (int i = 0; i < batch.n_tokens; i++) {
|
| 670 |
+
batch.token[i] = embd[i];
|
| 671 |
+
batch.pos[i] = n_past + i;
|
| 672 |
+
batch.n_seq_id[i] = 1;
|
| 673 |
+
batch.seq_id[i][0] = 0;
|
| 674 |
+
batch.logits[i] = i == batch.n_tokens - 1;
|
| 675 |
+
}
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
+
if (llama_decode(ctx_llama, batch)) {
|
| 679 |
+
fprintf(stderr, "%s : failed to decode\n", __func__);
|
| 680 |
return 1;
|
| 681 |
}
|
| 682 |
}
|