ggerganov commited on
Commit
d5d2d41
·
1 Parent(s): ae956e7

parallel : working

Browse files
examples/CMakeLists.txt CHANGED
@@ -22,6 +22,7 @@ if (EMSCRIPTEN)
22
  add_subdirectory(whisper.wasm)
23
  else()
24
  add_subdirectory(main)
 
25
  add_subdirectory(stream)
26
  add_subdirectory(bench)
27
  endif()
 
22
  add_subdirectory(whisper.wasm)
23
  else()
24
  add_subdirectory(main)
25
+ add_subdirectory(parallel)
26
  add_subdirectory(stream)
27
  add_subdirectory(bench)
28
  endif()
examples/main/main.cpp CHANGED
@@ -384,7 +384,6 @@ int main(int argc, char ** argv) {
384
  wparams.translate = params.translate;
385
  wparams.language = params.language.c_str();
386
  wparams.n_threads = params.n_threads;
387
- wparams.n_processors = 1;
388
  wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
389
  wparams.offset_ms = params.offset_t_ms;
390
 
 
384
  wparams.translate = params.translate;
385
  wparams.language = params.language.c_str();
386
  wparams.n_threads = params.n_threads;
 
387
  wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
388
  wparams.offset_ms = params.offset_t_ms;
389
 
examples/parallel/parallel.cpp CHANGED
@@ -38,10 +38,12 @@ std::string to_timestamp(int64_t t, bool comma = false) {
38
 
39
  // command-line parameters
40
  struct whisper_params {
41
- int32_t seed = -1; // RNG seed, not used currently
42
- int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
43
- int32_t offset_t_ms = 0;
44
- int32_t offset_n = 0;
 
 
45
 
46
  bool verbose = false;
47
  bool translate = false;
@@ -73,10 +75,14 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
73
  params.seed = std::stoi(argv[++i]);
74
  } else if (arg == "-t" || arg == "--threads") {
75
  params.n_threads = std::stoi(argv[++i]);
 
 
76
  } else if (arg == "-ot" || arg == "--offset-t") {
77
  params.offset_t_ms = std::stoi(argv[++i]);
78
  } else if (arg == "-on" || arg == "--offset-n") {
79
  params.offset_n = std::stoi(argv[++i]);
 
 
80
  } else if (arg == "-v" || arg == "--verbose") {
81
  params.verbose = true;
82
  } else if (arg == "--translate") {
@@ -125,8 +131,10 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
125
  fprintf(stderr, " -h, --help show this help message and exit\n");
126
  fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
127
  fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
 
128
  fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
129
  fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
 
130
  fprintf(stderr, " -v, --verbose verbose output\n");
131
  fprintf(stderr, " --translate translate from source language to english\n");
132
  fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
@@ -359,8 +367,9 @@ int main(int argc, char ** argv) {
359
  fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
360
  }
361
  }
362
- fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
363
- __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads,
 
364
  params.language.c_str(),
365
  params.translate ? "translate" : "transcribe",
366
  params.no_timestamps ? 0 : 1);
@@ -380,6 +389,7 @@ int main(int argc, char ** argv) {
380
  wparams.translate = params.translate;
381
  wparams.language = params.language.c_str();
382
  wparams.n_threads = params.n_threads;
 
383
  wparams.offset_ms = params.offset_t_ms;
384
 
385
  // this callback is called on each new segment
@@ -388,7 +398,7 @@ int main(int argc, char ** argv) {
388
  wparams.new_segment_callback_user_data = &params;
389
  }
390
 
391
- if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
392
  fprintf(stderr, "%s: failed to process audio\n", argv[0]);
393
  return 8;
394
  }
 
38
 
39
  // command-line parameters
40
  struct whisper_params {
41
+ int32_t seed = -1; // RNG seed, not used currently
42
+ int32_t n_threads = std::max(std::min(4, (int32_t) std::thread::hardware_concurrency()) / 2, 1);
43
+ int32_t n_processors = 2;
44
+ int32_t offset_t_ms = 0;
45
+ int32_t offset_n = 0;
46
+ int32_t max_context = -1;
47
 
48
  bool verbose = false;
49
  bool translate = false;
 
75
  params.seed = std::stoi(argv[++i]);
76
  } else if (arg == "-t" || arg == "--threads") {
77
  params.n_threads = std::stoi(argv[++i]);
78
+ } else if (arg == "-p" || arg == "--processors") {
79
+ params.n_processors = std::stoi(argv[++i]);
80
  } else if (arg == "-ot" || arg == "--offset-t") {
81
  params.offset_t_ms = std::stoi(argv[++i]);
82
  } else if (arg == "-on" || arg == "--offset-n") {
83
  params.offset_n = std::stoi(argv[++i]);
84
+ } else if (arg == "-mc" || arg == "--max-context") {
85
+ params.max_context = std::stoi(argv[++i]);
86
  } else if (arg == "-v" || arg == "--verbose") {
87
  params.verbose = true;
88
  } else if (arg == "--translate") {
 
131
  fprintf(stderr, " -h, --help show this help message and exit\n");
132
  fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
133
  fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
134
+ fprintf(stderr, " -p N, --processors N number of processors to use during computation (default: %d)\n", params.n_processors);
135
  fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
136
  fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
137
+ fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
138
  fprintf(stderr, " -v, --verbose verbose output\n");
139
  fprintf(stderr, " --translate translate from source language to english\n");
140
  fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
 
367
  fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
368
  }
369
  }
370
+ fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
371
+ __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
372
+ params.n_threads, params.n_processors,
373
  params.language.c_str(),
374
  params.translate ? "translate" : "transcribe",
375
  params.no_timestamps ? 0 : 1);
 
389
  wparams.translate = params.translate;
390
  wparams.language = params.language.c_str();
391
  wparams.n_threads = params.n_threads;
392
+ wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
393
  wparams.offset_ms = params.offset_t_ms;
394
 
395
  // this callback is called on each new segment
 
398
  wparams.new_segment_callback_user_data = &params;
399
  }
400
 
401
+ if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
402
  fprintf(stderr, "%s: failed to process audio\n", argv[0]);
403
  return 8;
404
  }
ggml.h CHANGED
@@ -11,7 +11,7 @@ extern "C" {
11
  #define GGML_MAX_DIMS 4
12
  #define GGML_MAX_NODES 4096
13
  #define GGML_MAX_PARAMS 16
14
- #define GGML_MAX_CONTEXTS 16
15
  #define GGML_MAX_OPT 4
16
 
17
  #ifdef __ARM_NEON
 
11
  #define GGML_MAX_DIMS 4
12
  #define GGML_MAX_NODES 4096
13
  #define GGML_MAX_PARAMS 16
14
+ #define GGML_MAX_CONTEXTS 64
15
  #define GGML_MAX_OPT 4
16
 
17
  #ifdef __ARM_NEON
whisper.cpp CHANGED
@@ -379,6 +379,7 @@ struct whisper_model {
379
 
380
  // context
381
  struct ggml_context * ctx;
 
382
 
383
  // tensors
384
  int n_loaded;
@@ -393,9 +394,10 @@ struct whisper_context {
393
  int64_t t_decode_us = 0;
394
  int64_t t_start_us = 0;
395
 
396
- std::vector<uint8_t> buf_model;
397
- std::vector<uint8_t> buf_compute;
398
- std::vector<uint8_t> buf_compute_layer;
 
399
 
400
  whisper_model model;
401
  whisper_vocab vocab;
@@ -421,7 +423,7 @@ struct whisper_context {
421
  //
422
  // see the convert-pt-to-ggml.py script for details
423
  //
424
- bool whisper_model_load(const std::string & fname, const int n_processors, whisper_context & wctx) {
425
  fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
426
 
427
  auto & model = wctx.model;
@@ -494,13 +496,16 @@ bool whisper_model_load(const std::string & fname, const int n_processors, whisp
494
  fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
495
  fprintf(stderr, "%s: type = %d\n", __func__, model.type);
496
 
497
- wctx.buf_model.resize(MEM_REQ_MODEL.at(model.type));
 
 
498
  wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
499
  wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
500
 
501
  // this is the total memory required to run the inference
502
  const size_t mem_required =
503
- wctx.buf_model.size() +
 
504
  wctx.buf_compute.size() +
505
  wctx.buf_compute_layer.size();
506
 
@@ -583,6 +588,7 @@ bool whisper_model_load(const std::string & fname, const int n_processors, whisp
583
 
584
 
585
  size_t ctx_size = 0;
 
586
 
587
  {
588
  const auto & hparams = model.hparams;
@@ -691,11 +697,11 @@ bool whisper_model_load(const std::string & fname, const int n_processors, whisp
691
  ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
692
  }
693
 
694
- ctx_size += n_processors*n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
695
- ctx_size += n_processors*n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
696
 
697
- ctx_size += n_processors*n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
698
- ctx_size += n_processors*n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
699
 
700
  ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
701
 
@@ -705,8 +711,8 @@ bool whisper_model_load(const std::string & fname, const int n_processors, whisp
705
  // create the ggml context
706
  {
707
  struct ggml_init_params params = {
708
- .mem_size = wctx.buf_model.size(),
709
- .mem_buffer = wctx.buf_model.data(),
710
  };
711
 
712
  model.ctx = ggml_init(params);
@@ -716,6 +722,20 @@ bool whisper_model_load(const std::string & fname, const int n_processors, whisp
716
  }
717
  }
718
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719
  // prepare memory for the weights
720
  {
721
  auto & ctx = model.ctx;
@@ -914,7 +934,7 @@ bool whisper_model_load(const std::string & fname, const int n_processors, whisp
914
 
915
  // key + value memory
916
  {
917
- auto & ctx = model.ctx;
918
 
919
  const auto & hparams = model.hparams;
920
 
@@ -925,7 +945,7 @@ bool whisper_model_load(const std::string & fname, const int n_processors, whisp
925
  // key/value memory for the self-attention layer
926
  {
927
  const int n_mem = n_text_layer*n_text_ctx;
928
- const int n_elements = n_text_state*n_mem*n_processors;
929
 
930
  model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
931
  model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
@@ -936,7 +956,7 @@ bool whisper_model_load(const std::string & fname, const int n_processors, whisp
936
  const int n_audio_ctx = hparams.n_audio_ctx;
937
 
938
  const int n_mem = n_text_layer*n_audio_ctx;
939
- const int n_elements = n_text_state*n_mem*n_processors;
940
 
941
  model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
942
  model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
@@ -946,7 +966,7 @@ bool whisper_model_load(const std::string & fname, const int n_processors, whisp
946
  ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
947
  ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
948
 
949
- fprintf(stderr, "%s: memory size = %8.2f MB (%d processors)\n", __func__, memory_size/1024.0/1024.0, n_processors);
950
  }
951
 
952
  // load weights
@@ -1037,8 +1057,7 @@ bool whisper_model_load(const std::string & fname, const int n_processors, whisp
1037
  bool whisper_encode(
1038
  whisper_context & wctx,
1039
  const int n_threads,
1040
- const int mel_offset,
1041
- const int processor_id) {
1042
  const auto & model = wctx.model;
1043
  const auto & mel_inp = wctx.mel;
1044
  const auto & hparams = model.hparams;
@@ -1392,11 +1411,8 @@ bool whisper_encode(
1392
  Vcross),
1393
  Vcross);
1394
 
1395
- const size_t offset_k = processor_id*(ggml_element_size(model.memory_cross_k)*n_state)*(model.hparams.n_text_layer*n_ctx);
1396
- const size_t offset_v = processor_id*(ggml_element_size(model.memory_cross_v)*n_state)*(model.hparams.n_text_layer*n_ctx);
1397
-
1398
- struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, offset_k + (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
1399
- struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, offset_v + (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
1400
 
1401
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1402
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
@@ -1429,8 +1445,7 @@ bool whisper_decode(
1429
  const int n_threads,
1430
  const whisper_token * tokens,
1431
  const int n_tokens,
1432
- const int n_past,
1433
- const int processor_id) {
1434
  const auto & model = wctx.model;
1435
  const auto & hparams = model.hparams;
1436
 
@@ -1525,13 +1540,10 @@ bool whisper_decode(
1525
  Vcur),
1526
  Vcur);
1527
 
1528
- const size_t offset_k = processor_id*(ggml_element_size(model.memory_k)*n_state)*(n_layer*n_ctx);
1529
- const size_t offset_v = processor_id*(ggml_element_size(model.memory_v)*n_state)*(n_layer*n_ctx);
1530
-
1531
  // store key and value to memory
1532
  {
1533
- struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, offset_k + (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
1534
- struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, offset_v + (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
1535
 
1536
  ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
1537
  ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
@@ -1549,7 +1561,7 @@ bool whisper_decode(
1549
  struct ggml_tensor * K =
1550
  ggml_permute(ctxL,
1551
  ggml_reshape_3d(ctxL,
1552
- ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, offset_k + il*n_ctx*ggml_element_size(model.memory_k)*n_state),
1553
  n_state/n_head, n_head, n_past + N),
1554
  0, 2, 1, 3);
1555
 
@@ -1569,7 +1581,7 @@ bool whisper_decode(
1569
  struct ggml_tensor * V_trans =
1570
  ggml_permute(ctxL,
1571
  ggml_reshape_3d(ctxL,
1572
- ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, offset_v + il*n_ctx*ggml_element_size(model.memory_v)*n_state),
1573
  n_state/n_head, n_head, n_past + N),
1574
  1, 2, 0, 3);
1575
 
@@ -1621,18 +1633,15 @@ bool whisper_decode(
1621
 
1622
  Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1623
 
1624
- const size_t offset_k = processor_id*(ggml_element_size(model.memory_cross_k)*n_state)*(n_layer*M);
1625
- const size_t offset_v = processor_id*(ggml_element_size(model.memory_cross_v)*n_state)*(n_layer*M);
1626
-
1627
  // Kcross is already scaled
1628
  struct ggml_tensor * Kcross =
1629
  ggml_reshape_3d(ctxL,
1630
- ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, offset_k + il*M*ggml_element_size(model.memory_cross_k)*n_state),
1631
  n_state/n_head, n_head, M);
1632
 
1633
  struct ggml_tensor * Vcross =
1634
  ggml_reshape_3d(ctxL,
1635
- ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, offset_v + il*M*ggml_element_size(model.memory_cross_v)*n_state),
1636
  n_state/n_head, n_head, M);
1637
 
1638
  // ------
@@ -2118,26 +2127,7 @@ struct whisper_context * whisper_init(const char * path_model) {
2118
 
2119
  ctx->t_start_us = t_start_us;
2120
 
2121
- if (!whisper_model_load(path_model, 1, *ctx)) {
2122
- fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
2123
- return NULL;
2124
- }
2125
-
2126
- ctx->t_load_us = ggml_time_us() - t_start_us;
2127
-
2128
- return ctx;
2129
- }
2130
-
2131
- struct whisper_context * whisper_init_parallel(const char * path_model, int n_processors) {
2132
- ggml_time_init();
2133
-
2134
- whisper_context * ctx = new whisper_context;
2135
-
2136
- const int64_t t_start_us = ggml_time_us();
2137
-
2138
- ctx->t_start_us = t_start_us;
2139
-
2140
- if (!whisper_model_load(path_model, n_processors, *ctx)) {
2141
  fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
2142
  return NULL;
2143
  }
@@ -2149,6 +2139,9 @@ struct whisper_context * whisper_init_parallel(const char * path_model, int n_pr
2149
 
2150
  void whisper_free(struct whisper_context * ctx) {
2151
  if (ctx) {
 
 
 
2152
  delete ctx;
2153
  }
2154
  }
@@ -2188,7 +2181,7 @@ int whisper_set_mel(
2188
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2189
  const int64_t t_start_us = ggml_time_us();
2190
 
2191
- if (!whisper_encode(*ctx, n_threads, offset, 0)) {
2192
  fprintf(stderr, "%s: failed to eval\n", __func__);
2193
  return -1;
2194
  }
@@ -2201,7 +2194,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2201
  int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
2202
  const int64_t t_start_us = ggml_time_us();
2203
 
2204
- if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past, 0)) {
2205
  fprintf(stderr, "%s: failed to eval\n", __func__);
2206
  return 1;
2207
  }
@@ -2322,7 +2315,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2322
  /*.strategy =*/ WHISPER_SAMPLING_GREEDY,
2323
 
2324
  /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
2325
- /*.n_processors =*/ 1,
2326
  /*.n_max_text_ctx =*/ 16384,
2327
  /*.offset_ms =*/ 0,
2328
 
@@ -2355,7 +2347,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2355
  /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
2356
 
2357
  /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
2358
- /*.n_processors =*/ 1,
2359
  /*.n_max_text_ctx =*/ 16384,
2360
  /*.offset_ms =*/ 0,
2361
 
@@ -2629,6 +2620,135 @@ int whisper_full(
2629
  return 0;
2630
  }
2631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2632
  int whisper_full_n_segments(struct whisper_context * ctx) {
2633
  return ctx->result_all.size();
2634
  }
 
379
 
380
  // context
381
  struct ggml_context * ctx;
382
+ struct ggml_context * ctx_mem;
383
 
384
  // tensors
385
  int n_loaded;
 
394
  int64_t t_decode_us = 0;
395
  int64_t t_start_us = 0;
396
 
397
+ std::vector<uint8_t> * buf_model; // the model buffer is read-only and can be shared between processors
398
+ std::vector<uint8_t> buf_memory;
399
+ std::vector<uint8_t> buf_compute;
400
+ std::vector<uint8_t> buf_compute_layer;
401
 
402
  whisper_model model;
403
  whisper_vocab vocab;
 
423
  //
424
  // see the convert-pt-to-ggml.py script for details
425
  //
426
+ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
427
  fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
428
 
429
  auto & model = wctx.model;
 
496
  fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
497
  fprintf(stderr, "%s: type = %d\n", __func__, model.type);
498
 
499
+ wctx.buf_model = new std::vector<uint8_t>();
500
+ wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
501
+ wctx.buf_memory.resize(std::max(MEM_REQ_MODEL.at(model.type), MEM_REQ_MODEL.at(model.type))); // TODO: TMP !!!
502
  wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
503
  wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
504
 
505
  // this is the total memory required to run the inference
506
  const size_t mem_required =
507
+ wctx.buf_model->size() +
508
+ wctx.buf_memory.size() +
509
  wctx.buf_compute.size() +
510
  wctx.buf_compute_layer.size();
511
 
 
588
 
589
 
590
  size_t ctx_size = 0;
591
+ size_t ctx_mem_size = 0;
592
 
593
  {
594
  const auto & hparams = model.hparams;
 
697
  ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
698
  }
699
 
700
+ ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
701
+ ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
702
 
703
+ ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
704
+ ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
705
 
706
  ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
707
 
 
711
  // create the ggml context
712
  {
713
  struct ggml_init_params params = {
714
+ .mem_size = wctx.buf_model->size(),
715
+ .mem_buffer = wctx.buf_model->data(),
716
  };
717
 
718
  model.ctx = ggml_init(params);
 
722
  }
723
  }
724
 
725
+ // create the ggml memory context
726
+ {
727
+ struct ggml_init_params params = {
728
+ .mem_size = wctx.buf_memory.size(),
729
+ .mem_buffer = wctx.buf_memory.data(),
730
+ };
731
+
732
+ model.ctx_mem = ggml_init(params);
733
+ if (!model.ctx_mem) {
734
+ fprintf(stderr, "%s: ggml_init() failed\n", __func__);
735
+ return false;
736
+ }
737
+ }
738
+
739
  // prepare memory for the weights
740
  {
741
  auto & ctx = model.ctx;
 
934
 
935
  // key + value memory
936
  {
937
+ auto & ctx = model.ctx_mem;
938
 
939
  const auto & hparams = model.hparams;
940
 
 
945
  // key/value memory for the self-attention layer
946
  {
947
  const int n_mem = n_text_layer*n_text_ctx;
948
+ const int n_elements = n_text_state*n_mem;
949
 
950
  model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
951
  model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
 
956
  const int n_audio_ctx = hparams.n_audio_ctx;
957
 
958
  const int n_mem = n_text_layer*n_audio_ctx;
959
+ const int n_elements = n_text_state*n_mem;
960
 
961
  model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
962
  model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
 
966
  ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
967
  ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
968
 
969
+ fprintf(stderr, "%s: memory size = %8.2f MB\n", __func__, memory_size/1024.0/1024.0);
970
  }
971
 
972
  // load weights
 
1057
  bool whisper_encode(
1058
  whisper_context & wctx,
1059
  const int n_threads,
1060
+ const int mel_offset) {
 
1061
  const auto & model = wctx.model;
1062
  const auto & mel_inp = wctx.mel;
1063
  const auto & hparams = model.hparams;
 
1411
  Vcross),
1412
  Vcross);
1413
 
1414
+ struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
1415
+ struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
 
 
 
1416
 
1417
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1418
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
 
1445
  const int n_threads,
1446
  const whisper_token * tokens,
1447
  const int n_tokens,
1448
+ const int n_past) {
 
1449
  const auto & model = wctx.model;
1450
  const auto & hparams = model.hparams;
1451
 
 
1540
  Vcur),
1541
  Vcur);
1542
 
 
 
 
1543
  // store key and value to memory
1544
  {
1545
+ struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
1546
+ struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
1547
 
1548
  ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
1549
  ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
 
1561
  struct ggml_tensor * K =
1562
  ggml_permute(ctxL,
1563
  ggml_reshape_3d(ctxL,
1564
+ ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
1565
  n_state/n_head, n_head, n_past + N),
1566
  0, 2, 1, 3);
1567
 
 
1581
  struct ggml_tensor * V_trans =
1582
  ggml_permute(ctxL,
1583
  ggml_reshape_3d(ctxL,
1584
+ ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
1585
  n_state/n_head, n_head, n_past + N),
1586
  1, 2, 0, 3);
1587
 
 
1633
 
1634
  Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1635
 
 
 
 
1636
  // Kcross is already scaled
1637
  struct ggml_tensor * Kcross =
1638
  ggml_reshape_3d(ctxL,
1639
+ ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state),
1640
  n_state/n_head, n_head, M);
1641
 
1642
  struct ggml_tensor * Vcross =
1643
  ggml_reshape_3d(ctxL,
1644
+ ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state),
1645
  n_state/n_head, n_head, M);
1646
 
1647
  // ------
 
2127
 
2128
  ctx->t_start_us = t_start_us;
2129
 
2130
+ if (!whisper_model_load(path_model, *ctx)) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2131
  fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
2132
  return NULL;
2133
  }
 
2139
 
2140
  void whisper_free(struct whisper_context * ctx) {
2141
  if (ctx) {
2142
+ if (ctx->buf_model) {
2143
+ delete ctx->buf_model;
2144
+ }
2145
  delete ctx;
2146
  }
2147
  }
 
2181
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2182
  const int64_t t_start_us = ggml_time_us();
2183
 
2184
+ if (!whisper_encode(*ctx, n_threads, offset)) {
2185
  fprintf(stderr, "%s: failed to eval\n", __func__);
2186
  return -1;
2187
  }
 
2194
  int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
2195
  const int64_t t_start_us = ggml_time_us();
2196
 
2197
+ if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) {
2198
  fprintf(stderr, "%s: failed to eval\n", __func__);
2199
  return 1;
2200
  }
 
2315
  /*.strategy =*/ WHISPER_SAMPLING_GREEDY,
2316
 
2317
  /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
 
2318
  /*.n_max_text_ctx =*/ 16384,
2319
  /*.offset_ms =*/ 0,
2320
 
 
2347
  /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
2348
 
2349
  /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
 
2350
  /*.n_max_text_ctx =*/ 16384,
2351
  /*.offset_ms =*/ 0,
2352
 
 
2620
  return 0;
2621
  }
2622
 
2623
+ int whisper_full_parallel(
2624
+ struct whisper_context * ctx,
2625
+ struct whisper_full_params params,
2626
+ const float * samples,
2627
+ int n_samples,
2628
+ const int n_processors) {
2629
+ if (n_processors == 1) {
2630
+ return whisper_full(ctx, params, samples, n_samples);
2631
+ }
2632
+
2633
+ int ret = 0;
2634
+
2635
+ // prepare separate contexts for each thread
2636
+ std::vector<struct whisper_context> ctxs(n_processors - 1);
2637
+
2638
+ for (int i = 0; i < n_processors - 1; ++i) {
2639
+ ctxs[i] = *ctx;
2640
+
2641
+ auto & model = ctxs[i].model;
2642
+
2643
+ // create the ggml memory context
2644
+ {
2645
+ struct ggml_init_params params = {
2646
+ .mem_size = ctxs[i].buf_memory.size(),
2647
+ .mem_buffer = ctxs[i].buf_memory.data(),
2648
+ };
2649
+
2650
+ model.ctx_mem = ggml_init(params);
2651
+ if (!model.ctx_mem) {
2652
+ fprintf(stderr, "%s: ggml_init() failed\n", __func__);
2653
+ return false;
2654
+ }
2655
+ }
2656
+
2657
+ // separate key + value memory for each processor
2658
+ {
2659
+ auto & ctx = model.ctx_mem;
2660
+
2661
+ const auto & hparams = model.hparams;
2662
+
2663
+ const int n_text_state = hparams.n_text_state;
2664
+ const int n_text_layer = hparams.n_text_layer;
2665
+ const int n_text_ctx = hparams.n_text_ctx;
2666
+
2667
+ // key/value memory for the self-attention layer
2668
+ {
2669
+ const int n_mem = n_text_layer*n_text_ctx;
2670
+ const int n_elements = n_text_state*n_mem;
2671
+
2672
+ model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
2673
+ model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
2674
+ }
2675
+
2676
+ // key/value memory for the cross-attention layer
2677
+ {
2678
+ const int n_audio_ctx = hparams.n_audio_ctx;
2679
+
2680
+ const int n_mem = n_text_layer*n_audio_ctx;
2681
+ const int n_elements = n_text_state*n_mem;
2682
+
2683
+ model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
2684
+ model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
2685
+ }
2686
+
2687
+ const size_t memory_size =
2688
+ ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
2689
+ ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
2690
+ }
2691
+ }
2692
+
2693
+ const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
2694
+ const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
2695
+
2696
+ // the calling thread will process the first chunk
2697
+ // while the other threads will process the remaining chunks
2698
+
2699
+ std::vector<std::thread> workers(n_processors - 1);
2700
+ for (int i = 0; i < n_processors - 1; ++i) {
2701
+ const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
2702
+ const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
2703
+
2704
+ auto params_cur = params;
2705
+
2706
+ params_cur.offset_ms = 0;
2707
+ params_cur.print_progress = false;
2708
+ params_cur.print_realtime = false;
2709
+
2710
+ params_cur.new_segment_callback = nullptr;
2711
+ params_cur.new_segment_callback_user_data = nullptr;
2712
+
2713
+ workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur);
2714
+ }
2715
+
2716
+ {
2717
+ auto params_cur = params;
2718
+
2719
+ ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
2720
+ }
2721
+
2722
+ for (int i = 0; i < n_processors - 1; ++i) {
2723
+ workers[i].join();
2724
+ }
2725
+
2726
+ const int64_t offset_t = (int64_t) params.offset_ms/10.0;
2727
+
2728
+ // combine results into ctx->result_all
2729
+ for (int i = 0; i < n_processors - 1; ++i) {
2730
+ auto & result_all = ctxs[i].result_all;
2731
+
2732
+ for (int j = 0; j < (int) result_all.size(); ++j) {
2733
+ result_all[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
2734
+ result_all[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
2735
+
2736
+ if (ctx->result_all.size() > 0) {
2737
+ result_all[j].t0 = std::max(result_all[j].t0, ctx->result_all.back().t1);
2738
+ }
2739
+
2740
+ ctx->result_all.push_back(std::move(result_all[j]));
2741
+
2742
+ // call the new_segment_callback for each segment
2743
+ if (params.new_segment_callback) {
2744
+ params.new_segment_callback(ctx, params.new_segment_callback_user_data);
2745
+ }
2746
+ }
2747
+ }
2748
+
2749
+ return ret;
2750
+ }
2751
+
2752
  int whisper_full_n_segments(struct whisper_context * ctx) {
2753
  return ctx->result_all.size();
2754
  }
whisper.h CHANGED
@@ -80,8 +80,6 @@ extern "C" {
80
  // Returns NULL on failure.
81
  WHISPER_API struct whisper_context * whisper_init(const char * path_model);
82
 
83
- WHISPER_API struct whisper_context * whisper_init_parallel(const char * path_model, int n_processors);
84
-
85
  // Frees all memory allocated by the model.
86
  WHISPER_API void whisper_free(struct whisper_context * ctx);
87
 
@@ -179,7 +177,6 @@ extern "C" {
179
  enum whisper_sampling_strategy strategy;
180
 
181
  int n_threads;
182
- int n_processors;
183
  int n_max_text_ctx;
184
  int offset_ms;
185
 
@@ -216,6 +213,13 @@ extern "C" {
216
  const float * samples,
217
  int n_samples);
218
 
 
 
 
 
 
 
 
219
  // Number of generated text segments.
220
  // A segment can be a few words, a sentence, or even a paragraph.
221
  WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
 
80
  // Returns NULL on failure.
81
  WHISPER_API struct whisper_context * whisper_init(const char * path_model);
82
 
 
 
83
  // Frees all memory allocated by the model.
84
  WHISPER_API void whisper_free(struct whisper_context * ctx);
85
 
 
177
  enum whisper_sampling_strategy strategy;
178
 
179
  int n_threads;
 
180
  int n_max_text_ctx;
181
  int offset_ms;
182
 
 
213
  const float * samples,
214
  int n_samples);
215
 
216
+ WHISPER_API int whisper_full_parallel(
217
+ struct whisper_context * ctx,
218
+ struct whisper_full_params params,
219
+ const float * samples,
220
+ int n_samples,
221
+ const int n_processors);
222
+
223
  // Number of generated text segments.
224
  // A segment can be a few words, a sentence, or even a paragraph.
225
  WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);