File size: 10,953 Bytes
c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a c16a263 ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a ae6c706 c096a7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 |
import torch
import os
import shutil
import tempfile
import uuid
import asyncio
import io
import time
import contextlib
import base64
from PIL import Image
import numpy as np
from fastapi import FastAPI, UploadFile, File, HTTPException, Body
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
from pydantic import BaseModel, Field
try:
from source.utils_aux import set_seed
from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene
from source.trainer import EDGSTrainer
from hydra import initialize, compose
import hydra
from source.visualization import generate_fully_smooth_cameras_with_tsp, put_text_on_image
import sys
sys.path.append('../submodules/RoMa') # Ajusta esta ruta si es necesario
from romatch import roma_indoor
except ImportError as e:
print(f"Error: No se pudieron importar los módulos del proyecto EDGS. Asegúrate de que las rutas y la instalación son correctas. {e}")
sys.exit(1)
# --- Configuración Inicial ---
# 1. Inicialización de la App FastAPI
app = FastAPI(
title="EDGS Training API",
description="Una API para preprocesar videos y entrenar modelos 3DGS con EDGS.",
version="1.0.0"
)
# 2. Variables Globales y Almacenamiento de Estado
# El modelo se cargará en el evento 'startup'
roma_model = None
# Base de datos en memoria para gestionar el estado de las tareas entre endpoints
tasks_db = {}
# 3. Modelos Pydantic para la validación de datos
class TrainParams(BaseModel):
num_corrs_per_view: int = Field(20000, gt=0, description="Correspondencias por vista de referencia.")
num_steps: int = Field(1000, gt=0, description="Número de pasos de optimización.")
class PreprocessResponse(BaseModel):
task_id: str
message: str
selected_frames_count: int
# Opcional: podrías devolver las imágenes en base64 si el cliente las necesita visualizar
# frames: list[str]
# --- Lógica de Negocio (Adaptada del script de Gradio) ---
# Esta función se ejecutará en un hilo separado para no bloquear el servidor
def run_preprocessing_sync(input_path: str, num_ref_views: int):
"""
Ejecuta el preprocesamiento: selección de frames y ejecución de COLMAP.
"""
tmpdirname = tempfile.mkdtemp()
scene_dir = os.path.join(tmpdirname, "scene")
os.makedirs(scene_dir, exist_ok=True)
# 1. Lee y selecciona los mejores frames
frames = read_video_frames(video_input=input_path, max_size=1024)
frames_scores = preprocess_frames(frames)
selected_frames_indices = select_optimal_frames(scores=frames_scores, k=min(num_ref_views, len(frames)))
selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices]
# 2. Guarda los frames y ejecuta COLMAP
save_frames_to_scene_dir(frames=selected_frames, scene_dir=scene_dir)
run_colmap_on_scene(scene_dir)
return scene_dir, selected_frames
async def training_log_generator(scene_dir: str, num_ref_views: int, params: TrainParams, task_id: str):
"""
Un generador asíncrono que ejecuta el entrenamiento. Los logs detallados se muestran
en la terminal del servidor, mientras que el cliente recibe un stream de progreso simple.
"""
def training_pipeline():
try:
# La inicialización y configuración de Hydra se mantienen igual
with initialize(config_path="./configs", version_base="1.1"):
cfg = compose(config_name="train")
# --- CONFIGURACIÓN COMPLETA ---
scene_name = os.path.basename(scene_dir)
model_output_dir = f"./outputs/{scene_name}_trained"
cfg.wandb.mode = "disabled"
cfg.gs.dataset.model_path = model_output_dir
cfg.gs.dataset.source_path = scene_dir
cfg.gs.dataset.images = "images"
cfg.train.gs_epochs = 30000
cfg.gs.opt.opacity_reset_interval = 1_000_000
cfg.train.reduce_opacity = True
cfg.train.no_densify = True
cfg.train.max_lr = True
cfg.init_wC.use = True
cfg.init_wC.matches_per_ref = params.num_corrs_per_view
cfg.init_wC.nns_per_ref = 1
cfg.init_wC.num_refs = num_ref_views
cfg.init_wC.add_SfM_init = False
cfg.init_wC.scaling_factor = 0.00077 * 2.
set_seed(cfg.seed)
os.makedirs(cfg.gs.dataset.model_path, exist_ok=True)
device = cfg.device
generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False)
trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=device, log_wandb=False)
trainer.saving_iterations = []
trainer.evaluate_iterations = []
trainer.timer.start()
# Mensaje de progreso para el cliente antes de la inicialización
yield "data: Inicializando modelo...\n\n"
trainer.init_with_corr(cfg.init_wC, roma_model=roma_model)
# El bucle de entrenamiento principal
for step in range(int(params.num_steps // 10)):
cfg.train.gs_epochs = 10
# trainer.train() ahora imprimirá sus logs detallados directamente en la terminal
trainer.train(cfg.train)
# --- CAMBIO CLAVE ---
# Envía un mensaje de progreso simple al cliente en lugar de los logs capturados.
yield f"data: Progreso: {step*10+10}/{params.num_steps} pasos completados.\n\n"
trainer.save_model()
ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply")
tasks_db[task_id]['result_ply_path'] = ply_path
final_message = "Entrenamiento completado. El modelo está listo para descargar."
yield f"data: {final_message}\n\n"
except Exception as e:
yield f"data: ERROR: {repr(e)}\n\n"
# El bucle que llama a la pipeline se mantiene igual
training_gen = training_pipeline()
for log_message in training_gen:
yield log_message
await asyncio.sleep(0.1)
# --- Eventos de Ciclo de Vida de la App ---
@app.on_event("startup")
async def startup_event():
"""
Carga el modelo RoMa cuando el servidor se inicia.
"""
global roma_model
print("🚀 Iniciando servidor FastAPI...")
if torch.cuda.is_available():
device = "cuda:0"
print("✅ GPU detectada. Usando CUDA.")
else:
device = "cpu"
print("⚠️ No se detectó GPU. Usando CPU (puede ser muy lento).")
roma_model = roma_indoor(device=device)
roma_model.upsample_preds = False
roma_model.symmetric = False
print("🤖 Modelo RoMa cargado y listo.")
# --- Endpoints de la API ---
@app.post("/preprocess", response_model=PreprocessResponse)
async def preprocess_video(
num_ref_views: int = Body(16, embed=True, description="Número de vistas de referencia a extraer del video."),
video: UploadFile = File(..., description="Archivo de video a procesar (.mp4, .mov).")
):
"""
Recibe un video, lo preprocesa (extrae frames + COLMAP) y prepara para el entrenamiento.
"""
if not video.filename.lower().endswith(('.mp4', '.avi', '.mov')):
raise HTTPException(status_code=400, detail="Formato de archivo no soportado. Usa .mp4, .avi, o .mov.")
# Guarda el video temporalmente para que la librería pueda procesarlo
with tempfile.NamedTemporaryFile(delete=False, suffix=video.filename) as tmp_video:
shutil.copyfileobj(video.file, tmp_video)
tmp_video_path = tmp_video.name
try:
loop = asyncio.get_running_loop()
# Ejecuta la función síncrona y bloqueante en un executor para no bloquear el servidor
scene_dir, selected_frames = await loop.run_in_executor(
None, run_preprocessing_sync, tmp_video_path, num_ref_views
)
# Genera un ID único para esta tarea y guarda la ruta
task_id = str(uuid.uuid4())
tasks_db[task_id] = {
"scene_dir": scene_dir,
"num_ref_views": len(selected_frames),
"result_ply_path": None
}
return JSONResponse(
status_code=200,
content={
"task_id": task_id,
"message": f"Preprocesamiento completado. Se generó el directorio de la escena. Listo para entrenar.",
"selected_frames_count": len(selected_frames)
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error durante el preprocesamiento: {e}")
finally:
os.unlink(tmp_video_path) # Limpia el archivo de video temporal
@app.post("/train/{task_id}")
async def train_model(task_id: str, params: TrainParams):
"""
Inicia el entrenamiento para una tarea preprocesada.
Devuelve un stream de logs en tiempo real.
"""
if task_id not in tasks_db:
raise HTTPException(status_code=404, detail="Task ID no encontrado. Por favor, ejecuta el preprocesamiento primero.")
task_info = tasks_db[task_id]
scene_dir = task_info["scene_dir"]
num_ref_views = task_info["num_ref_views"]
return StreamingResponse(
training_log_generator(scene_dir, num_ref_views, params, task_id),
media_type="text/event-stream"
)
@app.get("/download/{task_id}")
async def download_ply_file(task_id: str):
"""
Permite descargar el archivo .ply resultante de un entrenamiento completado.
"""
if task_id not in tasks_db:
raise HTTPException(status_code=404, detail="Task ID no encontrado.")
task_info = tasks_db[task_id]
ply_path = task_info.get("result_ply_path")
if not ply_path:
raise HTTPException(status_code=404, detail="El entrenamiento no ha finalizado o el archivo aún no está disponible.")
if not os.path.exists(ply_path):
raise HTTPException(status_code=500, detail="Error: El archivo del modelo no se encuentra en el servidor.")
# Generamos un nombre de archivo amigable para el usuario
file_name = f"model_{task_id[:8]}.ply"
return FileResponse(
path=ply_path,
media_type='application/octet-stream',
filename=file_name
)
if __name__ == "__main__":
import uvicorn
# Para ejecutar: uvicorn main:app --reload
# El flag --reload es para desarrollo. Quítalo en producción.
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False) |