ggerganov commited on
Commit
de68a7e
·
unverified ·
1 Parent(s): cad4b5d

talk : improve prompting

Browse files
examples/talk/README.md CHANGED
@@ -31,7 +31,7 @@ To run this, you will need a ggml GPT-2 model: [instructions](https://github.com
31
  Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
32
 
33
  ```
34
- wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://ggml.ggerganov.com/ggml-model-gpt-2-117M.bin
35
  ```
36
 
37
  ## TTS
 
31
  Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
32
 
33
  ```
34
+ wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/datasets/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin
35
  ```
36
 
37
  ## TTS
examples/talk/gpt-2.cpp CHANGED
@@ -139,7 +139,7 @@ gpt_vocab::id gpt_sample_top_k_top_p(
139
  }
140
 
141
  //printf("\n");
142
- //for (int i = 0; i < (int)logits_id.size(); i++) {
143
  // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
144
  //}
145
  //exit(0);
@@ -825,8 +825,8 @@ Me too.
825
  int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
826
 
827
  // sampling parameters
828
- int32_t top_k = 20;
829
- float top_p = 0.98f;
830
  float temp = 1.0f;
831
  };
832
 
@@ -840,7 +840,7 @@ struct gpt2_context * gpt2_init(const char * path_model) {
840
  const int64_t t_start_us = ggml_time_us();
841
 
842
  if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
843
- fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, "gpt-2.bin");
844
  return nullptr;
845
  }
846
 
@@ -913,10 +913,7 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
913
  result += ctx->vocab.id_to_token[embd[0]];
914
 
915
  // end of text token
916
- if (embd.back() == 50256 ||
917
- ctx->vocab.id_to_token[embd.back()] == "." ||
918
- ctx->vocab.id_to_token[embd.back()] == "!" ||
919
- ctx->vocab.id_to_token[embd.back()] == "?") {
920
  break;
921
  }
922
  }
 
139
  }
140
 
141
  //printf("\n");
142
+ //for (int i = 0; i < (int) logits_id.size(); i++) {
143
  // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
144
  //}
145
  //exit(0);
 
825
  int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
826
 
827
  // sampling parameters
828
+ int32_t top_k = 5;
829
+ float top_p = 0.9f;
830
  float temp = 1.0f;
831
  };
832
 
 
840
  const int64_t t_start_us = ggml_time_us();
841
 
842
  if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
843
+ fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
844
  return nullptr;
845
  }
846
 
 
913
  result += ctx->vocab.id_to_token[embd[0]];
914
 
915
  // end of text token
916
+ if (embd.back() == 50256) {
 
 
 
917
  break;
918
  }
919
  }
examples/talk/talk.cpp CHANGED
@@ -473,56 +473,15 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
473
  return result;
474
  }
475
 
476
- // compute similarity between two strings using Levenshtein distance
477
- float similarity(const std::string & s0, const std::string & s1) {
478
- const size_t len0 = s0.size() + 1;
479
- const size_t len1 = s1.size() + 1;
480
 
481
- std::vector<int> col(len1, 0);
482
- std::vector<int> prevCol(len1, 0);
 
 
483
 
484
- for (size_t i = 0; i < len1; i++) {
485
- prevCol[i] = i;
486
- }
487
-
488
- for (size_t i = 0; i < len0; i++) {
489
- col[0] = i;
490
- for (size_t j = 1; j < len1; j++) {
491
- col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (s0[i - 1] == s1[j - 1] ? 0 : 1));
492
- }
493
- col.swap(prevCol);
494
- }
495
-
496
- const float dist = prevCol[len1 - 1];
497
-
498
- return 1.0f - (dist / std::max(s0.size(), s1.size()));
499
- }
500
-
501
- // generated with ChatGPT
502
- std::map<std::string, std::string> k_prompts = {
503
- { "Santa",
504
- R"(Kid: Hi Santa! Are you real?
505
- Santa: Of course I am, my dear! Ho ho ho!
506
- Kid: Can you please bring me a new toy for Christmas?
507
- Santa: I'll see what I can do, but you have to make sure to be a good boy or girl and listen to your parents.
508
- Kid: I will, Santa! Thank you!
509
- Santa: You're welcome, little one. Merry Christmas! Ho ho ho!
510
- Kid: Can you tell me how you deliver all the presents to all the kids in the world in one night?
511
- Santa: It's a secret, but I have a lot of help from my elves and my magical sleigh. And I have a special route that I follow to make sure I visit every child.
512
- Kid: Wow, that's amazing! Can I please have a ride in your sleigh sometime?
513
- Santa: I'm sorry, but only good boys and girls get to ride in my sleigh.
514
- )" },
515
- { "Kid",
516
- R"(Kid: Hi Santa! Are you real?
517
- Santa: Of course I am, my dear! Ho ho ho!
518
- Kid: Can you please bring me a new toy for Christmas?
519
- Santa: I'll see what I can do, but you have to make sure to be a good boy or girl and listen to your parents.
520
- Kid: I will, Santa! Thank you!
521
- Kid: Can you tell me how you deliver all the presents to all the kids in the world in one night?
522
- Santa: It's a secret, but I have a lot of help from my elves and my magical sleigh. And I have a special route that I follow to make sure I visit every child.
523
- Kid: Wow, that's amazing! Can I please have a ride in your sleigh sometime?
524
- )" },
525
- };
526
 
527
  int main(int argc, char ** argv) {
528
  whisper_params params;
@@ -579,7 +538,7 @@ int main(int argc, char ** argv) {
579
  int n_iter = 0;
580
 
581
  bool is_running = true;
582
- bool force_speak = params.person == "Kid";
583
 
584
  float prob0 = 0.0f;
585
  float prob = 0.0f;
@@ -587,19 +546,13 @@ int main(int argc, char ** argv) {
587
  std::vector<float> pcmf32_cur;
588
  std::vector<float> pcmf32_prompt;
589
 
590
- if (k_prompts.find(params.person) == k_prompts.end()) {
591
- fprintf(stderr, "%s: unknown person '%s'\n", __func__, params.person.c_str());
592
- return 1;
593
- }
594
-
595
- gpt2_set_prompt(ctx_gpt, k_prompts.at(params.person).c_str());
596
 
597
- const std::string person_other = params.person == "Santa" ? "Kid" : "Santa";
598
- const int voice_id = params.person == "Santa" ? 5 : 2;
599
 
600
- fprintf(stderr, "gpt-2: prompt_base:\n");
601
  fprintf(stderr, "========================\n\n");
602
- fprintf(stderr, "%s\n", gpt2_get_prompt(ctx_gpt));
603
  fprintf(stderr, "========================\n\n");
604
 
605
  // main loop
@@ -636,13 +589,12 @@ int main(int argc, char ** argv) {
636
 
637
  audio.get(params.voice_ms, pcmf32_cur);
638
 
639
- std::string text_heard = "Hey little one, what do you want for Christmas?";
 
640
  if (!force_speak) {
641
  text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
642
  }
643
 
644
- force_speak = false;
645
-
646
  // remove text between brackets using regex
647
  {
648
  std::regex re("\\[.*?\\]");
@@ -667,13 +619,15 @@ int main(int argc, char ** argv) {
667
 
668
  const std::vector<gpt_vocab::id> tokens = gpt2_tokenize(ctx_gpt, text_heard.c_str());
669
 
670
- if (text_heard.empty() || tokens.empty()) {
671
  fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
672
  audio.clear();
673
 
674
  continue;
675
  }
676
 
 
 
677
  fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", text_heard.c_str(), "\033[0m", (int) t_ms);
678
 
679
  std::string prompt_base = gpt2_get_prompt(ctx_gpt);
@@ -681,9 +635,11 @@ int main(int argc, char ** argv) {
681
  std::string text_to_speak;
682
 
683
  {
684
- text_heard = person_other + ": " + text_heard;
685
 
686
- text_to_speak = gpt2_gen_text(ctx_gpt, (prompt_base + text_heard + "\n").c_str(), params.max_tokens);
 
 
687
  text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
688
  text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
689
 
@@ -703,13 +659,20 @@ int main(int argc, char ** argv) {
703
  }
704
  }
705
 
706
- prompt_base += text_heard + "\n" + text_to_speak + "\n";
707
- }
 
 
708
 
709
- printf("%s\n", text_to_speak.c_str());
 
 
 
 
 
710
 
711
  //printf("========================\n");
712
- //printf("gpt-2: prompt_base:\n'%s'\n", prompt_base.c_str());
713
  //printf("========================\n");
714
 
715
  gpt2_set_prompt(ctx_gpt, prompt_base.c_str());
 
473
  return result;
474
  }
475
 
476
+ const std::string k_prompt =
477
+ R"(This is a dialogue between {0} (A) and a person (B). The dialogue so far is:
 
 
478
 
479
+ B: Hello {0}, how are you?
480
+ A: I'm fine, thank you.
481
+ {1}
482
+ Here is how {0} (A) continues the dialogue:
483
 
484
+ A:)";
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
  int main(int argc, char ** argv) {
487
  whisper_params params;
 
538
  int n_iter = 0;
539
 
540
  bool is_running = true;
541
+ bool force_speak = false;
542
 
543
  float prob0 = 0.0f;
544
  float prob = 0.0f;
 
546
  std::vector<float> pcmf32_cur;
547
  std::vector<float> pcmf32_prompt;
548
 
549
+ gpt2_set_prompt(ctx_gpt, "");
 
 
 
 
 
550
 
551
+ const int voice_id = rand()%6;
 
552
 
553
+ fprintf(stderr, "gpt-2: prompt:\n");
554
  fprintf(stderr, "========================\n\n");
555
+ fprintf(stderr, "%s\n", ::replace(k_prompt, "{0}", params.person).c_str());
556
  fprintf(stderr, "========================\n\n");
557
 
558
  // main loop
 
589
 
590
  audio.get(params.voice_ms, pcmf32_cur);
591
 
592
+ std::string text_heard = "";
593
+
594
  if (!force_speak) {
595
  text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
596
  }
597
 
 
 
598
  // remove text between brackets using regex
599
  {
600
  std::regex re("\\[.*?\\]");
 
619
 
620
  const std::vector<gpt_vocab::id> tokens = gpt2_tokenize(ctx_gpt, text_heard.c_str());
621
 
622
+ if (text_heard.empty() || tokens.empty() || force_speak) {
623
  fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
624
  audio.clear();
625
 
626
  continue;
627
  }
628
 
629
+ force_speak = false;
630
+
631
  fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", text_heard.c_str(), "\033[0m", (int) t_ms);
632
 
633
  std::string prompt_base = gpt2_get_prompt(ctx_gpt);
 
635
  std::string text_to_speak;
636
 
637
  {
638
+ prompt_base += "B: " + text_heard + "\n";
639
 
640
+ std::string prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
641
+
642
+ text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
643
  text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
644
  text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
645
 
 
659
  }
660
  }
661
 
662
+ prompt_base += "A:" + text_to_speak + "\n";
663
+
664
+ {
665
+ prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
666
 
667
+ printf("===============\n");
668
+ printf("prompt:\n");
669
+ printf("%s\n", prompt.c_str());
670
+ printf("===============\n");
671
+ }
672
+ }
673
 
674
  //printf("========================\n");
675
+ //printf("gpt-2: prompt_base:\n%s\n", prompt_base.c_str());
676
  //printf("========================\n");
677
 
678
  gpt2_set_prompt(ctx_gpt, prompt_base.c_str());
models/download-ggml-model.cmd CHANGED
@@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
40
  goto :eof
41
  )
42
 
43
- PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://ggml.ggerganov.com/ggml-model-whisper-%model%.bin -OutFile ggml-%model%.bin"
44
 
45
  if %ERRORLEVEL% neq 0 (
46
  echo Failed to download ggml model %model%
 
40
  goto :eof
41
  )
42
 
43
+ PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/raw/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
44
 
45
  if %ERRORLEVEL% neq 0 (
46
  echo Failed to download ggml model %model%