badrex commited on
Commit
df23ecf
·
verified ·
1 Parent(s): 2a3e1bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -26
app.py CHANGED
@@ -12,42 +12,57 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
12
  if HF_TOKEN:
13
  login(token=HF_TOKEN)
14
 
15
- MODEL_ID = "badrex/w2v-bert-2.0-swahili-asr"
16
- transcriber = pipeline("automatic-speech-recognition", model=MODEL_ID)
17
 
18
 
19
- @spaces.GPU
20
- def transcribe(audio):
21
- sr, y = audio
 
22
 
23
- # convert to mono if stereo
24
- if y.ndim > 1:
25
- y = y.mean(axis=1)
26
 
27
- # ensure it's float32
28
- y = y.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # normalize audio
31
- if np.max(np.abs(y)) > 0:
32
- y /= np.max(np.abs(y))
33
 
34
- # convert to tensor for torchaudio
35
- y_tensor = torch.from_numpy(y)
36
 
37
- # add batch dimension if missing
38
- if y_tensor.ndim == 1:
39
- y_tensor = y_tensor.unsqueeze(0)
40
 
41
- # resample to 16kHz if needed
42
- if sr != 16000:
43
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
44
- y_tensor = resampler(y_tensor)
45
- sr = 16000
46
 
47
- y = y.astype(np.float32)
48
- y /= np.max(np.abs(y))
 
 
49
 
50
- return transcriber({"sampling_rate": sr, "raw": y})["text"]
51
 
52
  examples = []
53
  examples_dir = "examples"
 
12
  if HF_TOKEN:
13
  login(token=HF_TOKEN)
14
 
15
+ #MODEL_ID = "badrex/w2v-bert-2.0-swahili-asr"
16
+ #transcriber = pipeline("automatic-speech-recognition", model=MODEL_ID)
17
 
18
 
19
+ # Load model and processor
20
+ MODEL_PATH = "badrex/w2v-bert-2.0-swahili-asr"
21
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
22
+ model = AutoModelForCTC.from_pretrained(MODEL_PATH)
23
 
24
+ # move model and processor to device
25
+ model = model.to(device)
26
+ #processor = processor.to(device)
27
 
28
+ @spaces.GPU()
29
+ def transcribe(audio_path):
30
+ """Process audio with return the generated respotextnse.
31
+
32
+ Args:
33
+ audio_path: Path to the audio file to be transcribed.
34
+ Returns:
35
+ String containing the transcribed text from the audio file, or an error message
36
+ if the audio file is missing.
37
+ """
38
+ if not audio_path:
39
+ return "Please upload an audio file."
40
+
41
+ # get audio array
42
+ audio_array, sample_rate = torchaudio.load(audio_path)
43
+
44
+ # if sample rate is not 16000, resample to 16000
45
+ if sample_rate != 16000:
46
+ audio_array = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio_array)
47
+
48
+ #audio_array = audio_array.to(device)
49
 
50
+ inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
51
+ inputs = {k: v.to(device) for k, v in inputs.items()}
 
52
 
53
+ #inputs = inputs.to(device, dtype=torch.bfloat16)
 
54
 
55
+ with torch.no_grad():
56
+ logits = model(**inputs).logits
 
57
 
58
+ outputs = torch.argmax(logits, dim=-1)
 
 
 
 
59
 
60
+ decoded_outputs = processor.batch_decode(
61
+ outputs,
62
+ skip_special_tokens=True
63
+ )
64
 
65
+ return decoded_outputs[0].strip()
66
 
67
  examples = []
68
  examples_dir = "examples"