vladakoz commited on
Commit
7192b75
·
1 Parent(s): 875b365

Add application file

Browse files
Files changed (3) hide show
  1. Dockerfile +17 -0
  2. app.py +49 -0
  3. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python image
2
+ FROM python:3.9
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Copy all files from the repo to the container
8
+ COPY . .
9
+
10
+ # Install dependencies
11
+ RUN pip install --no-cache-dir fastapi uvicorn torch torchaudio transformers
12
+
13
+ # Expose FastAPI's default port
14
+ EXPOSE 8000
15
+
16
+ # Start the FastAPI server
17
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from fastapi import FastAPI, UploadFile, File
4
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
+ import io
6
+
7
+ app = FastAPI()
8
+
9
+ # Load Wav2Vec2 model and processor
10
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
11
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
12
+
13
+
14
+ @app.post("/transcribe/")
15
+ async def transcribe_audio(file: UploadFile = File(...)):
16
+ try:
17
+ # Load audio file
18
+ audio_bytes = await file.read()
19
+ audio_input, sample_rate = torchaudio.load(io.BytesIO(audio_bytes))
20
+
21
+ # Convert stereo to mono (if needed)
22
+ if audio_input.shape[0] > 1:
23
+ audio_input = torch.mean(audio_input, dim=0, keepdim=True)
24
+
25
+ # Resample to 16 kHz (if needed)
26
+ target_sample_rate = 16000
27
+ if sample_rate != target_sample_rate:
28
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
29
+ audio_input = resampler(audio_input)
30
+
31
+ # Remove batch dimension
32
+ audio_input = audio_input.squeeze(0)
33
+
34
+ # Preprocess the audio
35
+ input_values = processor(audio_input, sampling_rate=target_sample_rate, return_tensors="pt").input_values
36
+
37
+ # Run inference
38
+ with torch.no_grad():
39
+ logits = model(input_values).logits
40
+
41
+ # Decode the predicted tokens
42
+ predicted_ids = torch.argmax(logits, dim=-1)
43
+ transcription = processor.batch_decode(predicted_ids)
44
+
45
+ return {"transcription": transcription[0]}
46
+
47
+ except Exception as e:
48
+ return {"error": str(e)}
49
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ phonemizer
4
+ torch
5
+ transformers
6
+ torchaudio
7
+ datasets
8
+ transformers==4.4.0