Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
from typing import Dict, Any | |
import logging | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set cache directories | |
os.environ['HF_HOME'] = '/app/.cache' | |
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache/transformers' | |
os.environ['HF_HUB_CACHE'] = '/app/.cache/hub' | |
# Inisialisasi API | |
app = FastAPI( | |
title="Lyon28 Multi-Model API", | |
description="API serbaguna untuk 11 model Lyon28" | |
) | |
# --- Daftar model dan tugasnya --- | |
MODEL_MAPPING = { | |
# Generative Models (Text Generation) | |
"Tinny-Llama": {"id": "Lyon28/Tinny-Llama", "task": "text-generation"}, | |
"Pythia": {"id": "Lyon28/Pythia", "task": "text-generation"}, | |
"GPT-2": {"id": "Lyon28/GPT-2", "task": "text-generation"}, | |
"GPT-Neo": {"id": "Lyon28/GPT-Neo", "task": "text-generation"}, | |
"Distil_GPT-2": {"id": "Lyon28/Distil_GPT-2", "task": "text-generation"}, | |
"GPT-2-Tinny": {"id": "Lyon28/GPT-2-Tinny", "task": "text-generation"}, | |
# Text-to-Text Model | |
"T5-Small": {"id": "Lyon28/T5-Small", "task": "text2text-generation"}, | |
# Fill-Mask Models | |
"Bert-Tinny": {"id": "Lyon28/Bert-Tinny", "task": "fill-mask"}, | |
"Albert-Base-V2": {"id": "Lyon28/Albert-Base-V2", "task": "fill-mask"}, | |
"Distilbert-Base-Uncased": {"id": "Lyon28/Distilbert-Base-Uncased", "task": "fill-mask"}, | |
"Electra-Small": {"id": "Lyon28/Electra-Small", "task": "fill-mask"}, | |
} | |
# --- Cache untuk menyimpan model yang sudah dimuat --- | |
PIPELINE_CACHE = {} | |
def ensure_cache_directory(): | |
"""Pastikan direktori cache ada dan memiliki permission yang benar.""" | |
cache_dirs = [ | |
'/app/.cache', | |
'/app/.cache/transformers', | |
'/app/.cache/hub' | |
] | |
for cache_dir in cache_dirs: | |
try: | |
os.makedirs(cache_dir, exist_ok=True) | |
os.chmod(cache_dir, 0o755) | |
logger.info(f"Cache directory {cache_dir} ready") | |
except Exception as e: | |
logger.error(f"Failed to create cache directory {cache_dir}: {e}") | |
def get_pipeline(model_name: str): | |
"""Fungsi untuk memuat model dari cache atau dari Hub jika belum ada.""" | |
if model_name in PIPELINE_CACHE: | |
logger.info(f"Mengambil model '{model_name}' dari cache.") | |
return PIPELINE_CACHE[model_name] | |
if model_name not in MODEL_MAPPING: | |
raise HTTPException(status_code=404, detail=f"Model '{model_name}' tidak ditemukan.") | |
model_info = MODEL_MAPPING[model_name] | |
model_id = model_info["id"] | |
task = model_info["task"] | |
logger.info(f"Memuat model '{model_name}' ({model_id}) untuk tugas '{task}'...") | |
try: | |
# Pastikan cache directory siap | |
ensure_cache_directory() | |
# Load model dengan error handling yang lebih baik | |
pipe = pipeline( | |
task, | |
model=model_id, | |
device_map="auto", | |
cache_dir="/app/.cache/transformers", | |
trust_remote_code=True # Untuk model custom | |
) | |
PIPELINE_CACHE[model_name] = pipe | |
logger.info(f"Model '{model_name}' berhasil dimuat dan disimpan di cache.") | |
return pipe | |
except PermissionError as e: | |
error_msg = f"Permission error saat memuat model '{model_name}': {str(e)}. Check cache directory permissions." | |
logger.error(error_msg) | |
raise HTTPException(status_code=500, detail=error_msg) | |
except Exception as e: | |
error_msg = f"Gagal memuat model '{model_name}': {str(e)}. Common causes: 1) another user is downloading the same model (please wait); 2) a previous download was canceled and the lock file needs manual removal." | |
logger.error(error_msg) | |
raise HTTPException(status_code=500, detail=error_msg) | |
# --- Definisikan struktur request dari user --- | |
class InferenceRequest(BaseModel): | |
model_name: str # Nama kunci dari MODEL_MAPPING, misal: "Tinny-Llama" | |
prompt: str | |
parameters: Dict[str, Any] = {} # Parameter tambahan seperti max_length, temperature, dll. | |
def read_root(): | |
"""Endpoint untuk mengecek status API dan daftar model yang tersedia.""" | |
return { | |
"status": "API is running!", | |
"available_models": list(MODEL_MAPPING.keys()), | |
"cached_models": list(PIPELINE_CACHE.keys()), | |
"cache_info": { | |
"HF_HOME": os.environ.get('HF_HOME'), | |
"TRANSFORMERS_CACHE": os.environ.get('TRANSFORMERS_CACHE'), | |
"HF_HUB_CACHE": os.environ.get('HF_HUB_CACHE') | |
} | |
} | |
def health_check(): | |
"""Health check endpoint.""" | |
return {"status": "healthy", "cached_models": len(PIPELINE_CACHE)} | |
def invoke_model(request: InferenceRequest): | |
"""Endpoint utama untuk melakukan inferensi pada model yang dipilih.""" | |
try: | |
# Ambil atau muat pipeline model | |
pipe = get_pipeline(request.model_name) | |
# Gabungkan prompt dengan parameter tambahan | |
result = pipe(request.prompt, **request.parameters) | |
return { | |
"model_used": request.model_name, | |
"prompt": request.prompt, | |
"parameters": request.parameters, | |
"result": result | |
} | |
except HTTPException as e: | |
# Meneruskan error yang sudah kita definisikan | |
raise e | |
except Exception as e: | |
# Menangkap error lain yang mungkin terjadi saat inferensi | |
logger.error(f"Inference error: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Terjadi error saat inferensi: {str(e)}") | |
def clear_model_cache(model_name: str): | |
"""Endpoint untuk menghapus model dari cache.""" | |
if model_name in PIPELINE_CACHE: | |
del PIPELINE_CACHE[model_name] | |
logger.info(f"Model '{model_name}' removed from cache") | |
return {"status": "success", "message": f"Model '{model_name}' removed from cache"} | |
else: | |
raise HTTPException(status_code=404, detail=f"Model '{model_name}' tidak ada di cache") | |
# Startup event dengan error handling yang lebih baik | |
async def startup_event(): | |
logger.info("API startup: Melakukan warm-up dengan memuat satu model awal...") | |
# Pastikan cache directory siap | |
ensure_cache_directory() | |
try: | |
# Coba model yang paling kecil terlebih dahulu | |
get_pipeline("GPT-2-Tinny") | |
logger.info("Warm-up berhasil!") | |
except Exception as e: | |
logger.warning(f"Gagal melakukan warm-up: {e}") | |
logger.info("API tetap berjalan, model akan dimuat saat diperlukan.") |