ggerganov commited on
Commit
5a9250a
·
unverified ·
1 Parent(s): 1f06c59

ref #10 : quick-and-dirty attempt for real-time audio transciption

Browse files

- Processes input in chunks of 3 seconds.
- Padding audio with silence
- Uses 1 second audio from previous pass
- No text context

Files changed (3) hide show
  1. .gitignore +1 -0
  2. Makefile +5 -0
  3. stream.cpp +2511 -0
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  sync.sh
2
  main
 
3
  *.o
 
1
  sync.sh
2
  main
3
+ stream
4
  *.o
Makefile CHANGED
@@ -1,3 +1,5 @@
 
 
1
  main: ggml.o main.o
2
  g++ -pthread -o main ggml.o main.o
3
  ./main -h
@@ -8,6 +10,9 @@ ggml.o: ggml.c ggml.h
8
  main.o: main.cpp ggml.h
9
  g++ -pthread -O3 -std=c++11 -c main.cpp
10
 
 
 
 
11
  # clean up the directory
12
  clean:
13
  rm -f *.o main
 
1
+ CC_SDL=`sdl2-config --cflags --libs`
2
+
3
  main: ggml.o main.o
4
  g++ -pthread -o main ggml.o main.o
5
  ./main -h
 
10
  main.o: main.cpp ggml.h
11
  g++ -pthread -O3 -std=c++11 -c main.cpp
12
 
13
+ stream: stream.cpp
14
+ g++ -pthread -O3 -std=c++11 -o stream stream.cpp ggml.o $(CC_SDL)
15
+
16
  # clean up the directory
17
  clean:
18
  rm -f *.o main
stream.cpp ADDED
@@ -0,0 +1,2511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Real-time speech recognition of input from a microphone
2
+ //
3
+ // A very quick-n-dirty implementation serving mainly as a proof of concept.
4
+
5
+ #include "ggml.h"
6
+
7
+ #define USE_FLASH_ATTN
8
+ #define USE_FLASH_FF
9
+
10
+ // third-party utilities
11
+ // use your favorite implementations
12
+ #define DR_WAV_IMPLEMENTATION
13
+ #include "dr_wav.h"
14
+
15
+ #include <SDL.h>
16
+ #include <SDL_audio.h>
17
+
18
+ #include <algorithm>
19
+ #include <cassert>
20
+ #include <cmath>
21
+ #include <cstdio>
22
+ #include <cstring>
23
+ #include <fstream>
24
+ #include <map>
25
+ #include <string>
26
+ #include <thread>
27
+ #include <vector>
28
+
29
+ // available whisper models
30
+ enum e_model {
31
+ MODEL_UNKNOWN,
32
+ MODEL_TINY,
33
+ MODEL_BASE,
34
+ MODEL_SMALL,
35
+ MODEL_MEDIUM,
36
+ MODEL_LARGE,
37
+ };
38
+
39
+ const std::map<std::string, std::pair<int, std::string>> g_lang = {
40
+ { "en", { 0, "english", } },
41
+ { "zh", { 1, "chinese", } },
42
+ { "de", { 2, "german", } },
43
+ { "es", { 3, "spanish", } },
44
+ { "ru", { 4, "russian", } },
45
+ { "ko", { 5, "korean", } },
46
+ { "fr", { 6, "french", } },
47
+ { "ja", { 7, "japanese", } },
48
+ { "pt", { 8, "portuguese", } },
49
+ { "tr", { 9, "turkish", } },
50
+ { "pl", { 10, "polish", } },
51
+ { "ca", { 11, "catalan", } },
52
+ { "nl", { 12, "dutch", } },
53
+ { "ar", { 13, "arabic", } },
54
+ { "sv", { 14, "swedish", } },
55
+ { "it", { 15, "italian", } },
56
+ { "id", { 16, "indonesian", } },
57
+ { "hi", { 17, "hindi", } },
58
+ { "fi", { 18, "finnish", } },
59
+ { "vi", { 19, "vietnamese", } },
60
+ { "iw", { 20, "hebrew", } },
61
+ { "uk", { 21, "ukrainian", } },
62
+ { "el", { 22, "greek", } },
63
+ { "ms", { 23, "malay", } },
64
+ { "cs", { 24, "czech", } },
65
+ { "ro", { 25, "romanian", } },
66
+ { "da", { 26, "danish", } },
67
+ { "hu", { 27, "hungarian", } },
68
+ { "ta", { 28, "tamil", } },
69
+ { "no", { 29, "norwegian", } },
70
+ { "th", { 30, "thai", } },
71
+ { "ur", { 31, "urdu", } },
72
+ { "hr", { 32, "croatian", } },
73
+ { "bg", { 33, "bulgarian", } },
74
+ { "lt", { 34, "lithuanian", } },
75
+ { "la", { 35, "latin", } },
76
+ { "mi", { 36, "maori", } },
77
+ { "ml", { 37, "malayalam", } },
78
+ { "cy", { 38, "welsh", } },
79
+ { "sk", { 39, "slovak", } },
80
+ { "te", { 40, "telugu", } },
81
+ { "fa", { 41, "persian", } },
82
+ { "lv", { 42, "latvian", } },
83
+ { "bn", { 43, "bengali", } },
84
+ { "sr", { 44, "serbian", } },
85
+ { "az", { 45, "azerbaijani", } },
86
+ { "sl", { 46, "slovenian", } },
87
+ { "kn", { 47, "kannada", } },
88
+ { "et", { 48, "estonian", } },
89
+ { "mk", { 49, "macedonian", } },
90
+ { "br", { 50, "breton", } },
91
+ { "eu", { 51, "basque", } },
92
+ { "is", { 52, "icelandic", } },
93
+ { "hy", { 53, "armenian", } },
94
+ { "ne", { 54, "nepali", } },
95
+ { "mn", { 55, "mongolian", } },
96
+ { "bs", { 56, "bosnian", } },
97
+ { "kk", { 57, "kazakh", } },
98
+ { "sq", { 58, "albanian", } },
99
+ { "sw", { 59, "swahili", } },
100
+ { "gl", { 60, "galician", } },
101
+ { "mr", { 61, "marathi", } },
102
+ { "pa", { 62, "punjabi", } },
103
+ { "si", { 63, "sinhala", } },
104
+ { "km", { 64, "khmer", } },
105
+ { "sn", { 65, "shona", } },
106
+ { "yo", { 66, "yoruba", } },
107
+ { "so", { 67, "somali", } },
108
+ { "af", { 68, "afrikaans", } },
109
+ { "oc", { 69, "occitan", } },
110
+ { "ka", { 70, "georgian", } },
111
+ { "be", { 71, "belarusian", } },
112
+ { "tg", { 72, "tajik", } },
113
+ { "sd", { 73, "sindhi", } },
114
+ { "gu", { 74, "gujarati", } },
115
+ { "am", { 75, "amharic", } },
116
+ { "yi", { 76, "yiddish", } },
117
+ { "lo", { 77, "lao", } },
118
+ { "uz", { 78, "uzbek", } },
119
+ { "fo", { 79, "faroese", } },
120
+ { "ht", { 80, "haitian creole", } },
121
+ { "ps", { 81, "pashto", } },
122
+ { "tk", { 82, "turkmen", } },
123
+ { "nn", { 83, "nynorsk", } },
124
+ { "mt", { 84, "maltese", } },
125
+ { "sa", { 85, "sanskrit", } },
126
+ { "lb", { 86, "luxembourgish", } },
127
+ { "my", { 87, "myanmar", } },
128
+ { "bo", { 88, "tibetan", } },
129
+ { "tl", { 89, "tagalog", } },
130
+ { "mg", { 90, "malagasy", } },
131
+ { "as", { 91, "assamese", } },
132
+ { "tt", { 92, "tatar", } },
133
+ { "haw", { 93, "hawaiian", } },
134
+ { "ln", { 94, "lingala", } },
135
+ { "ha", { 95, "hausa", } },
136
+ { "ba", { 96, "bashkir", } },
137
+ { "jw", { 97, "javanese", } },
138
+ { "su", { 98, "sundanese", } },
139
+ };
140
+
141
+ const size_t MB = 1024*1024;
142
+
143
+ const std::map<e_model, size_t> MEM_REQ_MODEL = {
144
+ { MODEL_TINY, 86ull*MB },
145
+ { MODEL_BASE, 165ull*MB },
146
+ { MODEL_SMALL, 540ull*MB },
147
+ { MODEL_MEDIUM, 1650ull*MB },
148
+ { MODEL_LARGE, 3260ull*MB },
149
+ };
150
+
151
+ const std::map<e_model, size_t> MEM_REQ_ENCODE = {
152
+ { MODEL_TINY, 80ull*MB },
153
+ { MODEL_BASE, 128ull*MB },
154
+ { MODEL_SMALL, 300ull*MB },
155
+ { MODEL_MEDIUM, 680ull*MB },
156
+ { MODEL_LARGE, 1100ull*MB },
157
+ };
158
+
159
+ const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
160
+ { MODEL_TINY, 64ull*MB },
161
+ { MODEL_BASE, 84ull*MB },
162
+ { MODEL_SMALL, 128ull*MB },
163
+ { MODEL_MEDIUM, 172ull*MB },
164
+ { MODEL_LARGE, 216ull*MB },
165
+ };
166
+
167
+ const std::map<e_model, size_t> MEM_REQ_DECODE = {
168
+ { MODEL_TINY, 94ull*MB },
169
+ { MODEL_BASE, 96ull*MB },
170
+ { MODEL_SMALL, 98ull*MB },
171
+ { MODEL_MEDIUM, 100ull*MB },
172
+ { MODEL_LARGE, 102ull*MB },
173
+ };
174
+
175
+ const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
176
+ { MODEL_TINY, 32ull*MB },
177
+ { MODEL_BASE, 44ull*MB },
178
+ { MODEL_SMALL, 64ull*MB },
179
+ { MODEL_MEDIUM, 84ull*MB },
180
+ { MODEL_LARGE, 110ull*MB },
181
+ };
182
+
183
+ // the memory buffers used to store the model in memory and perform the inference computations
184
+ std::vector<uint8_t> g_buf_model;
185
+ std::vector<uint8_t> g_buf_compute;
186
+ std::vector<uint8_t> g_buf_compute_layer;
187
+
188
+ const int SAMPLE_RATE = 16000;
189
+ const int N_FFT = 400;
190
+ const int N_MEL = 80;
191
+ const int HOP_LENGTH = 160;
192
+ const int CHUNK_SIZE = 30; // seconds
193
+
194
+ struct whisper_mel {
195
+ int n_len;
196
+ int n_mel;
197
+
198
+ std::vector<float> data;
199
+ };
200
+
201
+ struct whisper_filters {
202
+ int32_t n_mel;
203
+ int32_t n_fft;
204
+
205
+ std::vector<float> data;
206
+ };
207
+
208
+ struct whisper_vocab {
209
+ using id = int32_t;
210
+ using token = std::string;
211
+
212
+ int n_vocab = 51864;
213
+
214
+ std::map<token, id> token_to_id;
215
+ std::map<id, token> id_to_token;
216
+
217
+ id token_eot = 50256;
218
+ id token_sot = 50257;
219
+ id token_prev = 50360;
220
+ id token_solm = 50361; // ??
221
+ id token_not = 50362; // no timestamps
222
+ id token_beg = 50363;
223
+
224
+ // available tasks
225
+ const id token_translate = 50358;
226
+ const id token_transcribe = 50359;
227
+
228
+ bool is_multilingual() const {
229
+ return n_vocab == 51865;
230
+ }
231
+ };
232
+
233
+ struct whisper_result {
234
+ whisper_vocab::id id;
235
+ int64_t t;
236
+ };
237
+
238
+ // command-line parameters
239
+ struct whisper_params {
240
+ int32_t seed = -1; // RNG seed, not used currently
241
+ int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
242
+
243
+ bool verbose = false;
244
+ bool translate = false;
245
+ bool print_special_tokens = false;
246
+ bool no_timestamps = true;
247
+
248
+ std::string language = "en";
249
+ std::string model = "models/ggml-base.en.bin";
250
+ std::string fname_inp = "samples/jfk.wav";
251
+ };
252
+
253
+ void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
254
+
255
+ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
256
+ for (int i = 1; i < argc; i++) {
257
+ std::string arg = argv[i];
258
+
259
+ if (arg == "-s" || arg == "--seed") {
260
+ params.seed = std::stoi(argv[++i]);
261
+ } else if (arg == "-t" || arg == "--threads") {
262
+ params.n_threads = std::stoi(argv[++i]);
263
+ } else if (arg == "-v" || arg == "--verbose") {
264
+ params.verbose = true;
265
+ } else if (arg == "--translate") {
266
+ params.translate = true;
267
+ } else if (arg == "-l" || arg == "--language") {
268
+ params.language = argv[++i];
269
+ if (g_lang.find(params.language) == g_lang.end()) {
270
+ fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
271
+ whisper_print_usage(argc, argv, params);
272
+ exit(0);
273
+ }
274
+ } else if (arg == "-ps" || arg == "--print_special") {
275
+ params.print_special_tokens = true;
276
+ } else if (arg == "-nt" || arg == "--no_timestamps") {
277
+ params.no_timestamps = true;
278
+ } else if (arg == "-m" || arg == "--model") {
279
+ params.model = argv[++i];
280
+ } else if (arg == "-f" || arg == "--file") {
281
+ params.fname_inp = argv[++i];
282
+ } else if (arg == "-h" || arg == "--help") {
283
+ whisper_print_usage(argc, argv, params);
284
+ exit(0);
285
+ } else {
286
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
287
+ whisper_print_usage(argc, argv, params);
288
+ exit(0);
289
+ }
290
+ }
291
+
292
+ return true;
293
+ }
294
+
295
+ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
296
+ fprintf(stderr, "\n");
297
+ fprintf(stderr, "usage: %s [options]\n", argv[0]);
298
+ fprintf(stderr, "\n");
299
+ fprintf(stderr, "options:\n");
300
+ fprintf(stderr, " -h, --help show this help message and exit\n");
301
+ fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
302
+ fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
303
+ fprintf(stderr, " -v, --verbose verbose output\n");
304
+ fprintf(stderr, " --translate translate from source language to english\n");
305
+ fprintf(stderr, " -ps, --print_special print special tokens\n");
306
+ fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
307
+ fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
308
+ fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
309
+ fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str());
310
+ fprintf(stderr, "\n");
311
+ }
312
+
313
+
314
+ // medium
315
+ // hparams: {
316
+ // 'n_mels': 80,
317
+ // 'n_vocab': 51864,
318
+ // 'n_audio_ctx': 1500,
319
+ // 'n_audio_state': 1024,
320
+ // 'n_audio_head': 16,
321
+ // 'n_audio_layer': 24,
322
+ // 'n_text_ctx': 448,
323
+ // 'n_text_state': 1024,
324
+ // 'n_text_head': 16,
325
+ // 'n_text_layer': 24
326
+ // }
327
+ //
328
+ // default hparams (Whisper tiny)
329
+ struct whisper_hparams {
330
+ int32_t n_vocab = 51864;
331
+ int32_t n_audio_ctx = 1500;
332
+ int32_t n_audio_state = 384;
333
+ int32_t n_audio_head = 6;
334
+ int32_t n_audio_layer = 4;
335
+ int32_t n_text_ctx = 448;
336
+ int32_t n_text_state = 384;
337
+ int32_t n_text_head = 6;
338
+ int32_t n_text_layer = 4;
339
+ int32_t n_mels = 80;
340
+ int32_t f16 = 1;
341
+ };
342
+
343
+ // audio encoding layer
344
+ struct whisper_layer_encoder {
345
+ // encoder.blocks.*.attn_ln
346
+ struct ggml_tensor * attn_ln_0_w;
347
+ struct ggml_tensor * attn_ln_0_b;
348
+
349
+ // encoder.blocks.*.attn.out
350
+ struct ggml_tensor * attn_ln_1_w;
351
+ struct ggml_tensor * attn_ln_1_b;
352
+
353
+ // encoder.blocks.*.attn.query
354
+ struct ggml_tensor * attn_q_w;
355
+ struct ggml_tensor * attn_q_b;
356
+
357
+ // encoder.blocks.*.attn.key
358
+ struct ggml_tensor * attn_k_w;
359
+
360
+ // encoder.blocks.*.attn.value
361
+ struct ggml_tensor * attn_v_w;
362
+ struct ggml_tensor * attn_v_b;
363
+
364
+ // encoder.blocks.*.mlp_ln
365
+ struct ggml_tensor * mlp_ln_w;
366
+ struct ggml_tensor * mlp_ln_b;
367
+
368
+ // encoder.blocks.*.mlp.0
369
+ struct ggml_tensor * mlp_0_w;
370
+ struct ggml_tensor * mlp_0_b;
371
+
372
+ // encoder.blocks.*.mlp.2
373
+ struct ggml_tensor * mlp_1_w;
374
+ struct ggml_tensor * mlp_1_b;
375
+ };
376
+
377
+ // token decoding layer
378
+ struct whisper_layer_decoder {
379
+ // decoder.blocks.*.attn_ln
380
+ struct ggml_tensor * attn_ln_0_w;
381
+ struct ggml_tensor * attn_ln_0_b;
382
+
383
+ // decoder.blocks.*.attn.out
384
+ struct ggml_tensor * attn_ln_1_w;
385
+ struct ggml_tensor * attn_ln_1_b;
386
+
387
+ // decoder.blocks.*.attn.query
388
+ struct ggml_tensor * attn_q_w;
389
+ struct ggml_tensor * attn_q_b;
390
+
391
+ // decoder.blocks.*.attn.key
392
+ struct ggml_tensor * attn_k_w;
393
+
394
+ // decoder.blocks.*.attn.value
395
+ struct ggml_tensor * attn_v_w;
396
+ struct ggml_tensor * attn_v_b;
397
+
398
+ // decoder.blocks.*.cross_attn_ln
399
+ struct ggml_tensor * cross_attn_ln_0_w;
400
+ struct ggml_tensor * cross_attn_ln_0_b;
401
+
402
+ // decoder.blocks.*.cross_attn.out
403
+ struct ggml_tensor * cross_attn_ln_1_w;
404
+ struct ggml_tensor * cross_attn_ln_1_b;
405
+
406
+ // decoder.blocks.*.cross_attn.query
407
+ struct ggml_tensor * cross_attn_q_w;
408
+ struct ggml_tensor * cross_attn_q_b;
409
+
410
+ // decoder.blocks.*.cross_attn.key
411
+ struct ggml_tensor * cross_attn_k_w;
412
+
413
+ // decoder.blocks.*.cross_attn.value
414
+ struct ggml_tensor * cross_attn_v_w;
415
+ struct ggml_tensor * cross_attn_v_b;
416
+
417
+ // decoder.blocks.*.mlp_ln
418
+ struct ggml_tensor * mlp_ln_w;
419
+ struct ggml_tensor * mlp_ln_b;
420
+
421
+ // decoder.blocks.*.mlp.0
422
+ struct ggml_tensor * mlp_0_w;
423
+ struct ggml_tensor * mlp_0_b;
424
+
425
+ // decoder.blocks.*.mlp.2
426
+ struct ggml_tensor * mlp_1_w;
427
+ struct ggml_tensor * mlp_1_b;
428
+ };
429
+
430
+ struct whisper_model {
431
+ e_model type = MODEL_UNKNOWN;
432
+
433
+ whisper_hparams hparams;
434
+ whisper_filters filters;
435
+
436
+ // encoder.positional_embedding
437
+ struct ggml_tensor * e_pe;
438
+
439
+ // encoder.conv1
440
+ struct ggml_tensor * e_conv_1_w;
441
+ struct ggml_tensor * e_conv_1_b;
442
+
443
+ // encoder.conv2
444
+ struct ggml_tensor * e_conv_2_w;
445
+ struct ggml_tensor * e_conv_2_b;
446
+
447
+ // encoder.ln_post
448
+ struct ggml_tensor * e_ln_w;
449
+ struct ggml_tensor * e_ln_b;
450
+
451
+ // decoder.positional_embedding
452
+ struct ggml_tensor * d_pe; // DD
453
+
454
+ // decoder.token_embedding
455
+ struct ggml_tensor * d_te; // DD
456
+
457
+ // decoder.ln
458
+ struct ggml_tensor * d_ln_w; // DD
459
+ struct ggml_tensor * d_ln_b; // DD
460
+
461
+ std::vector<whisper_layer_encoder> layers_encoder;
462
+ std::vector<whisper_layer_decoder> layers_decoder;
463
+
464
+ // key + value memory
465
+ struct ggml_tensor * memory_k;
466
+ struct ggml_tensor * memory_v;
467
+
468
+ struct ggml_tensor * memory_cross_k;
469
+ struct ggml_tensor * memory_cross_v;
470
+
471
+ //
472
+ struct ggml_context * ctx;
473
+ std::map<std::string, struct ggml_tensor *> tensors;
474
+ };
475
+
476
+ // load the model from a ggml file
477
+ //
478
+ // file format:
479
+ //
480
+ // - hparams
481
+ // - pre-computed mel filters
482
+ // - vocab
483
+ // - weights
484
+ //
485
+ // see the convert-pt-to-ggml.py script for details
486
+ //
487
+ bool whisper_model_load(const std::string & fname, whisper_model & model, whisper_vocab & vocab) {
488
+ printf("%s: loading model from '%s'\n", __func__, fname.c_str());
489
+
490
+ auto fin = std::ifstream(fname, std::ios::binary);
491
+ if (!fin) {
492
+ fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
493
+ return false;
494
+ }
495
+
496
+ // verify magic
497
+ {
498
+ uint32_t magic;
499
+ fin.read((char *) &magic, sizeof(magic));
500
+ if (magic != 0x67676d6c) {
501
+ fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
502
+ return false;
503
+ }
504
+ }
505
+
506
+ //load hparams
507
+ {
508
+ auto & hparams = model.hparams;
509
+
510
+ fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
511
+ fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
512
+ fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
513
+ fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
514
+ fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
515
+ fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
516
+ fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
517
+ fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
518
+ fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
519
+ fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
520
+ fin.read((char *) &hparams.f16, sizeof(hparams.f16));
521
+
522
+ assert(hparams.n_text_state == hparams.n_audio_state);
523
+
524
+ if (hparams.n_audio_layer == 4) {
525
+ model.type = e_model::MODEL_TINY;
526
+ }
527
+
528
+ if (hparams.n_audio_layer == 6) {
529
+ model.type = e_model::MODEL_BASE;
530
+ }
531
+
532
+ if (hparams.n_audio_layer == 12) {
533
+ model.type = e_model::MODEL_SMALL;
534
+ }
535
+
536
+ if (hparams.n_audio_layer == 24) {
537
+ model.type = e_model::MODEL_MEDIUM;
538
+ }
539
+
540
+ if (hparams.n_audio_layer == 32) {
541
+ model.type = e_model::MODEL_LARGE;
542
+ }
543
+
544
+ printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
545
+ printf("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
546
+ printf("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
547
+ printf("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
548
+ printf("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
549
+ printf("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
550
+ printf("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
551
+ printf("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
552
+ printf("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
553
+ printf("%s: n_mels = %d\n", __func__, hparams.n_mels);
554
+ printf("%s: f16 = %d\n", __func__, hparams.f16);
555
+ printf("%s: type = %d\n", __func__, model.type);
556
+
557
+ g_buf_model.resize(MEM_REQ_MODEL.at(model.type));
558
+ g_buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
559
+ g_buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
560
+
561
+ // this is the total memory required to run the inference
562
+ const size_t mem_required =
563
+ g_buf_model.size() +
564
+ g_buf_compute.size() +
565
+ g_buf_compute_layer.size();
566
+
567
+ printf("%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
568
+ }
569
+
570
+ // load mel filters
571
+ {
572
+ auto & filters = model.filters;
573
+
574
+ fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
575
+ fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
576
+
577
+ filters.data.resize(filters.n_mel * filters.n_fft);
578
+ fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
579
+ }
580
+
581
+ // load vocab
582
+ {
583
+ int32_t n_vocab = 0;
584
+ fin.read((char *) &n_vocab, sizeof(n_vocab));
585
+
586
+ //if (n_vocab != model.hparams.n_vocab) {
587
+ // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
588
+ // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
589
+ // return false;
590
+ //}
591
+
592
+ std::string word;
593
+ for (int i = 0; i < n_vocab; i++) {
594
+ uint32_t len;
595
+ fin.read((char *) &len, sizeof(len));
596
+
597
+ word.resize(len);
598
+ fin.read((char *) word.data(), len);
599
+
600
+ vocab.token_to_id[word] = i;
601
+ vocab.id_to_token[i] = word;
602
+
603
+ //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
604
+ }
605
+
606
+ vocab.n_vocab = model.hparams.n_vocab;
607
+ if (vocab.is_multilingual()) {
608
+ vocab.token_eot++;
609
+ vocab.token_sot++;
610
+ vocab.token_prev++;
611
+ vocab.token_solm++;
612
+ vocab.token_not++;
613
+ vocab.token_beg++;
614
+ }
615
+
616
+ if (n_vocab < model.hparams.n_vocab) {
617
+ printf("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
618
+ for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
619
+ if (i > vocab.token_beg) {
620
+ word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
621
+ } else if (i == vocab.token_eot) {
622
+ word = "[_EOT_]";
623
+ } else if (i == vocab.token_sot) {
624
+ word = "[_SOT_]";
625
+ } else if (i == vocab.token_prev) {
626
+ word = "[_PREV_]";
627
+ } else if (i == vocab.token_not) {
628
+ word = "[_NOT_]";
629
+ } else if (i == vocab.token_beg) {
630
+ word = "[_BEG_]";
631
+ } else {
632
+ word = "[_extra_token_" + std::to_string(i) + "]";
633
+ }
634
+ vocab.token_to_id[word] = i;
635
+ vocab.id_to_token[i] = word;
636
+ }
637
+ }
638
+ }
639
+
640
+ // for the big tensors, we have the option to store the data in 16-bit floats
641
+ // in order to save memory and also to speed up the computation
642
+ const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
643
+
644
+ auto & ctx = model.ctx;
645
+
646
+ size_t ctx_size = 0;
647
+
648
+ {
649
+ const auto & hparams = model.hparams;
650
+
651
+ const int n_vocab = hparams.n_vocab;
652
+
653
+ const int n_audio_ctx = hparams.n_audio_ctx;
654
+ const int n_audio_state = hparams.n_audio_state;
655
+ const int n_audio_layer = hparams.n_audio_layer;
656
+
657
+ const int n_text_ctx = hparams.n_text_ctx;
658
+ const int n_text_state = hparams.n_text_state;
659
+ const int n_text_layer = hparams.n_text_layer;
660
+
661
+ const int n_mels = hparams.n_mels;
662
+
663
+ // encoder
664
+ {
665
+ // TODO: F16 .. maybe not?
666
+ ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
667
+
668
+ ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
669
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
670
+
671
+ ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
672
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
673
+
674
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
675
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
676
+ }
677
+
678
+ // decoder
679
+ {
680
+ // TODO: F16 .. maybe not?
681
+ ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
682
+
683
+ ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
684
+
685
+ ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
686
+ ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
687
+ }
688
+
689
+ // encoder layers
690
+ {
691
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
692
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
693
+
694
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
695
+ ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
696
+
697
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
698
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
699
+
700
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
701
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
702
+
703
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
704
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
705
+
706
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
707
+
708
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
709
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
710
+
711
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
712
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
713
+ }
714
+
715
+ // decoder layers
716
+ {
717
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
718
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
719
+
720
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
721
+ ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
722
+
723
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
724
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
725
+
726
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
727
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
728
+
729
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
730
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
731
+
732
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
733
+
734
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
735
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
736
+
737
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
738
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
739
+ //
740
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
741
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
742
+
743
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
744
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
745
+
746
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
747
+
748
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
749
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
750
+
751
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
752
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
753
+ }
754
+
755
+ ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
756
+ ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
757
+
758
+ ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
759
+ ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
760
+
761
+ ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
762
+
763
+ printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
764
+ }
765
+
766
+ // create the ggml context
767
+ {
768
+ struct ggml_init_params params = {
769
+ .mem_size = g_buf_model.size(),
770
+ .mem_buffer = g_buf_model.data(),
771
+ };
772
+
773
+ model.ctx = ggml_init(params);
774
+ if (!model.ctx) {
775
+ fprintf(stderr, "%s: ggml_init() failed\n", __func__);
776
+ return false;
777
+ }
778
+ }
779
+
780
+ // prepare memory for the weights
781
+ {
782
+ const auto & hparams = model.hparams;
783
+
784
+ const int n_vocab = hparams.n_vocab;
785
+
786
+ const int n_audio_ctx = hparams.n_audio_ctx;
787
+ const int n_audio_state = hparams.n_audio_state;
788
+ const int n_audio_layer = hparams.n_audio_layer;
789
+
790
+ const int n_text_ctx = hparams.n_text_ctx;
791
+ const int n_text_state = hparams.n_text_state;
792
+ const int n_text_layer = hparams.n_text_layer;
793
+
794
+ const int n_mels = hparams.n_mels;
795
+
796
+ model.layers_encoder.resize(n_audio_layer);
797
+ model.layers_decoder.resize(n_text_layer);
798
+
799
+ // encoder
800
+ {
801
+ model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
802
+
803
+ model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
804
+ model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
805
+
806
+ model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
807
+ model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
808
+
809
+ model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
810
+ model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
811
+
812
+ // map by name
813
+ model.tensors["encoder.positional_embedding"] = model.e_pe;
814
+
815
+ model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
816
+ model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
817
+
818
+ model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
819
+ model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
820
+
821
+ model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
822
+ model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
823
+
824
+ for (int i = 0; i < n_audio_layer; ++i) {
825
+ auto & layer = model.layers_encoder[i];
826
+
827
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
828
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
829
+
830
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
831
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
832
+
833
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
834
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
835
+
836
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
837
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
838
+
839
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
840
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
841
+
842
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
843
+
844
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
845
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
846
+
847
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
848
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
849
+
850
+ // map by name
851
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
852
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
853
+
854
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
855
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
856
+
857
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
858
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
859
+
860
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
861
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
862
+
863
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
864
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
865
+
866
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
867
+
868
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
869
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
870
+
871
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
872
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
873
+ }
874
+ }
875
+
876
+ // decoder
877
+ {
878
+ model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
879
+
880
+ model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
881
+
882
+ model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
883
+ model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
884
+
885
+ // map by name
886
+ model.tensors["decoder.positional_embedding"] = model.d_pe;
887
+
888
+ model.tensors["decoder.token_embedding.weight"] = model.d_te;
889
+
890
+ model.tensors["decoder.ln.weight"] = model.d_ln_w;
891
+ model.tensors["decoder.ln.bias"] = model.d_ln_b;
892
+
893
+ for (int i = 0; i < n_text_layer; ++i) {
894
+ auto & layer = model.layers_decoder[i];
895
+
896
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
897
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
898
+
899
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
900
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
901
+
902
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
903
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
904
+
905
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
906
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
907
+
908
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
909
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
910
+
911
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
912
+
913
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
914
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
915
+
916
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
917
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
918
+
919
+ layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
920
+ layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
921
+
922
+ layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
923
+ layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
924
+
925
+ layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
926
+
927
+ layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
928
+ layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
929
+
930
+ layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
931
+ layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
932
+
933
+ // map by name
934
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
935
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
936
+
937
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
938
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
939
+
940
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
941
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
942
+
943
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
944
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
945
+
946
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
947
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
948
+
949
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
950
+
951
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
952
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
953
+
954
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
955
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
956
+
957
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
958
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
959
+
960
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
961
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
962
+
963
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
964
+
965
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
966
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
967
+
968
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
969
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
970
+ }
971
+ }
972
+ }
973
+
974
+ // key + value memory
975
+ {
976
+ const auto & hparams = model.hparams;
977
+
978
+ const int n_text_state = hparams.n_text_state;
979
+ const int n_text_layer = hparams.n_text_layer;
980
+ const int n_text_ctx = hparams.n_text_ctx;
981
+
982
+ // key/value memory for the self-attention layer
983
+ {
984
+ const int n_mem = n_text_layer*n_text_ctx;
985
+ const int n_elements = n_text_state*n_mem;
986
+
987
+ model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
988
+ model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
989
+ }
990
+
991
+ // key/value memory for the cross-attention layer
992
+ {
993
+ const int n_audio_ctx = hparams.n_audio_ctx;
994
+
995
+ const int n_mem = n_text_layer*n_audio_ctx;
996
+ const int n_elements = n_text_state*n_mem;
997
+
998
+ model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
999
+ model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
1000
+ }
1001
+
1002
+ const size_t memory_size =
1003
+ ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
1004
+ ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
1005
+
1006
+ printf("%s: memory size = %8.2f MB \n", __func__, memory_size/1024.0/1024.0);
1007
+ }
1008
+
1009
+ // load weights
1010
+ {
1011
+ size_t total_size = 0;
1012
+
1013
+ while (true) {
1014
+ int32_t n_dims;
1015
+ int32_t length;
1016
+ int32_t ftype;
1017
+
1018
+ fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
1019
+ fin.read(reinterpret_cast<char *>(&length), sizeof(length));
1020
+ fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
1021
+
1022
+ if (fin.eof()) {
1023
+ break;
1024
+ }
1025
+
1026
+ int32_t nelements = 1;
1027
+ int32_t ne[3] = { 1, 1, 1 };
1028
+ for (int i = 0; i < n_dims; ++i) {
1029
+ fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
1030
+ nelements *= ne[i];
1031
+ }
1032
+
1033
+ std::string name(length, 0);
1034
+ fin.read(&name[0], length);
1035
+
1036
+ if (model.tensors.find(name.data()) == model.tensors.end()) {
1037
+ fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
1038
+ return false;
1039
+ }
1040
+
1041
+ auto tensor = model.tensors[name.data()];
1042
+ if (ggml_nelements(tensor) != nelements) {
1043
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1044
+ return false;
1045
+ }
1046
+
1047
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1048
+ fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1049
+ __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
1050
+ return false;
1051
+ }
1052
+
1053
+ const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
1054
+
1055
+ if (nelements*bpe != ggml_nbytes(tensor)) {
1056
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1057
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1058
+ return false;
1059
+ }
1060
+
1061
+ fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
1062
+
1063
+ //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
1064
+ total_size += ggml_nbytes(tensor);
1065
+ }
1066
+
1067
+ printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
1068
+ }
1069
+
1070
+ fin.close();
1071
+
1072
+ return true;
1073
+ }
1074
+
1075
+ // evaluate the encoder
1076
+ //
1077
+ // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1078
+ // part of the transformer model and returns the encoded features
1079
+ //
1080
+ // - model: the model
1081
+ // - n_threads: number of threads to use
1082
+ // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1083
+ // - mel_inp: input mel spectrogram
1084
+ // - features: output encoded features
1085
+ //
1086
+ bool whisper_encode(
1087
+ const whisper_model & model,
1088
+ const int n_threads,
1089
+ const int mel_offset,
1090
+ const whisper_mel & mel_inp,
1091
+ std::vector<float> & features) {
1092
+ const auto & hparams = model.hparams;
1093
+
1094
+ const int n_vocab = hparams.n_vocab;
1095
+
1096
+ const int n_ctx = hparams.n_audio_ctx;
1097
+ const int n_state = hparams.n_audio_state;
1098
+ const int n_head = hparams.n_audio_head;
1099
+ const int n_layer = hparams.n_audio_layer;
1100
+
1101
+ const int N = n_ctx;
1102
+
1103
+ const int n_mels = hparams.n_mels;
1104
+ assert(mel_inp.n_mel == n_mels);
1105
+
1106
+ struct ggml_init_params params = {
1107
+ .mem_size = g_buf_compute.size(),
1108
+ .mem_buffer = g_buf_compute.data(),
1109
+ };
1110
+
1111
+ struct ggml_context * ctx0 = ggml_init(params);
1112
+
1113
+ struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1114
+ assert(mel->type == GGML_TYPE_F32);
1115
+ {
1116
+ float * dst = (float *) mel->data;
1117
+ memset(dst, 0, ggml_nbytes(mel));
1118
+
1119
+ const int i0 = std::min(mel_offset, mel_inp.n_len);
1120
+ const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1121
+
1122
+ for (int j = 0; j < mel_inp.n_mel; ++j) {
1123
+ for (int i = i0; i < i1; ++i) {
1124
+ dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
1125
+ }
1126
+ }
1127
+ }
1128
+
1129
+ struct ggml_tensor * cur;
1130
+
1131
+ // convolution + gelu
1132
+ {
1133
+ cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1134
+ cur = ggml_add(ctx0,
1135
+ ggml_repeat(ctx0,
1136
+ model.e_conv_1_b,
1137
+ cur),
1138
+ cur);
1139
+
1140
+ cur = ggml_gelu(ctx0, cur);
1141
+
1142
+ cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1143
+ cur = ggml_add(ctx0,
1144
+ ggml_repeat(ctx0,
1145
+ model.e_conv_2_b,
1146
+ cur),
1147
+ cur);
1148
+
1149
+ cur = ggml_gelu(ctx0, cur);
1150
+ }
1151
+
1152
+ cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
1153
+
1154
+ struct ggml_tensor * inpL = cur;
1155
+
1156
+ for (int il = 0; il < n_layer; ++il) {
1157
+ const auto & layer = model.layers_encoder[il];
1158
+
1159
+ // create separate context for each layer to reduce memory usage
1160
+
1161
+ struct ggml_init_params paramsL = {
1162
+ .mem_size = g_buf_compute_layer.size(),
1163
+ .mem_buffer = g_buf_compute_layer.data(),
1164
+ };
1165
+
1166
+ struct ggml_context * ctxL = ggml_init(paramsL);
1167
+
1168
+ // norm
1169
+ {
1170
+ cur = ggml_norm(ctxL, inpL);
1171
+
1172
+ // cur = ln_0_w*cur + ln_0_b
1173
+ cur = ggml_add(ctxL,
1174
+ ggml_mul(ctxL,
1175
+ ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
1176
+ cur),
1177
+ ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
1178
+ }
1179
+
1180
+ // self-attention
1181
+ {
1182
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1183
+ layer.attn_q_w,
1184
+ cur);
1185
+
1186
+ Qcur = ggml_add(ctxL,
1187
+ ggml_repeat(ctxL,
1188
+ layer.attn_q_b,
1189
+ Qcur),
1190
+ Qcur);
1191
+
1192
+ //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1193
+
1194
+ // note: no bias for Key
1195
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
1196
+ layer.attn_k_w,
1197
+ cur);
1198
+
1199
+ //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1200
+
1201
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
1202
+ layer.attn_v_w,
1203
+ cur);
1204
+
1205
+ Vcur = ggml_add(ctxL,
1206
+ ggml_repeat(ctxL,
1207
+ layer.attn_v_b,
1208
+ Vcur),
1209
+ Vcur);
1210
+
1211
+ // ------
1212
+
1213
+ #ifdef USE_FLASH_ATTN
1214
+ struct ggml_tensor * Q =
1215
+ ggml_permute(ctxL,
1216
+ ggml_cpy(ctxL,
1217
+ Qcur,
1218
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1219
+ 0, 2, 1, 3);
1220
+
1221
+ struct ggml_tensor * K =
1222
+ ggml_permute(ctxL,
1223
+ ggml_cpy(ctxL,
1224
+ Kcur,
1225
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1226
+ 0, 2, 1, 3);
1227
+
1228
+ struct ggml_tensor * V =
1229
+ ggml_cpy(ctxL,
1230
+ ggml_permute(ctxL,
1231
+ ggml_reshape_3d(ctxL,
1232
+ Vcur,
1233
+ n_state/n_head, n_head, N),
1234
+ 1, 2, 0, 3),
1235
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
1236
+ );
1237
+
1238
+ struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
1239
+ #else
1240
+ struct ggml_tensor * Q =
1241
+ ggml_permute(ctxL,
1242
+ ggml_cpy(ctxL,
1243
+ Qcur,
1244
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1245
+ 0, 2, 1, 3);
1246
+
1247
+ struct ggml_tensor * K =
1248
+ ggml_permute(ctxL,
1249
+ ggml_cpy(ctxL,
1250
+ Kcur,
1251
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1252
+ 0, 2, 1, 3);
1253
+
1254
+ // K * Q
1255
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1256
+
1257
+ struct ggml_tensor * KQ_scaled =
1258
+ ggml_scale(ctxL,
1259
+ KQ,
1260
+ ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1261
+ );
1262
+
1263
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled);
1264
+
1265
+ //struct ggml_tensor * V_trans =
1266
+ // ggml_permute(ctxL,
1267
+ // ggml_cpy(ctxL,
1268
+ // Vcur,
1269
+ // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1270
+ // 1, 2, 0, 3);
1271
+
1272
+ //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1273
+
1274
+ struct ggml_tensor * V =
1275
+ ggml_cpy(ctxL,
1276
+ ggml_permute(ctxL,
1277
+ ggml_reshape_3d(ctxL,
1278
+ Vcur,
1279
+ n_state/n_head, n_head, N),
1280
+ 0, 2, 1, 3),
1281
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
1282
+ );
1283
+
1284
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
1285
+ #endif
1286
+
1287
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1288
+
1289
+ cur = ggml_cpy(ctxL,
1290
+ KQV_merged,
1291
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1292
+ }
1293
+
1294
+ // projection
1295
+ {
1296
+ cur = ggml_mul_mat(ctxL,
1297
+ layer.attn_ln_1_w,
1298
+ cur);
1299
+
1300
+ cur = ggml_add(ctxL,
1301
+ ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
1302
+ cur);
1303
+ }
1304
+
1305
+ // add the input
1306
+ cur = ggml_add(ctxL, cur, inpL);
1307
+
1308
+ struct ggml_tensor * inpFF = cur;
1309
+
1310
+ // feed-forward network
1311
+ {
1312
+ // norm
1313
+ {
1314
+ cur = ggml_norm(ctxL, inpFF);
1315
+
1316
+ // cur = mlp_ln_w*cur + mlp_ln_b
1317
+ cur = ggml_add(ctxL,
1318
+ ggml_mul(ctxL,
1319
+ ggml_repeat(ctxL, layer.mlp_ln_w, cur),
1320
+ cur),
1321
+ ggml_repeat(ctxL, layer.mlp_ln_b, cur));
1322
+ }
1323
+
1324
+ #ifdef USE_FLASH_FF
1325
+ cur = ggml_flash_ff(ctxL,
1326
+ ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
1327
+ layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1328
+ #else
1329
+ // fully connected
1330
+ cur = ggml_mul_mat(ctxL,
1331
+ layer.mlp_0_w,
1332
+ cur);
1333
+
1334
+ cur = ggml_add(ctxL,
1335
+ ggml_repeat(ctxL, layer.mlp_0_b, cur),
1336
+ cur);
1337
+
1338
+ // GELU activation
1339
+ cur = ggml_gelu(ctxL, cur);
1340
+
1341
+ // projection
1342
+ cur = ggml_mul_mat(ctxL,
1343
+ layer.mlp_1_w,
1344
+ cur);
1345
+
1346
+ cur = ggml_add(ctxL,
1347
+ ggml_repeat(ctxL, layer.mlp_1_b, cur),
1348
+ cur);
1349
+ #endif
1350
+ }
1351
+
1352
+ // output from this layer
1353
+ struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
1354
+
1355
+ {
1356
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1357
+
1358
+ ggml_build_forward_expand(&gf, inpO);
1359
+ ggml_graph_compute (ctxL, &gf);
1360
+
1361
+ //ggml_graph_print(&gf);
1362
+ }
1363
+
1364
+ // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
1365
+ // input for next layer (inpO -> inpL)
1366
+ memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
1367
+ inpL->op = GGML_OP_NONE;
1368
+ inpL->src0 = NULL;
1369
+ inpL->src1 = NULL;
1370
+
1371
+ //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
1372
+
1373
+ ggml_free(ctxL);
1374
+ }
1375
+
1376
+ cur = inpL;
1377
+
1378
+ // norm
1379
+ {
1380
+ cur = ggml_norm(ctx0, cur);
1381
+
1382
+ // cur = ln_f_g*cur + ln_f_b
1383
+ cur = ggml_add(ctx0,
1384
+ ggml_mul(ctx0,
1385
+ ggml_repeat(ctx0, model.e_ln_w, cur),
1386
+ cur),
1387
+ ggml_repeat(ctx0, model.e_ln_b, cur));
1388
+ }
1389
+
1390
+ // run the computation
1391
+ {
1392
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1393
+
1394
+ ggml_build_forward_expand(&gf, cur);
1395
+ ggml_graph_compute (ctx0, &gf);
1396
+
1397
+ //ggml_graph_print(&gf);
1398
+ }
1399
+
1400
+ // cur
1401
+ //{
1402
+ // printf("ne0 = %d\n", cur->ne[0]);
1403
+ // printf("ne1 = %d\n", cur->ne[1]);
1404
+ // for (int i = 0; i < 10; ++i) {
1405
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1406
+ // }
1407
+ // printf("... ");
1408
+ // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1409
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1410
+ // }
1411
+ // printf("\n");
1412
+ //}
1413
+
1414
+ // pre-compute cross-attention memory
1415
+ {
1416
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1417
+
1418
+ // TODO: hack to disconnect the encoded features from the previous graph
1419
+ cur->op = GGML_OP_NONE;
1420
+ cur->src0 = NULL;
1421
+ cur->src1 = NULL;
1422
+
1423
+ for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1424
+ auto & layer = model.layers_decoder[il];
1425
+
1426
+ struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
1427
+ layer.cross_attn_k_w,
1428
+ cur);
1429
+
1430
+ Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1431
+
1432
+ struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
1433
+ layer.cross_attn_v_w,
1434
+ cur);
1435
+
1436
+ Vcross = ggml_add(ctx0,
1437
+ ggml_repeat(ctx0,
1438
+ layer.cross_attn_v_b,
1439
+ Vcross),
1440
+ Vcross);
1441
+
1442
+ struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
1443
+ struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
1444
+
1445
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1446
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
1447
+ }
1448
+
1449
+ ggml_graph_compute(ctx0, &gf);
1450
+ }
1451
+
1452
+ ////////////////////////////////////////////////////////////////////////////
1453
+
1454
+ // output the features
1455
+ assert(cur->type == GGML_TYPE_F32);
1456
+ features.resize(cur->ne[0]*cur->ne[1]);
1457
+ memcpy(features.data(), cur->data, features.size()*sizeof(float));
1458
+
1459
+ //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0);
1460
+
1461
+ ggml_free(ctx0);
1462
+
1463
+ return true;
1464
+ }
1465
+
1466
+ // evaluate the decoder
1467
+ //
1468
+ // given text prompt + audio features -> predicts the probabilities for the next token
1469
+ //
1470
+ // - model: the model
1471
+ // - n_threads: number of threads to use
1472
+ // - n_past: prompt length
1473
+ // - prompt: text prompt
1474
+ // - logits_out: output logits
1475
+ // - probs_out: output probabilities
1476
+ //
1477
+ bool whisper_decode(
1478
+ const whisper_model & model,
1479
+ const int n_threads,
1480
+ const int n_past,
1481
+ const std::vector<whisper_vocab::id> & prompt,
1482
+ std::vector<float> & logits_out,
1483
+ std::vector<float> & probs_out) {
1484
+ const auto & hparams = model.hparams;
1485
+
1486
+ const int n_vocab = hparams.n_vocab;
1487
+
1488
+ const int n_ctx = hparams.n_text_ctx;
1489
+ const int n_state = hparams.n_text_state;
1490
+ const int n_head = hparams.n_text_head;
1491
+ const int n_layer = hparams.n_text_layer;
1492
+
1493
+ const int N = prompt.size();
1494
+ const int M = hparams.n_audio_ctx;
1495
+
1496
+ struct ggml_init_params params = {
1497
+ .mem_size = g_buf_compute.size(),
1498
+ .mem_buffer = g_buf_compute.data(),
1499
+ };
1500
+
1501
+ struct ggml_context * ctx0 = ggml_init(params);
1502
+
1503
+ struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1504
+ memcpy(embd->data, prompt.data(), N*ggml_element_size(embd));
1505
+
1506
+ struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1507
+ for (int i = 0; i < N; ++i) {
1508
+ ((int32_t *) position->data)[i] = n_past + i;
1509
+ }
1510
+
1511
+ // token encoding + position encoding
1512
+ struct ggml_tensor * cur =
1513
+ ggml_add(ctx0,
1514
+ ggml_get_rows(ctx0, model.d_te, embd),
1515
+ ggml_get_rows(ctx0, model.d_pe, position));
1516
+
1517
+ struct ggml_tensor * inpL = cur;
1518
+
1519
+ for (int il = 0; il < n_layer; ++il) {
1520
+ const auto & layer = model.layers_decoder[il];
1521
+
1522
+ struct ggml_init_params paramsL = {
1523
+ .mem_size = g_buf_compute_layer.size(),
1524
+ .mem_buffer = g_buf_compute_layer.data(),
1525
+ };
1526
+
1527
+ struct ggml_context * ctxL = ggml_init(paramsL);
1528
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1529
+
1530
+ // norm
1531
+ {
1532
+ cur = ggml_norm(ctxL, inpL);
1533
+
1534
+ // cur = ln_0_w*cur + ln_0_b
1535
+ cur = ggml_add(ctxL,
1536
+ ggml_mul(ctxL,
1537
+ ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
1538
+ cur),
1539
+ ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
1540
+ }
1541
+
1542
+ // self-attention
1543
+ {
1544
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1545
+ layer.attn_q_w,
1546
+ cur);
1547
+
1548
+ Qcur = ggml_add(ctxL,
1549
+ ggml_repeat(ctxL,
1550
+ layer.attn_q_b,
1551
+ Qcur),
1552
+ Qcur);
1553
+
1554
+ Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1555
+
1556
+ // note: no bias for Key
1557
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
1558
+ layer.attn_k_w,
1559
+ cur);
1560
+
1561
+ Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1562
+
1563
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
1564
+ layer.attn_v_w,
1565
+ cur);
1566
+
1567
+ Vcur = ggml_add(ctxL,
1568
+ ggml_repeat(ctxL,
1569
+ layer.attn_v_b,
1570
+ Vcur),
1571
+ Vcur);
1572
+
1573
+ // store key and value to memory
1574
+ {
1575
+ struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
1576
+ struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
1577
+
1578
+ ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
1579
+ ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
1580
+ }
1581
+
1582
+ // ------
1583
+
1584
+ struct ggml_tensor * Q =
1585
+ ggml_permute(ctxL,
1586
+ ggml_cpy(ctxL,
1587
+ Qcur,
1588
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1589
+ 0, 2, 1, 3);
1590
+
1591
+ struct ggml_tensor * K =
1592
+ ggml_permute(ctxL,
1593
+ ggml_reshape_3d(ctxL,
1594
+ ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
1595
+ n_state/n_head, n_head, n_past + N),
1596
+ 0, 2, 1, 3);
1597
+
1598
+ // K * Q
1599
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1600
+
1601
+ //struct ggml_tensor * KQ_scaled =
1602
+ // ggml_scale(ctxL,
1603
+ // KQ,
1604
+ // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1605
+ // );
1606
+
1607
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
1608
+
1609
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
1610
+
1611
+ struct ggml_tensor * V_trans =
1612
+ ggml_permute(ctxL,
1613
+ ggml_reshape_3d(ctxL,
1614
+ ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
1615
+ n_state/n_head, n_head, n_past + N),
1616
+ 1, 2, 0, 3);
1617
+
1618
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1619
+
1620
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1621
+
1622
+ cur = ggml_cpy(ctxL,
1623
+ KQV_merged,
1624
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1625
+ }
1626
+
1627
+ {
1628
+ cur = ggml_mul_mat(ctxL,
1629
+ layer.attn_ln_1_w,
1630
+ cur);
1631
+
1632
+ cur = ggml_add(ctxL,
1633
+ ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
1634
+ cur);
1635
+ }
1636
+
1637
+ // add the input
1638
+ struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL);
1639
+
1640
+ // norm
1641
+ {
1642
+ cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here
1643
+
1644
+ // cur = ln_0_w*cur + ln_0_b
1645
+ cur = ggml_add(ctxL,
1646
+ ggml_mul(ctxL,
1647
+ ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur),
1648
+ cur),
1649
+ ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur));
1650
+ }
1651
+
1652
+ // cross-attention
1653
+ {
1654
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1655
+ layer.cross_attn_q_w,
1656
+ cur);
1657
+
1658
+ Qcur = ggml_add(ctxL,
1659
+ ggml_repeat(ctxL,
1660
+ layer.cross_attn_q_b,
1661
+ Qcur),
1662
+ Qcur);
1663
+
1664
+ Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1665
+
1666
+ // Kcross is already scaled
1667
+ struct ggml_tensor * Kcross =
1668
+ ggml_reshape_3d(ctxL,
1669
+ ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state),
1670
+ n_state/n_head, n_head, M);
1671
+
1672
+ struct ggml_tensor * Vcross =
1673
+ ggml_reshape_3d(ctxL,
1674
+ ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state),
1675
+ n_state/n_head, n_head, M);
1676
+
1677
+ // ------
1678
+
1679
+ struct ggml_tensor * Q =
1680
+ ggml_permute(ctxL,
1681
+ ggml_cpy(ctxL,
1682
+ Qcur,
1683
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1684
+ 0, 2, 1, 3);
1685
+
1686
+ struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
1687
+
1688
+ // K * Q
1689
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1690
+
1691
+ //struct ggml_tensor * KQ_scaled =
1692
+ // ggml_scale(ctxL,
1693
+ // KQ,
1694
+ // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1695
+ // );
1696
+
1697
+ // no masking for cross-attention
1698
+ //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
1699
+
1700
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
1701
+
1702
+ struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
1703
+
1704
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1705
+
1706
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1707
+
1708
+ // cur = KQV_merged.contiguous().view(n_state, N)
1709
+ cur = ggml_cpy(ctxL,
1710
+ KQV_merged,
1711
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1712
+ }
1713
+
1714
+ // projection
1715
+ {
1716
+ cur = ggml_mul_mat(ctxL,
1717
+ layer.cross_attn_ln_1_w,
1718
+ cur);
1719
+
1720
+ cur = ggml_add(ctxL,
1721
+ ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur),
1722
+ cur);
1723
+ }
1724
+
1725
+ // add the input
1726
+ cur = ggml_add(ctxL, cur, inpCA);
1727
+
1728
+ struct ggml_tensor * inpFF = cur;
1729
+
1730
+ // feed-forward network
1731
+ {
1732
+ // norm
1733
+ {
1734
+ cur = ggml_norm(ctxL, inpFF);
1735
+
1736
+ // cur = mlp_ln_w*cur + mlp_ln_b
1737
+ cur = ggml_add(ctxL,
1738
+ ggml_mul(ctxL,
1739
+ ggml_repeat(ctxL, layer.mlp_ln_w, cur),
1740
+ cur),
1741
+ ggml_repeat(ctxL, layer.mlp_ln_b, cur));
1742
+ }
1743
+
1744
+ // fully connected
1745
+ cur = ggml_mul_mat(ctxL,
1746
+ layer.mlp_0_w,
1747
+ cur);
1748
+
1749
+ cur = ggml_add(ctxL,
1750
+ ggml_repeat(ctxL, layer.mlp_0_b, cur),
1751
+ cur);
1752
+
1753
+ // GELU activation
1754
+ cur = ggml_gelu(ctxL, cur);
1755
+
1756
+ // projection
1757
+ cur = ggml_mul_mat(ctxL,
1758
+ layer.mlp_1_w,
1759
+ cur);
1760
+
1761
+ cur = ggml_add(ctxL,
1762
+ ggml_repeat(ctxL, layer.mlp_1_b, cur),
1763
+ cur);
1764
+ }
1765
+
1766
+ // output from this layer
1767
+ struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
1768
+
1769
+ {
1770
+ ggml_build_forward_expand(&gf, inpO);
1771
+ ggml_graph_compute (ctxL, &gf);
1772
+
1773
+ //ggml_graph_print(&gf);
1774
+ }
1775
+
1776
+ // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
1777
+ // input for next layer (inpO -> inpL)
1778
+ memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
1779
+ inpL->op = GGML_OP_NONE;
1780
+ inpL->src0 = NULL;
1781
+ inpL->src1 = NULL;
1782
+
1783
+ if (N > 1) {
1784
+ //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
1785
+ }
1786
+
1787
+ ggml_free(ctxL);
1788
+ }
1789
+
1790
+ cur = inpL;
1791
+
1792
+ // norm
1793
+ {
1794
+ cur = ggml_norm(ctx0, cur);
1795
+
1796
+ cur = ggml_add(ctx0,
1797
+ ggml_mul(ctx0,
1798
+ ggml_repeat(ctx0, model.d_ln_w, cur),
1799
+ cur),
1800
+ ggml_repeat(ctx0, model.d_ln_b, cur));
1801
+ }
1802
+
1803
+ struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
1804
+
1805
+ // logits -> probs
1806
+ cur = ggml_dup(ctx0, logits);
1807
+ cur = ggml_soft_max(ctx0, cur); // in-place
1808
+
1809
+ // run the computation
1810
+ {
1811
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1812
+
1813
+ ggml_build_forward_expand(&gf, cur);
1814
+ ggml_graph_compute (ctx0, &gf);
1815
+ }
1816
+
1817
+ logits_out.resize(N*n_vocab);
1818
+ memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
1819
+
1820
+ probs_out.resize(N*n_vocab);
1821
+ memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
1822
+
1823
+ if (N > 1) {
1824
+ //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
1825
+ //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
1826
+ //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
1827
+ }
1828
+
1829
+ ggml_free(ctx0);
1830
+
1831
+ return true;
1832
+ }
1833
+
1834
+ // the most basic sampling scheme - select the top token
1835
+ // TODO: beam search
1836
+ // TODO: temperature
1837
+ whisper_vocab::id whisper_sample_best(
1838
+ const whisper_vocab & vocab,
1839
+ const float * probs, bool need_timestamp) {
1840
+ int n_logits = vocab.id_to_token.size();
1841
+
1842
+ std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1843
+ probs_id.reserve(n_logits);
1844
+
1845
+ for (int i = 0; i < n_logits; i++) {
1846
+ probs_id.push_back(std::make_pair(probs[i], i));
1847
+ }
1848
+
1849
+ const int top_k = 4;
1850
+
1851
+ // find the top K tokens
1852
+ std::partial_sort(
1853
+ probs_id.begin(),
1854
+ probs_id.begin() + top_k, probs_id.end(),
1855
+ [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
1856
+ return a.first > b.first;
1857
+ });
1858
+
1859
+ probs_id.resize(top_k);
1860
+
1861
+ //printf("\n");
1862
+ //for (int i = 0; i < (int) probs_id.size(); i++) {
1863
+ // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
1864
+ //}
1865
+
1866
+ if (need_timestamp) {
1867
+ // at the end of the 30-second audio segment, we start giving preference to time tokens
1868
+ for (int i = 0; i < top_k; i++) {
1869
+ if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
1870
+ return probs_id[i].second;
1871
+ }
1872
+ }
1873
+ }
1874
+
1875
+ int res = 0;
1876
+ while ((probs_id[res].second == vocab.token_sot ||
1877
+ probs_id[res].second == vocab.token_solm ||
1878
+ probs_id[res].second == vocab.token_not) &&
1879
+ res < (int) probs_id.size() - 1) {
1880
+ res++;
1881
+ }
1882
+
1883
+ return probs_id[res].second;
1884
+ }
1885
+
1886
+ // samples only from the timestamps tokens
1887
+ whisper_vocab::id whisper_sample_timestamp(
1888
+ const whisper_vocab & vocab,
1889
+ const float * probs) {
1890
+ int n_logits = vocab.id_to_token.size();
1891
+
1892
+ std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1893
+ probs_id.reserve(n_logits);
1894
+
1895
+ for (int i = vocab.token_beg + 1; i < n_logits; i++) {
1896
+ probs_id.push_back(std::make_pair(probs[i], i));
1897
+ }
1898
+
1899
+ const int top_k = 10;
1900
+
1901
+ // find the top K tokens
1902
+ std::partial_sort(
1903
+ probs_id.begin(),
1904
+ probs_id.begin() + top_k, probs_id.end(),
1905
+ [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
1906
+ return a.first > b.first;
1907
+ });
1908
+
1909
+ probs_id.resize(top_k);
1910
+
1911
+ //printf("\n");
1912
+ //for (int i = 0; i < (int) probs_id.size(); i++) {
1913
+ // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
1914
+ //}
1915
+
1916
+ return probs_id[0].second;
1917
+ }
1918
+
1919
+ // naive Discrete Fourier Transform
1920
+ // input is real-valued
1921
+ // output is complex-valued
1922
+ void dft(const std::vector<float> & in, std::vector<float> & out) {
1923
+ int N = in.size();
1924
+
1925
+ out.resize(N*2);
1926
+
1927
+ for (int k = 0; k < N; k++) {
1928
+ float re = 0;
1929
+ float im = 0;
1930
+
1931
+ for (int n = 0; n < N; n++) {
1932
+ float angle = 2*M_PI*k*n/N;
1933
+ re += in[n]*cos(angle);
1934
+ im -= in[n]*sin(angle);
1935
+ }
1936
+
1937
+ out[k*2 + 0] = re;
1938
+ out[k*2 + 1] = im;
1939
+ }
1940
+ }
1941
+
1942
+ // Cooley-Tukey FFT
1943
+ // poor man's implmentation - use something better
1944
+ // input is real-valued
1945
+ // output is complex-valued
1946
+ void fft(const std::vector<float> & in, std::vector<float> & out) {
1947
+ out.resize(in.size()*2);
1948
+
1949
+ int N = in.size();
1950
+
1951
+ if (N == 1) {
1952
+ out[0] = in[0];
1953
+ out[1] = 0;
1954
+ return;
1955
+ }
1956
+
1957
+ if (N%2 == 1) {
1958
+ dft(in, out);
1959
+ return;
1960
+ }
1961
+
1962
+ std::vector<float> even;
1963
+ std::vector<float> odd;
1964
+
1965
+ for (int i = 0; i < N; i++) {
1966
+ if (i % 2 == 0) {
1967
+ even.push_back(in[i]);
1968
+ } else {
1969
+ odd.push_back(in[i]);
1970
+ }
1971
+ }
1972
+
1973
+ std::vector<float> even_fft;
1974
+ std::vector<float> odd_fft;
1975
+
1976
+ fft(even, even_fft);
1977
+ fft(odd, odd_fft);
1978
+
1979
+ for (int k = 0; k < N/2; k++) {
1980
+ float theta = 2*M_PI*k/N;
1981
+
1982
+ float re = cos(theta);
1983
+ float im = -sin(theta);
1984
+
1985
+ float re_odd = odd_fft[2*k + 0];
1986
+ float im_odd = odd_fft[2*k + 1];
1987
+
1988
+ out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
1989
+ out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
1990
+
1991
+ out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
1992
+ out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
1993
+ }
1994
+ }
1995
+
1996
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
1997
+ bool log_mel_spectrogram(
1998
+ const std::vector<float> sf32,
1999
+ const int sample_rate,
2000
+ const int fft_size,
2001
+ const int fft_step,
2002
+ const int n_mel,
2003
+ const int n_threads,
2004
+ const whisper_filters & filters,
2005
+ whisper_mel & mel) {
2006
+ const int n_sample = sf32.size();
2007
+ const float * samples = sf32.data();
2008
+
2009
+ // Hanning window
2010
+ std::vector<float> hann;
2011
+ hann.resize(fft_size);
2012
+ for (int i = 0; i < fft_size; i++) {
2013
+ hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
2014
+ }
2015
+
2016
+ mel.n_mel = n_mel;
2017
+ mel.n_len = (n_sample)/fft_step;
2018
+ mel.data.resize(mel.n_mel*mel.n_len);
2019
+
2020
+ const int n_fft = 1 + fft_size/2;
2021
+
2022
+ //printf("%s: n_sample = %d, n_len = %d\n", __func__, n_sample, mel.n_len);
2023
+ //printf("%s: recording length: %f s\n", __func__, (float) n_sample/sample_rate);
2024
+
2025
+ std::vector<std::thread> workers(n_threads);
2026
+ for (int iw = 0; iw < n_threads; ++iw) {
2027
+ workers[iw] = std::thread([&](int ith) {
2028
+ std::vector<float> fft_in;
2029
+ fft_in.resize(fft_size);
2030
+ for (int i = 0; i < fft_size; i++) {
2031
+ fft_in[i] = 0.0;
2032
+ }
2033
+
2034
+ std::vector<float> fft_out;
2035
+ fft_out.resize(2*fft_size);
2036
+
2037
+ for (int i = ith; i < mel.n_len; i += n_threads) {
2038
+ const int offset = i*fft_step;
2039
+
2040
+ // apply Hanning window
2041
+ for (int j = 0; j < fft_size; j++) {
2042
+ if (offset + j < n_sample) {
2043
+ fft_in[j] = hann[j]*samples[offset + j];
2044
+ } else {
2045
+ fft_in[j] = 0.0;
2046
+ }
2047
+ }
2048
+
2049
+ // FFT -> mag^2
2050
+ fft(fft_in, fft_out);
2051
+
2052
+ for (int j = 0; j < fft_size; j++) {
2053
+ fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
2054
+ }
2055
+ for (int j = 1; j < fft_size/2; j++) {
2056
+ fft_out[j] += fft_out[fft_size - j];
2057
+ }
2058
+
2059
+ // mel spectrogram
2060
+ for (int j = 0; j < mel.n_mel; j++) {
2061
+ double sum = 0.0;
2062
+
2063
+ for (int k = 0; k < n_fft; k++) {
2064
+ sum += fft_out[k]*filters.data[j*n_fft + k];
2065
+ }
2066
+ if (sum < 1e-10) {
2067
+ sum = 1e-10;
2068
+ }
2069
+
2070
+ sum = log10(sum);
2071
+
2072
+ mel.data[j*mel.n_len + i] = sum;
2073
+ }
2074
+ }
2075
+ }, iw);
2076
+ }
2077
+
2078
+ for (int iw = 0; iw < n_threads; ++iw) {
2079
+ workers[iw].join();
2080
+ }
2081
+
2082
+ // clamping and normalization
2083
+ double mmax = -1e20;
2084
+ for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2085
+ if (mel.data[i] > mmax) {
2086
+ mmax = mel.data[i];
2087
+ }
2088
+ }
2089
+
2090
+ mmax -= 8.0;
2091
+
2092
+ for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2093
+ if (mel.data[i] < mmax) {
2094
+ mel.data[i] = mmax;
2095
+ }
2096
+
2097
+ mel.data[i] = (mel.data[i] + 4.0)/4.0;
2098
+ }
2099
+
2100
+ return true;
2101
+ }
2102
+
2103
+ // 500 -> 00:05.000
2104
+ // 6000 -> 01:00.000
2105
+ std::string to_timestamp(int64_t t) {
2106
+ int64_t sec = t/100;
2107
+ int64_t msec = t - sec*100;
2108
+ int64_t min = sec/60;
2109
+ sec = sec - min*60;
2110
+
2111
+ char buf[32];
2112
+ snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec);
2113
+
2114
+ return std::string(buf);
2115
+ }
2116
+
2117
+ //
2118
+ // SDL Audio capture
2119
+ //
2120
+
2121
+ SDL_AudioDeviceID g_dev_id_in = 0;
2122
+
2123
+ bool audio_sdl_init(const int capture_id) {
2124
+ if (g_dev_id_in) {
2125
+ fprintf(stderr, "%s: already initialized\n", __func__);
2126
+ return false;
2127
+ }
2128
+
2129
+ if (g_dev_id_in == 0) {
2130
+ SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
2131
+
2132
+ if (SDL_Init(SDL_INIT_AUDIO) < 0) {
2133
+ SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
2134
+ return (1);
2135
+ }
2136
+
2137
+ SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
2138
+
2139
+ {
2140
+ int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
2141
+ printf("%s: found %d capture devices:\n", __func__, nDevices);
2142
+ for (int i = 0; i < nDevices; i++) {
2143
+ printf("%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
2144
+ }
2145
+ }
2146
+ }
2147
+
2148
+ if (g_dev_id_in == 0) {
2149
+ SDL_AudioSpec capture_spec_requested;
2150
+ SDL_AudioSpec capture_spec_obtained;
2151
+
2152
+ SDL_zero(capture_spec_requested);
2153
+ SDL_zero(capture_spec_obtained);
2154
+
2155
+ capture_spec_requested.freq = SAMPLE_RATE;
2156
+ capture_spec_requested.format = AUDIO_F32;
2157
+ capture_spec_requested.channels = 1;
2158
+ capture_spec_requested.samples = 1024;
2159
+
2160
+ if (capture_id >= 0) {
2161
+ printf("%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
2162
+ g_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
2163
+ } else {
2164
+ printf("%s: attempt to open default capture device ...\n", __func__);
2165
+ g_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
2166
+ }
2167
+ if (!g_dev_id_in) {
2168
+ printf("%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
2169
+ g_dev_id_in = 0;
2170
+ } else {
2171
+ printf("%s: obtained spec for input device (SDL Id = %d):\n", __func__, g_dev_id_in);
2172
+ printf("%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
2173
+ printf("%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format, capture_spec_requested.format);
2174
+ printf("%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels, capture_spec_requested.channels);
2175
+ printf("%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
2176
+ }
2177
+ }
2178
+
2179
+
2180
+ return true;
2181
+ }
2182
+
2183
+ ///////////////////////////
2184
+
2185
+ int main(int argc, char ** argv) {
2186
+ const int64_t t_main_start_us = ggml_time_us();
2187
+
2188
+ whisper_params params;
2189
+
2190
+ if (whisper_params_parse(argc, argv, params) == false) {
2191
+ return 1;
2192
+ }
2193
+
2194
+ if (params.seed < 0) {
2195
+ params.seed = time(NULL);
2196
+ }
2197
+
2198
+ // init audio
2199
+
2200
+ if (!audio_sdl_init(-1)) {
2201
+ fprintf(stderr, "%s: audio_sdl_init() failed!\n", __func__);
2202
+ return 1;
2203
+ }
2204
+
2205
+ // model loading
2206
+
2207
+ //printf("%s: seed = %d\n", __func__, params.seed);
2208
+
2209
+ int64_t t_load_us = 0;
2210
+ int64_t t_mel_us = 0;
2211
+ int64_t t_sample_us = 0;
2212
+ int64_t t_encode_us = 0;
2213
+ int64_t t_decode_us = 0;
2214
+
2215
+ whisper_vocab vocab;
2216
+ whisper_model model;
2217
+
2218
+ // load the model
2219
+ {
2220
+ const int64_t t_start_us = ggml_time_us();
2221
+
2222
+ if (!whisper_model_load(params.model, model, vocab)) {
2223
+ fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
2224
+ whisper_print_usage(argc, argv, {});
2225
+ return 1;
2226
+ }
2227
+
2228
+ t_load_us = ggml_time_us() - t_start_us;
2229
+ }
2230
+
2231
+ const int n_samples_30s = 30*SAMPLE_RATE;
2232
+ std::vector<float> pcmf32(n_samples_30s, 0.0f);
2233
+ std::vector<float> pcmf32_old;
2234
+
2235
+ // print some info about the processing
2236
+ {
2237
+ printf("\n");
2238
+ if (!vocab.is_multilingual()) {
2239
+ if (params.language != "en" || params.translate) {
2240
+ params.language = "en";
2241
+ params.translate = false;
2242
+ printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
2243
+ }
2244
+ }
2245
+ printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
2246
+ __func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads,
2247
+ g_lang.at(params.language).second.c_str(),
2248
+ params.translate ? "translate" : "transcribe",
2249
+ params.no_timestamps ? 0 : 1);
2250
+ printf("\n");
2251
+ }
2252
+
2253
+ SDL_PauseAudioDevice(g_dev_id_in, 0);
2254
+
2255
+ // main audio loop
2256
+ while (true) {
2257
+ // process 3 seconds of new audio
2258
+ while ((int) SDL_GetQueuedAudioSize(g_dev_id_in) < 3*SAMPLE_RATE*sizeof(float)) {
2259
+ SDL_Delay(1);
2260
+ }
2261
+ const int n_samples_new = SDL_GetQueuedAudioSize(g_dev_id_in)/sizeof(float);
2262
+
2263
+ // take one second from previous iteration
2264
+ // TODO: better strategy
2265
+ const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_30s/30 - n_samples_new));
2266
+
2267
+ //printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
2268
+
2269
+ pcmf32.resize(n_samples_new + n_samples_take);
2270
+
2271
+ for (int i = 0; i < n_samples_take; i++) {
2272
+ pcmf32[i] = pcmf32_old[pcmf32_old.size() - n_samples_take + i];
2273
+ }
2274
+
2275
+ SDL_DequeueAudio(g_dev_id_in, pcmf32.data() + n_samples_take, n_samples_new*sizeof(float));
2276
+
2277
+ pcmf32_old = pcmf32;
2278
+
2279
+ // compute log mel spectrogram
2280
+ whisper_mel mel_inp;
2281
+ {
2282
+ const int64_t t_start_us = ggml_time_us();
2283
+
2284
+ log_mel_spectrogram(pcmf32, SAMPLE_RATE, N_FFT, HOP_LENGTH, N_MEL, params.n_threads, model.filters, mel_inp);
2285
+
2286
+ t_mel_us = ggml_time_us() - t_start_us;
2287
+ }
2288
+
2289
+ // the accumulated text context so far
2290
+ std::vector<whisper_vocab::id> prompt_past = { };
2291
+
2292
+ // these tokens determine the task that will be performed
2293
+ std::vector<whisper_vocab::id> prompt_init = { vocab.token_sot };
2294
+ if (vocab.is_multilingual()) {
2295
+ prompt_init.push_back(vocab.token_sot + 1 + g_lang.at(params.language).first);
2296
+ if (params.translate) {
2297
+ prompt_init.push_back(vocab.token_translate);
2298
+ } else {
2299
+ prompt_init.push_back(vocab.token_transcribe);
2300
+ }
2301
+ }
2302
+
2303
+ // the generated text including timestamps
2304
+ //std::vector<whisper_result> result_all;
2305
+
2306
+ // main loop
2307
+ int seek = 0;
2308
+ while (true) {
2309
+ if (seek >= mel_inp.n_len) {
2310
+ break;
2311
+ }
2312
+
2313
+ // encode audio features starting at offset seek
2314
+ std::vector<float> features;
2315
+ {
2316
+ const int64_t t_start_us = ggml_time_us();
2317
+
2318
+ if (!whisper_encode(model, params.n_threads, seek, mel_inp, features)) {
2319
+ fprintf(stderr, "%s: failed to eval\n", __func__);
2320
+ return 1;
2321
+ }
2322
+
2323
+ t_encode_us += ggml_time_us() - t_start_us;
2324
+ }
2325
+
2326
+ std::vector<float> probs;
2327
+ std::vector<float> logits;
2328
+
2329
+ std::vector<whisper_vocab::id> prompt;
2330
+
2331
+ int n_past = 0;
2332
+
2333
+ // if we have already generated some text, use it as a prompt to condition the next generation
2334
+ if (prompt_past.size() > 0) {
2335
+ int n_take = std::min(model.hparams.n_text_ctx/2, int(prompt_past.size()));
2336
+
2337
+ prompt = { vocab.token_prev };
2338
+ prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
2339
+
2340
+ prompt_past.clear();
2341
+ prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
2342
+ }
2343
+
2344
+ prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
2345
+
2346
+ bool done = false;
2347
+ int seek_delta = 100*CHUNK_SIZE;
2348
+ whisper_vocab::id last_id = 0;
2349
+
2350
+ // print the prompt
2351
+ //printf("\n\n");
2352
+ //for (int i = 0; i < prompt.size(); i++) {
2353
+ // printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str());
2354
+ //}
2355
+ //printf("\n\n");
2356
+
2357
+ // the accumulated transcription in the current interation
2358
+ int result_len = 0;
2359
+ std::vector<whisper_result> result_cur;
2360
+
2361
+ for (int i = 0; i < model.hparams.n_text_ctx/2 - 4; ++i) {
2362
+ // decode
2363
+ if (prompt.size() > 0) {
2364
+ const int64_t t_start_us = ggml_time_us();
2365
+
2366
+ if (!whisper_decode(model, params.n_threads, n_past, prompt, logits, probs)) {
2367
+ fprintf(stderr, "%s: failed to eval\n", __func__);
2368
+ return 1;
2369
+ }
2370
+
2371
+ t_decode_us += ggml_time_us() - t_start_us;
2372
+ }
2373
+
2374
+ n_past += prompt.size();
2375
+ prompt.clear();
2376
+
2377
+ // very basic greedy sampling strategy:
2378
+ //
2379
+ // - always take the most probable token
2380
+ //
2381
+ // more sophisticated sampling strategies could be implemented here, but we keep it simple
2382
+ // feel free to experiment!
2383
+ //
2384
+ {
2385
+ const int n_vocab = model.hparams.n_vocab;
2386
+
2387
+ whisper_vocab::id id = 0;
2388
+ whisper_vocab::id tid = vocab.token_beg;
2389
+
2390
+ {
2391
+ const int64_t t_start_sample_us = ggml_time_us();
2392
+
2393
+ id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab), result_len == 0);
2394
+ if (i > 0) {
2395
+ tid = whisper_sample_timestamp(vocab, probs.data() + (probs.size() - n_vocab));
2396
+ }
2397
+
2398
+ t_sample_us += ggml_time_us() - t_start_sample_us;
2399
+ }
2400
+
2401
+ // update sliding window
2402
+ if (id > vocab.token_beg) {
2403
+ seek_delta = 2*(id - vocab.token_beg);
2404
+ result_len = i + 1;
2405
+ }
2406
+ last_id = id;
2407
+
2408
+ // add it to the context
2409
+ prompt.push_back(id);
2410
+ result_cur.push_back({ id, seek + 2*(tid - vocab.token_beg) });
2411
+
2412
+ //printf("%s: %s\n", __func__, vocab.id_to_token[id].c_str());
2413
+
2414
+ // end of text token
2415
+ if (id == vocab.token_eot) {
2416
+ break;
2417
+ }
2418
+ }
2419
+
2420
+ if (done) {
2421
+ break;
2422
+ }
2423
+ }
2424
+
2425
+ result_cur.resize(result_len);
2426
+ //result_all.insert(result_all.end(), result_cur.begin(), result_cur.end());
2427
+
2428
+ for (const auto & r : result_cur) {
2429
+ prompt_past.push_back(r.id);
2430
+ }
2431
+
2432
+ // print the text from this iteration
2433
+ if (result_cur.size() > 0) {
2434
+ auto t0 = result_cur.front().t;
2435
+
2436
+ std::string text = "";
2437
+ for (int i = 0; i < result_cur.size(); i++) {
2438
+ if (params.print_special_tokens == false && result_cur[i].id >= vocab.token_eot) {
2439
+ } else {
2440
+ text += vocab.id_to_token[result_cur[i].id];
2441
+ }
2442
+ if (result_cur[i].id > vocab.token_beg) {
2443
+ const auto t1 = result_cur[i].t;
2444
+ if (!text.empty()) {
2445
+ if (params.no_timestamps) {
2446
+ printf ("%s", text.c_str());
2447
+ fflush(stdout);
2448
+ } else {
2449
+ printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
2450
+ }
2451
+ }
2452
+ text = "";
2453
+ while (result_cur[i].id > vocab.token_beg && i < result_cur.size()) {
2454
+ i++;
2455
+ }
2456
+ i--;
2457
+ t0 = result_cur[i].t;
2458
+ }
2459
+ }
2460
+
2461
+ if (!text.empty()) {
2462
+ printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(seek + seek_delta).c_str(), text.c_str());
2463
+ }
2464
+ }
2465
+
2466
+ seek += seek_delta;
2467
+ }
2468
+ }
2469
+
2470
+ // WIP: attempt for per-token timestamps
2471
+ //if (!params.no_timestamps && result_all.size() > 0) {
2472
+ // const int64_t dt = 500; // 5 second intervals
2473
+
2474
+ // int i0 = 0;
2475
+
2476
+ // int64_t t0 = result_all[0].t;
2477
+ // int64_t t1 = t0;
2478
+
2479
+ // printf("\n\n");
2480
+ // for (int i = 0; i < result_all.size(); ++i) {
2481
+ // printf("'%s' -> %lld\n", vocab.id_to_token[result_all[i].id].c_str(), result_all[i].t);
2482
+ // if (result_all[i].t - t0 > dt) {
2483
+ // t1 = result_all[i - 1].t;
2484
+ // printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
2485
+ // for (int j = i0; j < i; ++j) {
2486
+ // printf("%s", vocab.id_to_token.at(result_all[j].id).c_str());
2487
+ // }
2488
+ // printf("\n");
2489
+ // i0 = i;
2490
+ // t0 = result_all[i].t;
2491
+ // }
2492
+ // }
2493
+ //}
2494
+
2495
+ // report timing
2496
+ {
2497
+ const int64_t t_main_end_us = ggml_time_us();
2498
+
2499
+ printf("\n\n");
2500
+ printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
2501
+ printf("%s: mel time = %8.2f ms\n", __func__, t_mel_us/1000.0f);
2502
+ printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
2503
+ printf("%s: encode time = %8.2f ms / %.2f ms per layer\n", __func__, t_encode_us/1000.0f, t_encode_us/1000.0f/model.hparams.n_audio_layer);
2504
+ printf("%s: decode time = %8.2f ms\n", __func__, t_decode_us/1000.0f);
2505
+ printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
2506
+ }
2507
+
2508
+ ggml_free(model.ctx);
2509
+
2510
+ return 0;
2511
+ }