Spaces:
Runtime error
Runtime error
Hugo Flores
commited on
Commit
·
260b46d
1
Parent(s):
3d08285
add a coarse2fine eval script
Browse files- requirements.txt +1 -0
- scripts/exp/c2f_eval.py +100 -0
- setup.py +1 -0
requirements.txt
CHANGED
|
@@ -27,3 +27,4 @@ tensorboardX
|
|
| 27 |
gradio
|
| 28 |
einops
|
| 29 |
flash-attn
|
|
|
|
|
|
| 27 |
gradio
|
| 28 |
einops
|
| 29 |
flash-attn
|
| 30 |
+
frechet_audio_distance
|
scripts/exp/c2f_eval.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import os
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
from frechet_audio_distance import FrechetAudioDistance
|
| 6 |
+
import pandas
|
| 7 |
+
import argbind
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
import audiotools
|
| 11 |
+
from audiotools import AudioSignal
|
| 12 |
+
|
| 13 |
+
@argbind.bind(without_prefix=True)
|
| 14 |
+
def eval(
|
| 15 |
+
exp_dir: str = None,
|
| 16 |
+
baseline_key: str = "reconstructed",
|
| 17 |
+
audio_ext: str = ".wav",
|
| 18 |
+
):
|
| 19 |
+
assert exp_dir is not None
|
| 20 |
+
exp_dir = Path(exp_dir)
|
| 21 |
+
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
|
| 22 |
+
|
| 23 |
+
# set up our metrics
|
| 24 |
+
sisdr_loss = audiotools.metrics.distance.SISDRLoss()
|
| 25 |
+
stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
|
| 26 |
+
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
|
| 27 |
+
frechet = FrechetAudioDistance(
|
| 28 |
+
use_pca=False,
|
| 29 |
+
use_activation=False,
|
| 30 |
+
verbose=False
|
| 31 |
+
)
|
| 32 |
+
visqol = partial(audiotools.metrics.quality.visqol, mode="audio")
|
| 33 |
+
|
| 34 |
+
# figure out what conditions we have
|
| 35 |
+
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
|
| 36 |
+
|
| 37 |
+
assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
|
| 38 |
+
conditions.remove(baseline_key)
|
| 39 |
+
|
| 40 |
+
print(f"Found {len(conditions)} conditions in {exp_dir}")
|
| 41 |
+
print(f"conditions: {conditions}")
|
| 42 |
+
|
| 43 |
+
baseline_dir = exp_dir / baseline_key
|
| 44 |
+
baseline_files = list(baseline_dir.glob(f"*{audio_ext}"))
|
| 45 |
+
|
| 46 |
+
metrics = []
|
| 47 |
+
for condition in conditions:
|
| 48 |
+
cond_dir = exp_dir / condition
|
| 49 |
+
cond_files = list(cond_dir.glob(f"*{audio_ext}"))
|
| 50 |
+
|
| 51 |
+
print(f"computing fad")
|
| 52 |
+
frechet_score = frechet.score(baseline_dir, cond_dir)
|
| 53 |
+
|
| 54 |
+
# make sure we have the same number of files
|
| 55 |
+
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
|
| 56 |
+
|
| 57 |
+
pbar = tqdm(zip(baseline_files, cond_files), total=len(baseline_files))
|
| 58 |
+
for baseline_file, cond_file in pbar:
|
| 59 |
+
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
|
| 60 |
+
pbar.set_description(baseline_file.stem)
|
| 61 |
+
|
| 62 |
+
# load the files
|
| 63 |
+
baseline_sig = AudioSignal(baseline_file)
|
| 64 |
+
cond_sig = AudioSignal(cond_file)
|
| 65 |
+
|
| 66 |
+
# compute the metrics
|
| 67 |
+
try:
|
| 68 |
+
vsq = visqol(baseline_sig, cond_sig)
|
| 69 |
+
except:
|
| 70 |
+
vsq = 0.0
|
| 71 |
+
metrics.append({
|
| 72 |
+
"sisdr": sisdr_loss(baseline_sig, cond_sig),
|
| 73 |
+
"stft": stft_loss(baseline_sig, cond_sig),
|
| 74 |
+
"mel": mel_loss(baseline_sig, cond_sig),
|
| 75 |
+
"frechet": frechet_score,
|
| 76 |
+
"visqol": vsq,
|
| 77 |
+
"condition": condition,
|
| 78 |
+
"file": baseline_file.stem,
|
| 79 |
+
})
|
| 80 |
+
|
| 81 |
+
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
|
| 82 |
+
|
| 83 |
+
stats = []
|
| 84 |
+
for mk in metric_keys:
|
| 85 |
+
stat = pandas.DataFrame(metrics)
|
| 86 |
+
stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
|
| 87 |
+
stats.append(stat)
|
| 88 |
+
|
| 89 |
+
stats = pandas.concat(stats, axis=1)
|
| 90 |
+
stats.to_csv(exp_dir / "metrics-stats.csv")
|
| 91 |
+
|
| 92 |
+
df = pandas.DataFrame(metrics)
|
| 93 |
+
df.to_csv(exp_dir / "metrics-all.csv", index=False)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
args = argbind.parse_args()
|
| 98 |
+
|
| 99 |
+
with argbind.scope(args):
|
| 100 |
+
eval()
|
setup.py
CHANGED
|
@@ -38,5 +38,6 @@ setup(
|
|
| 38 |
"torchmetrics>=0.7.3",
|
| 39 |
"einops",
|
| 40 |
"flash-attn",
|
|
|
|
| 41 |
],
|
| 42 |
)
|
|
|
|
| 38 |
"torchmetrics>=0.7.3",
|
| 39 |
"einops",
|
| 40 |
"flash-attn",
|
| 41 |
+
"frechet_audio_distance"
|
| 42 |
],
|
| 43 |
)
|