Spaces:
Running
Running
whisper : fix bug in prompt processing (close #705)
Browse files- examples/main/main.cpp +2 -2
- whisper.cpp +23 -21
examples/main/main.cpp
CHANGED
|
@@ -208,8 +208,8 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
|
|
| 208 |
|
| 209 |
std::string speaker = "";
|
| 210 |
|
| 211 |
-
int64_t t0;
|
| 212 |
-
int64_t t1;
|
| 213 |
|
| 214 |
// print the last n_new segments
|
| 215 |
const int s0 = n_segments - n_new;
|
|
|
|
| 208 |
|
| 209 |
std::string speaker = "";
|
| 210 |
|
| 211 |
+
int64_t t0 = 0;
|
| 212 |
+
int64_t t1 = 0;
|
| 213 |
|
| 214 |
// print the last n_new segments
|
| 215 |
const int s0 = n_segments - n_new;
|
whisper.cpp
CHANGED
|
@@ -1260,12 +1260,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 1260 |
break;
|
| 1261 |
}
|
| 1262 |
|
| 1263 |
-
|
| 1264 |
-
|
| 1265 |
for (int i = 0; i < n_dims; ++i) {
|
| 1266 |
-
|
| 1267 |
-
read_safe(loader, ne_cur);
|
| 1268 |
-
ne[i] = ne_cur;
|
| 1269 |
nelements *= ne[i];
|
| 1270 |
}
|
| 1271 |
|
|
@@ -1286,15 +1284,15 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 1286 |
}
|
| 1287 |
|
| 1288 |
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
| 1289 |
-
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%
|
| 1290 |
-
__func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
|
| 1291 |
return false;
|
| 1292 |
}
|
| 1293 |
|
| 1294 |
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
|
| 1295 |
|
| 1296 |
if (nelements*bpe != ggml_nbytes(tensor)) {
|
| 1297 |
-
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %
|
| 1298 |
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
| 1299 |
return false;
|
| 1300 |
}
|
|
@@ -3819,22 +3817,26 @@ int whisper_full_with_state(
|
|
| 3819 |
prompt_past.clear();
|
| 3820 |
}
|
| 3821 |
|
| 3822 |
-
//
|
| 3823 |
-
|
| 3824 |
std::vector<whisper_token> prompt_tokens;
|
| 3825 |
-
prompt_tokens.resize(1024);
|
| 3826 |
-
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
|
| 3827 |
-
params.prompt_tokens = prompt_tokens.data();
|
| 3828 |
-
params.prompt_n_tokens = prompt_tokens.size();
|
| 3829 |
-
}
|
| 3830 |
|
| 3831 |
-
|
| 3832 |
-
|
| 3833 |
-
|
| 3834 |
-
|
| 3835 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3836 |
}
|
| 3837 |
-
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
| 3838 |
}
|
| 3839 |
|
| 3840 |
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
|
|
|
| 1260 |
break;
|
| 1261 |
}
|
| 1262 |
|
| 1263 |
+
int32_t nelements = 1;
|
| 1264 |
+
int32_t ne[3] = { 1, 1, 1 };
|
| 1265 |
for (int i = 0; i < n_dims; ++i) {
|
| 1266 |
+
read_safe(loader, ne[i]);
|
|
|
|
|
|
|
| 1267 |
nelements *= ne[i];
|
| 1268 |
}
|
| 1269 |
|
|
|
|
| 1284 |
}
|
| 1285 |
|
| 1286 |
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
| 1287 |
+
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
| 1288 |
+
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
| 1289 |
return false;
|
| 1290 |
}
|
| 1291 |
|
| 1292 |
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
|
| 1293 |
|
| 1294 |
if (nelements*bpe != ggml_nbytes(tensor)) {
|
| 1295 |
+
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
| 1296 |
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
| 1297 |
return false;
|
| 1298 |
}
|
|
|
|
| 3817 |
prompt_past.clear();
|
| 3818 |
}
|
| 3819 |
|
| 3820 |
+
// prepare prompt
|
| 3821 |
+
{
|
| 3822 |
std::vector<whisper_token> prompt_tokens;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3823 |
|
| 3824 |
+
// initial prompt
|
| 3825 |
+
if (!params.prompt_tokens && params.initial_prompt) {
|
| 3826 |
+
prompt_tokens.resize(1024);
|
| 3827 |
+
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
|
| 3828 |
+
params.prompt_tokens = prompt_tokens.data();
|
| 3829 |
+
params.prompt_n_tokens = prompt_tokens.size();
|
| 3830 |
+
}
|
| 3831 |
+
|
| 3832 |
+
// prepend the prompt tokens to the prompt_past
|
| 3833 |
+
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
|
| 3834 |
+
// parse tokens from the pointer
|
| 3835 |
+
for (int i = 0; i < params.prompt_n_tokens; i++) {
|
| 3836 |
+
prompt_past.push_back(params.prompt_tokens[i]);
|
| 3837 |
+
}
|
| 3838 |
+
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
| 3839 |
}
|
|
|
|
| 3840 |
}
|
| 3841 |
|
| 3842 |
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|