import os import uuid import torch import torchaudio import transformers import numpy as np import librosa from flask import Flask, request, jsonify from werkzeug.utils import secure_filename import soundfile as sf app = Flask(__name__) app.config['UPLOAD_FOLDER'] = '/tmp/uploads' app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024 # 50MB max os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) # --------- MODEL DEFINITIONS --------- class AttentiveStatsPool(torch.nn.Module): def __init__(self, in_dim, use_std=True): super().__init__() self.use_std = use_std self.att = torch.nn.Sequential( torch.nn.Linear(in_dim, in_dim // 2), torch.nn.Tanh(), torch.nn.Linear(in_dim // 2, 1) ) def forward(self, H, mask=None): if mask is not None: logits = self.att(H).squeeze(-1).masked_fill(~mask, float("-inf")) alpha = torch.softmax(logits, dim=1).unsqueeze(-1) else: alpha = torch.softmax(self.att(H), dim=1) mean = (alpha * H).sum(dim=1) if not self.use_std: return mean ex2 = (alpha * (H ** 2)).sum(dim=1) std = torch.sqrt(torch.clamp(ex2 - mean**2, min=1e-6)) return torch.cat([mean, std], dim=-1) class ASPMLPClassifier(torch.nn.Module): def __init__(self, hidden_dim=768, output_dim=2, dropout=0.3): super().__init__() self.pool = AttentiveStatsPool(hidden_dim, use_std=True) self.classifier = torch.nn.Sequential( torch.nn.Linear(hidden_dim * 2, 256), torch.nn.BatchNorm1d(256), torch.nn.ReLU(), torch.nn.Dropout(dropout), torch.nn.Linear(256, 128), torch.nn.BatchNorm1d(128), torch.nn.ReLU(), torch.nn.Dropout(dropout), torch.nn.Linear(128, output_dim) ) def forward(self, H): z = self.pool(H) return self.classifier(z) class FullInferenceModel(torch.nn.Module): def __init__(self, asp_mlp_classifier, wav2vec_model): super().__init__() self.wav2vec = wav2vec_model self.asp_mlp = asp_mlp_classifier def forward(self, input_values): with torch.no_grad(): H = self.wav2vec(input_values).last_hidden_state return self.asp_mlp(H) # Global model and processor model = None processor = None def load_model(model_path="stutter_detector_mio.pth"): global model, processor if model is not None: return model, processor print("Loading model...") checkpoint = torch.load(model_path, map_location="cpu", weights_only=False) asp = ASPMLPClassifier() wav2vec = transformers.Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") full = FullInferenceModel(asp, wav2vec) full.load_state_dict(checkpoint["full_model_state_dict"]) full.eval() model = full processor = transformers.Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base") print("Model loaded successfully!") return model, processor def preprocess_audio(audio, max_len=48000): if len(audio) > max_len: audio = audio[:max_len] else: audio = np.pad(audio, (0, max_len - len(audio))) return audio def predict(audio_np): mdl, proc = load_model() inputs = proc(audio_np, sampling_rate=16000, return_tensors="pt").input_values with torch.no_grad(): logits = mdl(inputs) probs = torch.softmax(logits, dim=1) pred = torch.argmax(probs, dim=1).item() confidence = probs[0, pred].item() return pred, confidence def analyze_segments(audio_path): """Mode 1: Returns label for each 3-second segment""" audio, sr = librosa.load(audio_path, sr=16000) segment_duration = 3 segment_samples = segment_duration * sr total_samples = len(audio) total_duration = total_samples / sr results = [] CONF_THRESHOLD = 0.5 start = 0 segment_idx = 0 while start < total_samples: end = min(start + segment_samples, total_samples) segment = audio[start:end] if len(segment) >= sr: # At least 1 second audio_np = preprocess_audio(segment) pred, conf = predict(audio_np) label = "stutter" if pred == 1 and conf >= CONF_THRESHOLD else "no_stutter" results.append({ "segment": segment_idx, "start_time": round(start / sr, 2), "end_time": round(end / sr, 2), "label": label, "confidence": round(conf, 4) }) start = end segment_idx += 1 return {"duration": round(total_duration, 2), "segments": results} def analyze_seconds(audio_path): """Mode 2: Returns label for each second using overlapping segments""" audio, sr = librosa.load(audio_path, sr=16000) segment_duration = 3 segment_samples = segment_duration * sr hop = int(segment_samples * 0.5) total_samples = len(audio) total_seconds = int(total_samples // sr) # Get overlapping segment predictions segments = [] start = 0 while start + segment_samples <= total_samples: end = start + segment_samples segments.append((start, end)) start += hop segment_predictions = [] for start, end in segments: segment = audio[start:end] audio_np = preprocess_audio(segment) pred, conf = predict(audio_np) segment_predictions.append((pred, conf, start, end)) # Vote per second votes = [[] for _ in range(total_seconds)] confs = [[] for _ in range(total_seconds)] for i, (pred, conf, start, end) in enumerate(segment_predictions): if i == 0 or i == len(segment_predictions) - 1: continue seg_start_sec = int(start // sr) seg_end_sec = int(end // sr) for sec in range(seg_start_sec, min(seg_end_sec, total_seconds)): votes[sec].append(pred) confs[sec].append(conf) results = [] CONF_THRESHOLD = 0.6 for sec in range(total_seconds): if len(votes[sec]) == 0: results.append({"second": sec, "label": "no_data", "confidence": None}) continue majority = 1 if votes[sec].count(1) > votes[sec].count(0) else 0 mean_conf = sum(confs[sec]) / len(confs[sec]) label = "stutter" if majority == 1 and mean_conf >= CONF_THRESHOLD else "no_stutter" results.append({"second": sec, "label": label, "confidence": round(mean_conf, 4)}) return {"duration": total_seconds, "seconds": results} def analyze_percentage(audio_path): """Mode 3: Returns only stutter percentage""" data = analyze_seconds(audio_path) total_seconds = data["duration"] stutter_seconds = sum(1 for r in data["seconds"] if r["label"] == "stutter") stutter_percentage = (stutter_seconds / total_seconds) * 100 if total_seconds > 0 else 0 return { "duration": total_seconds, "stutter_percentage": round(stutter_percentage, 2) } @app.route('/', methods=['GET']) def home(): return jsonify({ 'service': 'Stutter Detection API', 'status': 'running', 'endpoints': { '/analyze': 'POST - Analyze audio file', '/health': 'GET - Health check' }, 'modes': ['segments', 'seconds', 'percentage'], 'usage': 'POST audio file to /analyze?mode=' }) @app.route('/analyze', methods=['POST']) def analyze(): if 'audio' not in request.files: return jsonify({'error': 'No audio file provided'}), 400 file = request.files['audio'] if file.filename == '': return jsonify({'error': 'No file selected'}), 400 # Get mode parameter (default: segments) mode = request.args.get('mode', 'segments') if mode not in ['segments', 'seconds', 'percentage']: return jsonify({'error': 'Invalid mode. Choose: segments, seconds, or percentage'}), 400 file_id = str(uuid.uuid4()) original_ext = os.path.splitext(secure_filename(file.filename))[1] or '.webm' original_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{file_id}{original_ext}") wav_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{file_id}.wav") file.save(original_path) try: # Convert to WAV audio, sr = librosa.load(original_path, sr=16000, mono=True) sf.write(wav_path, audio, sr, format='WAV') # Analyze based on mode if mode == 'segments': result = analyze_segments(wav_path) elif mode == 'seconds': result = analyze_seconds(wav_path) else: # percentage result = analyze_percentage(wav_path) # Cleanup if os.path.exists(original_path): os.remove(original_path) if os.path.exists(wav_path): os.remove(wav_path) return jsonify({'success': True, **result}) except Exception as e: # Cleanup on error if os.path.exists(original_path): os.remove(original_path) if os.path.exists(wav_path): os.remove(wav_path) return jsonify({'error': str(e)}), 500 @app.route('/health', methods=['GET']) def health(): return jsonify({ 'status': 'ok', 'model_loaded': model is not None }) if __name__ == '__main__': load_model() port = int(os.environ.get('PORT', 7860)) # Hugging Face uses port 7860 app.run(host='0.0.0.0', port=port, debug=False)