Daniel Ziegenberg commited on
Commit
9a3f777
·
unverified ·
1 Parent(s): 3b7b90c

main : add options for temperature control (#2088)

Browse files

Add two options:

```
-tp, --temperature N [0.00 ] The sampling temperature, between 0 and 1
-tpi, --temperature-inc N [0.20 ] The increment of temperature, between 0 and 1
```

The sampling temperature, between 0 and 1. Higher values like 0.8 will
make the output more random, while lower values like 0.2 will make it
more focused and deterministic. If set to 0, the model will use log
probability to automatically increase the temperature until certain
thresholds are hit.

Signed-off-by: Daniel Ziegenberg <[email protected]>

Files changed (1) hide show
  1. examples/main/main.cpp +9 -1
examples/main/main.cpp CHANGED
@@ -44,6 +44,8 @@ struct whisper_params {
44
  float entropy_thold = 2.40f;
45
  float logprob_thold = -1.00f;
46
  float grammar_penalty = 100.0f;
 
 
47
 
48
  bool speed_up = false;
49
  bool debug_mode = false;
@@ -133,6 +135,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
133
  else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
134
  else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
135
  else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
 
 
136
  // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
137
  else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
138
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
@@ -198,6 +202,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
198
  fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
199
  fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
200
  fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
 
 
201
  // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
202
  fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
203
  fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
@@ -1107,7 +1113,9 @@ int main(int argc, char ** argv) {
1107
  wparams.greedy.best_of = params.best_of;
1108
  wparams.beam_search.beam_size = params.beam_size;
1109
 
1110
- wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
 
 
1111
  wparams.entropy_thold = params.entropy_thold;
1112
  wparams.logprob_thold = params.logprob_thold;
1113
 
 
44
  float entropy_thold = 2.40f;
45
  float logprob_thold = -1.00f;
46
  float grammar_penalty = 100.0f;
47
+ float temperature = 0.0f;
48
+ float temperature_inc = 0.2f;
49
 
50
  bool speed_up = false;
51
  bool debug_mode = false;
 
135
  else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
136
  else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
137
  else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
138
+ else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); }
139
+ else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); }
140
  // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
141
  else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
142
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
 
202
  fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
203
  fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
204
  fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
205
+ fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
206
+ fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc);
207
  // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
208
  fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
209
  fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
 
1113
  wparams.greedy.best_of = params.best_of;
1114
  wparams.beam_search.beam_size = params.beam_size;
1115
 
1116
+ wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc;
1117
+ wparams.temperature = params.temperature;
1118
+
1119
  wparams.entropy_thold = params.entropy_thold;
1120
  wparams.logprob_thold = params.logprob_thold;
1121