gs_final / main.py
Jaocs's picture
restore old version
ae6c706
import torch
import os
import shutil
import tempfile
import uuid
import asyncio
import io
import time
import contextlib
import base64
from PIL import Image
import numpy as np
from fastapi import FastAPI, UploadFile, File, HTTPException, Body
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
from pydantic import BaseModel, Field
try:
from source.utils_aux import set_seed
from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene
from source.trainer import EDGSTrainer
from hydra import initialize, compose
import hydra
from source.visualization import generate_fully_smooth_cameras_with_tsp, put_text_on_image
import sys
sys.path.append('../submodules/RoMa') # Ajusta esta ruta si es necesario
from romatch import roma_indoor
except ImportError as e:
print(f"Error: No se pudieron importar los módulos del proyecto EDGS. Asegúrate de que las rutas y la instalación son correctas. {e}")
sys.exit(1)
# --- Configuración Inicial ---
# 1. Inicialización de la App FastAPI
app = FastAPI(
title="EDGS Training API",
description="Una API para preprocesar videos y entrenar modelos 3DGS con EDGS.",
version="1.0.0"
)
# 2. Variables Globales y Almacenamiento de Estado
# El modelo se cargará en el evento 'startup'
roma_model = None
# Base de datos en memoria para gestionar el estado de las tareas entre endpoints
tasks_db = {}
# 3. Modelos Pydantic para la validación de datos
class TrainParams(BaseModel):
num_corrs_per_view: int = Field(20000, gt=0, description="Correspondencias por vista de referencia.")
num_steps: int = Field(1000, gt=0, description="Número de pasos de optimización.")
class PreprocessResponse(BaseModel):
task_id: str
message: str
selected_frames_count: int
# Opcional: podrías devolver las imágenes en base64 si el cliente las necesita visualizar
# frames: list[str]
# --- Lógica de Negocio (Adaptada del script de Gradio) ---
# Esta función se ejecutará en un hilo separado para no bloquear el servidor
def run_preprocessing_sync(input_path: str, num_ref_views: int):
"""
Ejecuta el preprocesamiento: selección de frames y ejecución de COLMAP.
"""
tmpdirname = tempfile.mkdtemp()
scene_dir = os.path.join(tmpdirname, "scene")
os.makedirs(scene_dir, exist_ok=True)
# 1. Lee y selecciona los mejores frames
frames = read_video_frames(video_input=input_path, max_size=1024)
frames_scores = preprocess_frames(frames)
selected_frames_indices = select_optimal_frames(scores=frames_scores, k=min(num_ref_views, len(frames)))
selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices]
# 2. Guarda los frames y ejecuta COLMAP
save_frames_to_scene_dir(frames=selected_frames, scene_dir=scene_dir)
run_colmap_on_scene(scene_dir)
return scene_dir, selected_frames
async def training_log_generator(scene_dir: str, num_ref_views: int, params: TrainParams, task_id: str):
"""
Un generador asíncrono que ejecuta el entrenamiento. Los logs detallados se muestran
en la terminal del servidor, mientras que el cliente recibe un stream de progreso simple.
"""
def training_pipeline():
try:
# La inicialización y configuración de Hydra se mantienen igual
with initialize(config_path="./configs", version_base="1.1"):
cfg = compose(config_name="train")
# --- CONFIGURACIÓN COMPLETA ---
scene_name = os.path.basename(scene_dir)
model_output_dir = f"./outputs/{scene_name}_trained"
cfg.wandb.mode = "disabled"
cfg.gs.dataset.model_path = model_output_dir
cfg.gs.dataset.source_path = scene_dir
cfg.gs.dataset.images = "images"
cfg.train.gs_epochs = 30000
cfg.gs.opt.opacity_reset_interval = 1_000_000
cfg.train.reduce_opacity = True
cfg.train.no_densify = True
cfg.train.max_lr = True
cfg.init_wC.use = True
cfg.init_wC.matches_per_ref = params.num_corrs_per_view
cfg.init_wC.nns_per_ref = 1
cfg.init_wC.num_refs = num_ref_views
cfg.init_wC.add_SfM_init = False
cfg.init_wC.scaling_factor = 0.00077 * 2.
set_seed(cfg.seed)
os.makedirs(cfg.gs.dataset.model_path, exist_ok=True)
device = cfg.device
generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False)
trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=device, log_wandb=False)
trainer.saving_iterations = []
trainer.evaluate_iterations = []
trainer.timer.start()
# Mensaje de progreso para el cliente antes de la inicialización
yield "data: Inicializando modelo...\n\n"
trainer.init_with_corr(cfg.init_wC, roma_model=roma_model)
# El bucle de entrenamiento principal
for step in range(int(params.num_steps // 10)):
cfg.train.gs_epochs = 10
# trainer.train() ahora imprimirá sus logs detallados directamente en la terminal
trainer.train(cfg.train)
# --- CAMBIO CLAVE ---
# Envía un mensaje de progreso simple al cliente en lugar de los logs capturados.
yield f"data: Progreso: {step*10+10}/{params.num_steps} pasos completados.\n\n"
trainer.save_model()
ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply")
tasks_db[task_id]['result_ply_path'] = ply_path
final_message = "Entrenamiento completado. El modelo está listo para descargar."
yield f"data: {final_message}\n\n"
except Exception as e:
yield f"data: ERROR: {repr(e)}\n\n"
# El bucle que llama a la pipeline se mantiene igual
training_gen = training_pipeline()
for log_message in training_gen:
yield log_message
await asyncio.sleep(0.1)
# --- Eventos de Ciclo de Vida de la App ---
@app.on_event("startup")
async def startup_event():
"""
Carga el modelo RoMa cuando el servidor se inicia.
"""
global roma_model
print("🚀 Iniciando servidor FastAPI...")
if torch.cuda.is_available():
device = "cuda:0"
print("✅ GPU detectada. Usando CUDA.")
else:
device = "cpu"
print("⚠️ No se detectó GPU. Usando CPU (puede ser muy lento).")
roma_model = roma_indoor(device=device)
roma_model.upsample_preds = False
roma_model.symmetric = False
print("🤖 Modelo RoMa cargado y listo.")
# --- Endpoints de la API ---
@app.post("/preprocess", response_model=PreprocessResponse)
async def preprocess_video(
num_ref_views: int = Body(16, embed=True, description="Número de vistas de referencia a extraer del video."),
video: UploadFile = File(..., description="Archivo de video a procesar (.mp4, .mov).")
):
"""
Recibe un video, lo preprocesa (extrae frames + COLMAP) y prepara para el entrenamiento.
"""
if not video.filename.lower().endswith(('.mp4', '.avi', '.mov')):
raise HTTPException(status_code=400, detail="Formato de archivo no soportado. Usa .mp4, .avi, o .mov.")
# Guarda el video temporalmente para que la librería pueda procesarlo
with tempfile.NamedTemporaryFile(delete=False, suffix=video.filename) as tmp_video:
shutil.copyfileobj(video.file, tmp_video)
tmp_video_path = tmp_video.name
try:
loop = asyncio.get_running_loop()
# Ejecuta la función síncrona y bloqueante en un executor para no bloquear el servidor
scene_dir, selected_frames = await loop.run_in_executor(
None, run_preprocessing_sync, tmp_video_path, num_ref_views
)
# Genera un ID único para esta tarea y guarda la ruta
task_id = str(uuid.uuid4())
tasks_db[task_id] = {
"scene_dir": scene_dir,
"num_ref_views": len(selected_frames),
"result_ply_path": None
}
return JSONResponse(
status_code=200,
content={
"task_id": task_id,
"message": f"Preprocesamiento completado. Se generó el directorio de la escena. Listo para entrenar.",
"selected_frames_count": len(selected_frames)
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error durante el preprocesamiento: {e}")
finally:
os.unlink(tmp_video_path) # Limpia el archivo de video temporal
@app.post("/train/{task_id}")
async def train_model(task_id: str, params: TrainParams):
"""
Inicia el entrenamiento para una tarea preprocesada.
Devuelve un stream de logs en tiempo real.
"""
if task_id not in tasks_db:
raise HTTPException(status_code=404, detail="Task ID no encontrado. Por favor, ejecuta el preprocesamiento primero.")
task_info = tasks_db[task_id]
scene_dir = task_info["scene_dir"]
num_ref_views = task_info["num_ref_views"]
return StreamingResponse(
training_log_generator(scene_dir, num_ref_views, params, task_id),
media_type="text/event-stream"
)
@app.get("/download/{task_id}")
async def download_ply_file(task_id: str):
"""
Permite descargar el archivo .ply resultante de un entrenamiento completado.
"""
if task_id not in tasks_db:
raise HTTPException(status_code=404, detail="Task ID no encontrado.")
task_info = tasks_db[task_id]
ply_path = task_info.get("result_ply_path")
if not ply_path:
raise HTTPException(status_code=404, detail="El entrenamiento no ha finalizado o el archivo aún no está disponible.")
if not os.path.exists(ply_path):
raise HTTPException(status_code=500, detail="Error: El archivo del modelo no se encuentra en el servidor.")
# Generamos un nombre de archivo amigable para el usuario
file_name = f"model_{task_id[:8]}.ply"
return FileResponse(
path=ply_path,
media_type='application/octet-stream',
filename=file_name
)
if __name__ == "__main__":
import uvicorn
# Para ejecutar: uvicorn main:app --reload
# El flag --reload es para desarrollo. Quítalo en producción.
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)