ggerganov commited on
Commit
2bfec97
·
1 Parent(s): 7ba8c97

whisper : update FA call

Browse files
Files changed (1) hide show
  1. src/whisper.cpp +3 -3
src/whisper.cpp CHANGED
@@ -2124,7 +2124,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
2124
  ggml_element_size(kv_pad.v)*n_state_head,
2125
  0);
2126
 
2127
- cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f);
2128
 
2129
  cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
2130
  } else {
@@ -2563,7 +2563,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2563
  ggml_element_size(kv_self.v)*n_state_head,
2564
  ggml_element_size(kv_self.v)*n_state*n_ctx*il);
2565
 
2566
- cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f);
2567
 
2568
  cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2569
  } else {
@@ -2645,7 +2645,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2645
  ggml_element_size(wstate.kv_cross.v)*n_state_head,
2646
  ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
2647
 
2648
- cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f);
2649
 
2650
  cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2651
  } else {
 
2124
  ggml_element_size(kv_pad.v)*n_state_head,
2125
  0);
2126
 
2127
+ cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
2128
 
2129
  cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
2130
  } else {
 
2563
  ggml_element_size(kv_self.v)*n_state_head,
2564
  ggml_element_size(kv_self.v)*n_state*n_ctx*il);
2565
 
2566
+ cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
2567
 
2568
  cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2569
  } else {
 
2645
  ggml_element_size(wstate.kv_cross.v)*n_state_head,
2646
  ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
2647
 
2648
+ cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
2649
 
2650
  cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2651
  } else {