ggerganov commited on
Commit
3538ca9
·
unverified ·
1 Parent(s): f0a0087

whisper : fix external encoder (#1860)

Browse files
Files changed (1) hide show
  1. whisper.cpp +9 -32
whisper.cpp CHANGED
@@ -1659,22 +1659,9 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1659
  ggml_set_name(cur, "embd_conv");
1660
  wstate.embd_conv = cur;
1661
  } else {
1662
- #ifdef WHISPER_USE_COREML
1663
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
1664
- ggml_allocr_alloc(alloc, cur);
1665
 
1666
- if (!ggml_allocr_is_measure(alloc)) {
1667
- whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
1668
- }
1669
- #endif
1670
- #ifdef WHISPER_USE_OPENVINO
1671
  cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
1672
- ggml_allocr_alloc(alloc, cur);
1673
-
1674
- if (!ggml_allocr_is_measure(alloc)) {
1675
- whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
1676
- }
1677
- #endif
1678
 
1679
  ggml_set_name(cur, "embd_enc");
1680
  wstate.embd_enc = cur;
@@ -1708,14 +1695,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1708
 
1709
  ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
1710
 
1711
- //ggml_allocr * alloc = wstate.alloc_encode.alloc;
1712
-
1713
- //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state);
1714
- //ggml_allocr_alloc(alloc, cur);
1715
-
1716
- //if (!ggml_allocr_is_measure(alloc)) {
1717
- // ggml_backend_tensor_copy(wstate.embd_conv, cur);
1718
- //}
1719
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
1720
 
1721
  const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
@@ -1957,14 +1936,6 @@ static struct ggml_cgraph * whisper_build_graph_cross(
1957
 
1958
  ggml_cgraph * gf = ggml_new_graph(ctx0);
1959
 
1960
- //ggml_allocr * alloc = wstate.alloc_cross.alloc;
1961
-
1962
- //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
1963
- //ggml_allocr_alloc(alloc, cur);
1964
-
1965
- //if (!ggml_allocr_is_measure(alloc)) {
1966
- // ggml_backend_tensor_copy(wstate.embd_enc, cur);
1967
- //}
1968
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
1969
 
1970
  const float Kscale = pow(float(n_state) / n_head, -0.25);
@@ -2037,13 +2008,13 @@ static bool whisper_encode_internal(
2037
  return false;
2038
  }
2039
 
 
 
2040
  // set the input
2041
  {
2042
  const auto & mel_inp = wstate.mel;
2043
  const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
2044
 
2045
- struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
2046
-
2047
  assert(mel->type == GGML_TYPE_F32);
2048
  assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
2049
 
@@ -2068,6 +2039,12 @@ static bool whisper_encode_internal(
2068
  if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2069
  return false;
2070
  }
 
 
 
 
 
 
2071
  }
2072
  }
2073
 
 
1659
  ggml_set_name(cur, "embd_conv");
1660
  wstate.embd_conv = cur;
1661
  } else {
1662
+ ggml_build_forward_expand(gf, mel);
 
 
1663
 
 
 
 
 
 
1664
  cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
 
 
 
 
 
 
1665
 
1666
  ggml_set_name(cur, "embd_enc");
1667
  wstate.embd_enc = cur;
 
1695
 
1696
  ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
1697
 
 
 
 
 
 
 
 
 
1698
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
1699
 
1700
  const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
 
1936
 
1937
  ggml_cgraph * gf = ggml_new_graph(ctx0);
1938
 
 
 
 
 
 
 
 
 
1939
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
1940
 
1941
  const float Kscale = pow(float(n_state) / n_head, -0.25);
 
2008
  return false;
2009
  }
2010
 
2011
+ struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
2012
+
2013
  // set the input
2014
  {
2015
  const auto & mel_inp = wstate.mel;
2016
  const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
2017
 
 
 
2018
  assert(mel->type == GGML_TYPE_F32);
2019
  assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
2020
 
 
2039
  if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2040
  return false;
2041
  }
2042
+ } else {
2043
+ #if defined(WHISPER_USE_COREML)
2044
+ whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
2045
+ #elif defined(WHISPER_USE_OPENVINO)
2046
+ whisper_openvino_encode(wstate.ctx_openvino, mel, wstate.embd_enc);
2047
+ #endif
2048
  }
2049
  }
2050