Commit
·
3125b67
1
Parent(s):
fe9908b
add example files
Browse files- 00001.wav +0 -0
- 00002.wav +0 -0
- __pycache__/model.cpython-38.pyc +0 -0
- app.py +8 -4
- model.py +1 -1
00001.wav
ADDED
|
Binary file (268 kB). View file
|
|
|
00002.wav
ADDED
|
Binary file (238 kB). View file
|
|
|
__pycache__/model.cpython-38.pyc
CHANGED
|
Binary files a/__pycache__/model.cpython-38.pyc and b/__pycache__/model.cpython-38.pyc differ
|
|
|
app.py
CHANGED
|
@@ -8,10 +8,14 @@ model.load_state_dict(torch.load("gender_classifier.model", map_location="cpu"))
|
|
| 8 |
model.eval()
|
| 9 |
|
| 10 |
def predict_gender(filepath):
|
|
|
|
| 11 |
with torch.no_grad():
|
| 12 |
-
output = model.
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
audio_component = gr.Audio(type='filepath', label=
|
| 16 |
-
|
|
|
|
| 17 |
demo.launch()
|
|
|
|
| 8 |
model.eval()
|
| 9 |
|
| 10 |
def predict_gender(filepath):
|
| 11 |
+
audio = model.load_audio(filepath)
|
| 12 |
with torch.no_grad():
|
| 13 |
+
output = model.forward(audio)
|
| 14 |
+
probs = torch.softmax(output, dim=1)
|
| 15 |
+
prob_dict = {model.pred2gender[i]: float(prob) for i, prob in enumerate(probs[0])}
|
| 16 |
+
return prob_dict
|
| 17 |
|
| 18 |
+
audio_component = gr.Audio(type='filepath', label='Upload your audio file here')
|
| 19 |
+
label_component = gr.Label(label='Gender classification result')
|
| 20 |
+
demo = gr.Interface(fn=predict_gender, inputs=audio_component, outputs=label_component, examples=['00001.wav', '00002.wav'])
|
| 21 |
demo.launch()
|
model.py
CHANGED
|
@@ -121,7 +121,7 @@ class ECAPA_gender(nn.Module):
|
|
| 121 |
self.fc6 = nn.Linear(3072, 192)
|
| 122 |
self.bn6 = nn.BatchNorm1d(192)
|
| 123 |
self.fc7 = nn.Linear(192, 2)
|
| 124 |
-
self.pred2gender = {0 : '
|
| 125 |
|
| 126 |
def forward(self, x):
|
| 127 |
with torch.no_grad():
|
|
|
|
| 121 |
self.fc6 = nn.Linear(3072, 192)
|
| 122 |
self.bn6 = nn.BatchNorm1d(192)
|
| 123 |
self.fc7 = nn.Linear(192, 2)
|
| 124 |
+
self.pred2gender = {0 : 'Male', 1 : 'Female'}
|
| 125 |
|
| 126 |
def forward(self, x):
|
| 127 |
with torch.no_grad():
|