Neil Chudleigh commited on
Commit
c587102
·
unverified ·
1 Parent(s): 30cdb60

extra: Add benchmark script implemented in Python (#1298)

Browse files

* Create bench.py

* Various benchmark results

* Update benchmark script with hardware name, and file checks

* Remove old benchmark results

* Add git shorthash

* Round to 2 digits on calculated floats

* Fix the header reference when sorting results

* FIx order of models

* Parse file name

* Simplify filecheck

* Improve print run print statement

* Use simplified model name

* Update benchmark_results.csv

* Process single or lists of processors and threads

* Ignore benchmark results, dont check in

* Move bench.py to extra folder

* Readme section on how to use

* Move command to correct location

* Use separate list for models that exist

* Handle subprocess error in git short hash check

* Fix filtered models list initialization

Files changed (3) hide show
  1. .gitignore +2 -0
  2. README.md +13 -0
  3. extra/bench.py +222 -0
.gitignore CHANGED
@@ -46,3 +46,5 @@ models/*.mlpackage
46
  bindings/java/.gradle/
47
  bindings/java/.idea/
48
  .idea/
 
 
 
46
  bindings/java/.gradle/
47
  bindings/java/.idea/
48
  .idea/
49
+
50
+ benchmark_results.csv
README.md CHANGED
@@ -709,6 +709,19 @@ took to execute it. The results are summarized in the following Github issue:
709
 
710
  [Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  ## ggml format
713
 
714
  The original models are converted to a custom binary format. This allows to pack everything needed into a single file:
 
709
 
710
  [Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
711
 
712
+ Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](bench.py).
713
+
714
+ You can run it with the following command, by default it will run against any standard model in the models folder.
715
+
716
+ ```bash
717
+ python3 extra/bench.py -f samples/jfk.wav -t 2,4,8 -p 1,2
718
+ ```
719
+
720
+ It is written in python with the intention of being easy to modify and extend for your benchmarking use case.
721
+
722
+ It outputs a csv file with the results of the benchmarking.
723
+
724
+
725
  ## ggml format
726
 
727
  The original models are converted to a custom binary format. This allows to pack everything needed into a single file:
extra/bench.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import re
4
+ import csv
5
+ import wave
6
+ import contextlib
7
+ import argparse
8
+
9
+
10
+ # Custom action to handle comma-separated list
11
+ class ListAction(argparse.Action):
12
+ def __call__(self, parser, namespace, values, option_string=None):
13
+ setattr(namespace, self.dest, [int(val) for val in values.split(",")])
14
+
15
+
16
+ parser = argparse.ArgumentParser(description="Benchmark the speech recognition model")
17
+
18
+ # Define the argument to accept a list
19
+ parser.add_argument(
20
+ "-t",
21
+ "--threads",
22
+ dest="threads",
23
+ action=ListAction,
24
+ default=[4],
25
+ help="List of thread counts to benchmark (comma-separated, default: 4)",
26
+ )
27
+
28
+ parser.add_argument(
29
+ "-p",
30
+ "--processors",
31
+ dest="processors",
32
+ action=ListAction,
33
+ default=[1],
34
+ help="List of processor counts to benchmark (comma-separated, default: 1)",
35
+ )
36
+
37
+
38
+ parser.add_argument(
39
+ "-f",
40
+ "--filename",
41
+ type=str,
42
+ default="./samples/jfk.wav",
43
+ help="Relative path of the file to transcribe (default: ./samples/jfk.wav)",
44
+ )
45
+
46
+ # Parse the command line arguments
47
+ args = parser.parse_args()
48
+
49
+ sample_file = args.filename
50
+
51
+ threads = args.threads
52
+ processors = args.processors
53
+
54
+ # Define the models, threads, and processor counts to benchmark
55
+ models = [
56
+ "ggml-tiny.en.bin",
57
+ "ggml-tiny.bin",
58
+ "ggml-base.en.bin",
59
+ "ggml-base.bin",
60
+ "ggml-small.en.bin",
61
+ "ggml-small.bin",
62
+ "ggml-medium.en.bin",
63
+ "ggml-medium.bin",
64
+ "ggml-large.bin",
65
+ ]
66
+
67
+
68
+ metal_device = ""
69
+
70
+ # Initialize a dictionary to hold the results
71
+ results = {}
72
+
73
+ gitHashHeader = "Commit"
74
+ modelHeader = "Model"
75
+ hardwareHeader = "Hardware"
76
+ recordingLengthHeader = "Recording Length (seconds)"
77
+ threadHeader = "Thread"
78
+ processorCountHeader = "Processor Count"
79
+ loadTimeHeader = "Load Time (ms)"
80
+ sampleTimeHeader = "Sample Time (ms)"
81
+ encodeTimeHeader = "Encode Time (ms)"
82
+ decodeTimeHeader = "Decode Time (ms)"
83
+ sampleTimePerRunHeader = "Sample Time per Run (ms)"
84
+ encodeTimePerRunHeader = "Encode Time per Run (ms)"
85
+ decodeTimePerRunHeader = "Decode Time per Run (ms)"
86
+ totalTimeHeader = "Total Time (ms)"
87
+
88
+
89
+ def check_file_exists(file: str) -> bool:
90
+ return os.path.isfile(file)
91
+
92
+
93
+ def get_git_short_hash() -> str:
94
+ try:
95
+ return (
96
+ subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
97
+ .decode()
98
+ .strip()
99
+ )
100
+ except subprocess.CalledProcessError as e:
101
+ return ""
102
+
103
+
104
+ def wav_file_length(file: str = sample_file) -> float:
105
+ with contextlib.closing(wave.open(file, "r")) as f:
106
+ frames = f.getnframes()
107
+ rate = f.getframerate()
108
+ duration = frames / float(rate)
109
+ return duration
110
+
111
+
112
+ def extract_metrics(output: str, label: str) -> tuple[float, float]:
113
+ match = re.search(rf"{label} \s*=\s*(\d+\.\d+)\s*ms\s*/\s*(\d+)\s*runs", output)
114
+ time = float(match.group(1)) if match else None
115
+ runs = float(match.group(2)) if match else None
116
+ return time, runs
117
+
118
+
119
+ def extract_device(output: str) -> str:
120
+ match = re.search(r"picking default device: (.*)", output)
121
+ device = match.group(1) if match else "Not found"
122
+ return device
123
+
124
+
125
+ # Check if the sample file exists
126
+ if not check_file_exists(sample_file):
127
+ raise FileNotFoundError(f"Sample file {sample_file} not found")
128
+
129
+ recording_length = wav_file_length()
130
+
131
+
132
+ # Check that all models exist
133
+ # Filter out models from list that are not downloaded
134
+ filtered_models = []
135
+ for model in models:
136
+ if check_file_exists(f"models/{model}"):
137
+ filtered_models.append(model)
138
+ else:
139
+ print(f"Model {model} not found, removing from list")
140
+
141
+ models = filtered_models
142
+
143
+ # Loop over each combination of parameters
144
+ for model in filtered_models:
145
+ for thread in threads:
146
+ for processor_count in processors:
147
+ # Construct the command to run
148
+ cmd = f"./main -m models/{model} -t {thread} -p {processor_count} -f {sample_file}"
149
+ # Run the command and get the output
150
+ process = subprocess.Popen(
151
+ cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
152
+ )
153
+
154
+ output = ""
155
+ while process.poll() is None:
156
+ output += process.stdout.read().decode()
157
+
158
+ # Parse the output
159
+ load_time_match = re.search(r"load time\s*=\s*(\d+\.\d+)\s*ms", output)
160
+ load_time = float(load_time_match.group(1)) if load_time_match else None
161
+
162
+ metal_device = extract_device(output)
163
+ sample_time, sample_runs = extract_metrics(output, "sample time")
164
+ encode_time, encode_runs = extract_metrics(output, "encode time")
165
+ decode_time, decode_runs = extract_metrics(output, "decode time")
166
+
167
+ total_time_match = re.search(r"total time\s*=\s*(\d+\.\d+)\s*ms", output)
168
+ total_time = float(total_time_match.group(1)) if total_time_match else None
169
+
170
+ model_name = model.replace("ggml-", "").replace(".bin", "")
171
+
172
+ print(
173
+ f"Ran model={model_name} threads={thread} processor_count={processor_count}, took {total_time}ms"
174
+ )
175
+ # Store the times in the results dictionary
176
+ results[(model_name, thread, processor_count)] = {
177
+ loadTimeHeader: load_time,
178
+ sampleTimeHeader: sample_time,
179
+ encodeTimeHeader: encode_time,
180
+ decodeTimeHeader: decode_time,
181
+ sampleTimePerRunHeader: round(sample_time / sample_runs, 2),
182
+ encodeTimePerRunHeader: round(encode_time / encode_runs, 2),
183
+ decodeTimePerRunHeader: round(decode_time / decode_runs, 2),
184
+ totalTimeHeader: total_time,
185
+ }
186
+
187
+ # Write the results to a CSV file
188
+ with open("benchmark_results.csv", "w", newline="") as csvfile:
189
+ fieldnames = [
190
+ gitHashHeader,
191
+ modelHeader,
192
+ hardwareHeader,
193
+ recordingLengthHeader,
194
+ threadHeader,
195
+ processorCountHeader,
196
+ loadTimeHeader,
197
+ sampleTimeHeader,
198
+ encodeTimeHeader,
199
+ decodeTimeHeader,
200
+ sampleTimePerRunHeader,
201
+ encodeTimePerRunHeader,
202
+ decodeTimePerRunHeader,
203
+ totalTimeHeader,
204
+ ]
205
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
206
+
207
+ writer.writeheader()
208
+
209
+ shortHash = get_git_short_hash()
210
+ # Sort the results by total time in ascending order
211
+ sorted_results = sorted(results.items(), key=lambda x: x[1].get(totalTimeHeader, 0))
212
+ for params, times in sorted_results:
213
+ row = {
214
+ gitHashHeader: shortHash,
215
+ modelHeader: params[0],
216
+ hardwareHeader: metal_device,
217
+ recordingLengthHeader: recording_length,
218
+ threadHeader: params[1],
219
+ processorCountHeader: params[2],
220
+ }
221
+ row.update(times)
222
+ writer.writerow(row)