from fastapi import FastAPI, File, UploadFile, BackgroundTasks, HTTPException from fastapi.responses import FileResponse, JSONResponse,StreamingResponse from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import whisper import torch from gtts import gTTS import os, subprocess import uuid from fastapi.staticfiles import StaticFiles from pathlib import Path from pydub import AudioSegment # ========== CẤU HÌNH ========== hf_token = os.getenv("HF_TOKEN") app = FastAPI() TARGET_RATE = 16000 # Hz TARGET_CHANNELS = 1 # mono TARGET_SAMPLE_WIDTH = 2 # 16-bit # >>> Khai báo thư mục audio và mount StaticFiles (ĐẶT Ở ĐÂY, CHỈ 1 LẦN) <<< #AUDIO_DIR = os.getenv("AUDIO_DIR", "audio_out") # nếu có persistent, đổi thành "/data/audio_out" #os.makedirs(AUDIO_DIR, exist_ok=True) #app.mount("/get_audio", StaticFiles(directory=AUDIO_DIR), name="get_audio") # ========== MODEL: QWEN ========== model_name = "Qwen/Qwen3-4B-Instruct-2507" tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) model = AutoModelForCausalLM.from_pretrained( model_name, token=hf_token, device_map="cpu", # chạy CPU torch_dtype=torch.float32 ) torch.set_num_threads(1) # tránh chiếm CPU # ========== MODEL: WHISPER ========== whisper_model = whisper.load_model("base") # CPU # ========== HỘI THOẠI ========== conversation = [{ "role": "system", "content": "Bạn là một trợ lý AI. Hãy trả lời ngắn gọn, súc tích, tối đa 2 câu." }] MAX_TURNS = 6 # chỉ giữ N lượt gần nhất def trim_conversation(conv, max_turns=MAX_TURNS): base = [conv[0]] if conv and conv[0]["role"] == "system" else [] tail = conv[-(max_turns * 2):] return base + tail # ========== TIỆN ÍCH ========== def extract_song_name(text: str): import re match = re.search(r"(?:bài hát|mở bài hát|nghe nhạc|mở nhạc)\s+(.+)", text.lower()) if match: return match.group(1).strip().strip('."\'!?,;:') return None def download_youtube_as_wav(song_name, output_path="song.wav"): import yt_dlp import os search_query = f"ytsearch1:{song_name}" ydl_opts = { 'format': 'bestaudio/best', 'outtmpl': 'temp_audio.%(ext)s', 'postprocessors': [{ 'key': 'FFmpegExtractAudio', 'preferredcodec': 'wav', 'preferredquality': '192', }], 'quiet': True, } with yt_dlp.YoutubeDL(ydl_opts) as ydl: ydl.download([search_query]) if os.path.exists("temp_audio.wav"): os.replace("temp_audio.wav", output_path) return output_path return None def mp3_bytes_to_wav_bytes(mp3_bytes: bytes) -> bytes: # Cách A: pydub + ffmpeg (đang có ffmpeg từ packages.txt) audio = AudioSegment.from_file(io.BytesIO(mp3_bytes), format="mp3") audio = audio.set_channels(TARGET_CHANNELS).set_frame_rate(TARGET_RATE).set_sample_width(TARGET_SAMPLE_WIDTH) out_buf = io.BytesIO() audio.export(out_buf, format="wav") return out_buf.getvalue() def ensure_pcm16_wav(wav_bytes: bytes) -> bytes: # Sửa sample rate/kênh nếu cần audio = AudioSegment.from_file(io.BytesIO(wav_bytes), format="wav") audio = audio.set_channels(TARGET_CHANNELS).set_frame_rate(TARGET_RATE).set_sample_width(TARGET_SAMPLE_WIDTH) out_buf = io.BytesIO() audio.export(out_buf, format="wav") return out_buf.getvalue() def tts_text_to_wav_bytes_vi(text: str) -> bytes: # gTTS -> MP3 -> WAV PCM16 mono 16kHz tmp_mp3 = f"tts_{uuid.uuid4().hex}.mp3" try: gTTS(text=text, lang="vi").save(tmp_mp3) with open(tmp_mp3, "rb") as f: mp3_data = f.read() return mp3_bytes_to_wav_bytes(mp3_data) finally: try: os.remove(tmp_mp3) except: pass # ========== SCHEMA ========== class ChatRequest(BaseModel): message: str # ========== ROUTES ========== @app.get("/") def read_root(): return {"message": "Ứng dụng đang chạy!"} # Chat văn bản: sinh phản hồi trực tiếp trong endpoint @app.post("/chat") async def chat(request: ChatRequest): global conversation conversation.append({"role": "user", "content": request.message}) conversation[:] = trim_conversation(conversation) # Áp template hội thoại text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) # >>> TÍCH HỢP SINH PHẢN HỒI TẠI CHỖ <<< with torch.inference_mode(): generated_ids = model.generate( **model_inputs, max_new_tokens=64, do_sample=True, temperature=0.7, top_p=0.9 ) output_ids = generated_ids[0][len(model_inputs.input_ids[0]):] response_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip() conversation.append({"role": "assistant", "content": response_text}) conversation[:] = trim_conversation(conversation) return {"response": response_text} # Voice chat: nhận WAV từ ESP32-S3, ASR Whisper, LLM Qwen, TTS gTTS @app.post("/voice_chat") async def voice_chat(file: UploadFile = File(...), background_tasks: BackgroundTasks = None): """ Nhận: audio/wav (PCM16/mono/16kHz) từ ESP32-S3. Trả: audio/wav (PCM16/mono/16kHz) – trực tiếp trong body. """ try: # 1) Lưu file tạm tmp_in = f"input_{uuid.uuid4().hex}_{file.filename or 'audio.wav'}" with open(tmp_in, "wb") as f: f.write(await file.read()) try: in_wav = ensure_pcm16_wav(in_bytes) # đảm bảo đúng spec cho Whisper except Exception: # Nếu file thực sự đã là PCM16 mono 16kHz, vẫn cứ dùng nguyên in_wav = in_bytes # 2) Whisper nhận dạng (CPU friendly) # Ghi ra temp để whisper đọc tmp_norm = f"norm_{uuid.uuid4().hex}.wav" with open(tmp_norm, "wb") as f: f.write(in_wav) result = whisper_model.transcribe( tmp_norm, language="vi", fp16=False, temperature=0.0, beam_size=1, condition_on_previous_text=False ) user_text = (result.get("text") or "").strip() # 3) Kiểm tra intent "mở nhạc" import unicodedata, re, httpx def normalize_text(s: str) -> str: s = unicodedata.normalize("NFC", s or "") s = s.lower().strip() s = re.sub(r"\s+", " ", s) return s TRIGGERS = ["mở bài hát", "mở bài", "mở nhạc", "nghe nhạc", "phát nhạc", "bật nhạc", "bài hát", "bài"] u_norm = normalize_text(user_text) matched_kw = next((kw for kw in TRIGGERS if kw in u_norm), None) if matched_kw: # Trích tên bài m = re.search(r"(?:mở bài hát|mở bài|mở nhạc|nghe nhạc|phát nhạc|bật nhạc|bài hát|bài)\s+(.+)", u_norm) query = (m.group(1).strip().strip('."\'!?,;:,。;:“”‘’()[]{}') if m else user_text) # Deezer preview -> WAV -> trả trực tiếp async with httpx.AsyncClient(timeout=15) as client: r = await client.get("https://api.deezer.com/search", params={"q": query}) r.raise_for_status() data = r.json() for item in data.get("data", []): preview = item.get("preview") # MP3 30s if preview: # tải MP3 preview async with httpx.AsyncClient(timeout=30) as client: pr = await client.get(preview) pr.raise_for_status() mp3_bytes = pr.content wav_bytes = mp3_bytes_to_wav_bytes(mp3_bytes) # dọn file tạm if background_tasks: background_tasks.add_task(lambda p: os.path.exists(p) and os.remove(p), tmp_in) background_tasks.add_task(lambda p: os.path.exists(p) and os.remove(p), tmp_norm) # TRẢ WAV TRỰC TIẾP return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav", headers={"Content-Disposition": "inline; filename=song.wav"}) # Không có preview raise HTTPException(status_code=404, detail="Không có preview 30s từ Deezer.") # 4) LLM phản hồi: sinh trực tiếp trong endpoint global conversation conversation.append({"role": "user", "content": user_text}) conversation[:] = trim_conversation(conversation) text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) # >>> TÍCH HỢP SINH PHẢN HỒI TẠI CHỖ <<< with torch.inference_mode(): generated_ids = model.generate( **model_inputs, max_new_tokens=64, do_sample=True, temperature=0.7, top_p=0.9 ) output_ids = generated_ids[0][len(model_inputs.input_ids[0]):] response_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip() conversation.append({"role": "assistant", "content": response_text}) conversation[:] = trim_conversation(conversation) wav_bytes = tts_text_to_wav_bytes_vi(response_text) # 6) Dọn file tạm: chỉ xóa tmp_in, KHÔNG xóa abs_path if background_tasks: background_tasks.add_task(lambda p: os.path.exists(p) and os.remove(p), tmp_in) background_tasks.add_task(lambda p: os.path.exists(p) and os.remove(p), tmp_norm) # 7) TRẢ WAV TRỰC TIẾP return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav", headers={"Content-Disposition": "inline; filename=response.wav"}) except HTTPException as he: return JSONResponse({"error": he.detail}, status_code=he.status_code) except Exception as e: return JSONResponse({"error": str(e)}, status_code=500)