|
import torch |
|
import os |
|
import shutil |
|
import tempfile |
|
import uuid |
|
import asyncio |
|
import io |
|
import time |
|
import contextlib |
|
import base64 |
|
from PIL import Image |
|
import numpy as np |
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Body |
|
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse |
|
from pydantic import BaseModel, Field |
|
|
|
try: |
|
from source.utils_aux import set_seed |
|
from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene |
|
from source.trainer import EDGSTrainer |
|
from hydra import initialize, compose |
|
import hydra |
|
from source.visualization import generate_fully_smooth_cameras_with_tsp, put_text_on_image |
|
import sys |
|
sys.path.append('../submodules/RoMa') |
|
from romatch import roma_indoor |
|
except ImportError as e: |
|
print(f"Error: No se pudieron importar los módulos del proyecto EDGS. Asegúrate de que las rutas y la instalación son correctas. {e}") |
|
sys.exit(1) |
|
|
|
|
|
|
|
app = FastAPI( |
|
title="EDGS Training API", |
|
description="Una API para preprocesar videos y entrenar modelos 3DGS con EDGS.", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
|
|
roma_model = None |
|
|
|
|
|
tasks_db = {} |
|
|
|
|
|
class TrainParams(BaseModel): |
|
num_corrs_per_view: int = Field(20000, gt=0, description="Correspondencias por vista de referencia.") |
|
num_steps: int = Field(1000, gt=0, description="Número de pasos de optimización.") |
|
|
|
class PreprocessResponse(BaseModel): |
|
task_id: str |
|
message: str |
|
selected_frames_count: int |
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_preprocessing_sync(input_path: str, num_ref_views: int): |
|
""" |
|
Ejecuta el preprocesamiento: selección de frames y ejecución de COLMAP. |
|
""" |
|
tmpdirname = tempfile.mkdtemp() |
|
scene_dir = os.path.join(tmpdirname, "scene") |
|
os.makedirs(scene_dir, exist_ok=True) |
|
|
|
|
|
frames = read_video_frames(video_input=input_path, max_size=1024) |
|
frames_scores = preprocess_frames(frames) |
|
selected_frames_indices = select_optimal_frames(scores=frames_scores, k=min(num_ref_views, len(frames))) |
|
selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices] |
|
|
|
|
|
save_frames_to_scene_dir(frames=selected_frames, scene_dir=scene_dir) |
|
run_colmap_on_scene(scene_dir) |
|
|
|
return scene_dir, selected_frames |
|
|
|
async def training_log_generator(scene_dir: str, num_ref_views: int, params: TrainParams, task_id: str): |
|
""" |
|
Un generador asíncrono que ejecuta el entrenamiento. Los logs detallados se muestran |
|
en la terminal del servidor, mientras que el cliente recibe un stream de progreso simple. |
|
""" |
|
def training_pipeline(): |
|
try: |
|
|
|
with initialize(config_path="./configs", version_base="1.1"): |
|
cfg = compose(config_name="train") |
|
|
|
|
|
scene_name = os.path.basename(scene_dir) |
|
model_output_dir = f"./outputs/{scene_name}_trained" |
|
cfg.wandb.mode = "disabled" |
|
cfg.gs.dataset.model_path = model_output_dir |
|
cfg.gs.dataset.source_path = scene_dir |
|
cfg.gs.dataset.images = "images" |
|
cfg.train.gs_epochs = 30000 |
|
cfg.gs.opt.opacity_reset_interval = 1_000_000 |
|
cfg.train.reduce_opacity = True |
|
cfg.train.no_densify = True |
|
cfg.train.max_lr = True |
|
cfg.init_wC.use = True |
|
cfg.init_wC.matches_per_ref = params.num_corrs_per_view |
|
cfg.init_wC.nns_per_ref = 1 |
|
cfg.init_wC.num_refs = num_ref_views |
|
cfg.init_wC.add_SfM_init = False |
|
cfg.init_wC.scaling_factor = 0.00077 * 2. |
|
|
|
set_seed(cfg.seed) |
|
os.makedirs(cfg.gs.dataset.model_path, exist_ok=True) |
|
|
|
device = cfg.device |
|
generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False) |
|
trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=device, log_wandb=False) |
|
trainer.saving_iterations = [] |
|
trainer.evaluate_iterations = [] |
|
trainer.timer.start() |
|
|
|
|
|
yield "data: Inicializando modelo...\n\n" |
|
trainer.init_with_corr(cfg.init_wC, roma_model=roma_model) |
|
|
|
|
|
for step in range(int(params.num_steps // 10)): |
|
cfg.train.gs_epochs = 10 |
|
|
|
trainer.train(cfg.train) |
|
|
|
|
|
|
|
yield f"data: Progreso: {step*10+10}/{params.num_steps} pasos completados.\n\n" |
|
|
|
trainer.save_model() |
|
ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply") |
|
|
|
tasks_db[task_id]['result_ply_path'] = ply_path |
|
|
|
final_message = "Entrenamiento completado. El modelo está listo para descargar." |
|
yield f"data: {final_message}\n\n" |
|
|
|
except Exception as e: |
|
yield f"data: ERROR: {repr(e)}\n\n" |
|
|
|
|
|
training_gen = training_pipeline() |
|
for log_message in training_gen: |
|
yield log_message |
|
await asyncio.sleep(0.1) |
|
|
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
""" |
|
Carga el modelo RoMa cuando el servidor se inicia. |
|
""" |
|
global roma_model |
|
print("🚀 Iniciando servidor FastAPI...") |
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
print("✅ GPU detectada. Usando CUDA.") |
|
else: |
|
device = "cpu" |
|
print("⚠️ No se detectó GPU. Usando CPU (puede ser muy lento).") |
|
|
|
roma_model = roma_indoor(device=device) |
|
roma_model.upsample_preds = False |
|
roma_model.symmetric = False |
|
print("🤖 Modelo RoMa cargado y listo.") |
|
|
|
|
|
|
|
@app.post("/preprocess", response_model=PreprocessResponse) |
|
async def preprocess_video( |
|
num_ref_views: int = Body(16, embed=True, description="Número de vistas de referencia a extraer del video."), |
|
video: UploadFile = File(..., description="Archivo de video a procesar (.mp4, .mov).") |
|
): |
|
""" |
|
Recibe un video, lo preprocesa (extrae frames + COLMAP) y prepara para el entrenamiento. |
|
""" |
|
if not video.filename.lower().endswith(('.mp4', '.avi', '.mov')): |
|
raise HTTPException(status_code=400, detail="Formato de archivo no soportado. Usa .mp4, .avi, o .mov.") |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=video.filename) as tmp_video: |
|
shutil.copyfileobj(video.file, tmp_video) |
|
tmp_video_path = tmp_video.name |
|
|
|
try: |
|
loop = asyncio.get_running_loop() |
|
|
|
scene_dir, selected_frames = await loop.run_in_executor( |
|
None, run_preprocessing_sync, tmp_video_path, num_ref_views |
|
) |
|
|
|
|
|
task_id = str(uuid.uuid4()) |
|
tasks_db[task_id] = { |
|
"scene_dir": scene_dir, |
|
"num_ref_views": len(selected_frames), |
|
"result_ply_path": None |
|
} |
|
|
|
return JSONResponse( |
|
status_code=200, |
|
content={ |
|
"task_id": task_id, |
|
"message": f"Preprocesamiento completado. Se generó el directorio de la escena. Listo para entrenar.", |
|
"selected_frames_count": len(selected_frames) |
|
} |
|
) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error durante el preprocesamiento: {e}") |
|
finally: |
|
os.unlink(tmp_video_path) |
|
|
|
|
|
@app.post("/train/{task_id}") |
|
async def train_model(task_id: str, params: TrainParams): |
|
""" |
|
Inicia el entrenamiento para una tarea preprocesada. |
|
Devuelve un stream de logs en tiempo real. |
|
""" |
|
if task_id not in tasks_db: |
|
raise HTTPException(status_code=404, detail="Task ID no encontrado. Por favor, ejecuta el preprocesamiento primero.") |
|
|
|
task_info = tasks_db[task_id] |
|
scene_dir = task_info["scene_dir"] |
|
num_ref_views = task_info["num_ref_views"] |
|
|
|
return StreamingResponse( |
|
training_log_generator(scene_dir, num_ref_views, params, task_id), |
|
media_type="text/event-stream" |
|
) |
|
|
|
@app.get("/download/{task_id}") |
|
async def download_ply_file(task_id: str): |
|
""" |
|
Permite descargar el archivo .ply resultante de un entrenamiento completado. |
|
""" |
|
if task_id not in tasks_db: |
|
raise HTTPException(status_code=404, detail="Task ID no encontrado.") |
|
|
|
task_info = tasks_db[task_id] |
|
ply_path = task_info.get("result_ply_path") |
|
|
|
if not ply_path: |
|
raise HTTPException(status_code=404, detail="El entrenamiento no ha finalizado o el archivo aún no está disponible.") |
|
|
|
if not os.path.exists(ply_path): |
|
raise HTTPException(status_code=500, detail="Error: El archivo del modelo no se encuentra en el servidor.") |
|
|
|
|
|
file_name = f"model_{task_id[:8]}.ply" |
|
|
|
return FileResponse( |
|
path=ply_path, |
|
media_type='application/octet-stream', |
|
filename=file_name |
|
) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
|
|
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False) |