File size: 10,953 Bytes
c096a7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae6c706
 
c096a7a
 
 
ae6c706
c096a7a
ae6c706
 
c096a7a
 
 
 
 
 
 
 
 
 
 
ae6c706
 
c096a7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae6c706
 
c096a7a
 
 
ae6c706
c096a7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae6c706
c096a7a
 
 
ae6c706
c096a7a
 
ae6c706
c096a7a
ae6c706
 
 
c096a7a
 
 
 
 
c16a263
ae6c706
 
c096a7a
 
 
 
 
ae6c706
c096a7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae6c706
c096a7a
 
 
 
 
 
ae6c706
c096a7a
 
 
 
ae6c706
c096a7a
 
 
 
ae6c706
c096a7a
 
 
 
 
 
 
 
 
 
 
 
 
ae6c706
 
c096a7a
 
 
 
ae6c706
 
c096a7a
 
ae6c706
c096a7a
 
 
 
 
 
 
 
 
 
ae6c706
c096a7a
 
 
 
 
 
 
 
 
 
 
ae6c706
c096a7a
 
ae6c706
c096a7a
ae6c706
c096a7a
ae6c706
c096a7a
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import torch
import os
import shutil
import tempfile
import uuid
import asyncio
import io
import time
import contextlib
import base64
from PIL import Image
import numpy as np
from fastapi import FastAPI, UploadFile, File, HTTPException, Body
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
from pydantic import BaseModel, Field

try:
    from source.utils_aux import set_seed
    from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene
    from source.trainer import EDGSTrainer
    from hydra import initialize, compose
    import hydra
    from source.visualization import generate_fully_smooth_cameras_with_tsp, put_text_on_image
    import sys
    sys.path.append('../submodules/RoMa') # Ajusta esta ruta si es necesario
    from romatch import roma_indoor
except ImportError as e:
    print(f"Error: No se pudieron importar los módulos del proyecto EDGS. Asegúrate de que las rutas y la instalación son correctas. {e}")
    sys.exit(1)

# --- Configuración Inicial ---
# 1. Inicialización de la App FastAPI
app = FastAPI(
    title="EDGS Training API",
    description="Una API para preprocesar videos y entrenar modelos 3DGS con EDGS.",
    version="1.0.0"
)

# 2. Variables Globales y Almacenamiento de Estado
# El modelo se cargará en el evento 'startup'
roma_model = None 

# Base de datos en memoria para gestionar el estado de las tareas entre endpoints
tasks_db = {} 

# 3. Modelos Pydantic para la validación de datos
class TrainParams(BaseModel):
    num_corrs_per_view: int = Field(20000, gt=0, description="Correspondencias por vista de referencia.")
    num_steps: int = Field(1000, gt=0, description="Número de pasos de optimización.")

class PreprocessResponse(BaseModel):
    task_id: str
    message: str
    selected_frames_count: int
    # Opcional: podrías devolver las imágenes en base64 si el cliente las necesita visualizar
    # frames: list[str] 

# --- Lógica de Negocio (Adaptada del script de Gradio) ---

# Esta función se ejecutará en un hilo separado para no bloquear el servidor
def run_preprocessing_sync(input_path: str, num_ref_views: int):
    """
    Ejecuta el preprocesamiento: selección de frames y ejecución de COLMAP.
    """
    tmpdirname = tempfile.mkdtemp()
    scene_dir = os.path.join(tmpdirname, "scene")
    os.makedirs(scene_dir, exist_ok=True)
    
    # 1. Lee y selecciona los mejores frames
    frames = read_video_frames(video_input=input_path, max_size=1024)
    frames_scores = preprocess_frames(frames)
    selected_frames_indices = select_optimal_frames(scores=frames_scores, k=min(num_ref_views, len(frames)))
    selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices]
    
    # 2. Guarda los frames y ejecuta COLMAP
    save_frames_to_scene_dir(frames=selected_frames, scene_dir=scene_dir)
    run_colmap_on_scene(scene_dir)
    
    return scene_dir, selected_frames

async def training_log_generator(scene_dir: str, num_ref_views: int, params: TrainParams, task_id: str):
    """
    Un generador asíncrono que ejecuta el entrenamiento. Los logs detallados se muestran
    en la terminal del servidor, mientras que el cliente recibe un stream de progreso simple.
    """
    def training_pipeline():
        try:
            # La inicialización y configuración de Hydra se mantienen igual
            with initialize(config_path="./configs", version_base="1.1"):
                cfg = compose(config_name="train")

            # --- CONFIGURACIÓN COMPLETA ---
            scene_name = os.path.basename(scene_dir)
            model_output_dir = f"./outputs/{scene_name}_trained"
            cfg.wandb.mode = "disabled"
            cfg.gs.dataset.model_path = model_output_dir
            cfg.gs.dataset.source_path = scene_dir
            cfg.gs.dataset.images = "images"
            cfg.train.gs_epochs = 30000                                 
            cfg.gs.opt.opacity_reset_interval = 1_000_000               
            cfg.train.reduce_opacity = True                             
            cfg.train.no_densify = True                                 
            cfg.train.max_lr = True                                     
            cfg.init_wC.use = True
            cfg.init_wC.matches_per_ref = params.num_corrs_per_view
            cfg.init_wC.nns_per_ref = 1                                 
            cfg.init_wC.num_refs = num_ref_views
            cfg.init_wC.add_SfM_init = False                            
            cfg.init_wC.scaling_factor = 0.00077 * 2.                   
            
            set_seed(cfg.seed)
            os.makedirs(cfg.gs.dataset.model_path, exist_ok=True)

            device = cfg.device                                         
            generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False)
            trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=device, log_wandb=False)
            trainer.saving_iterations = []
            trainer.evaluate_iterations = []
            trainer.timer.start()

            # Mensaje de progreso para el cliente antes de la inicialización
            yield "data: Inicializando modelo...\n\n"
            trainer.init_with_corr(cfg.init_wC, roma_model=roma_model)

            # El bucle de entrenamiento principal
            for step in range(int(params.num_steps // 10)):
                cfg.train.gs_epochs = 10
                # trainer.train() ahora imprimirá sus logs detallados directamente en la terminal
                trainer.train(cfg.train)
                
                # --- CAMBIO CLAVE ---
                # Envía un mensaje de progreso simple al cliente en lugar de los logs capturados.
                yield f"data: Progreso: {step*10+10}/{params.num_steps} pasos completados.\n\n"
            
            trainer.save_model()
            ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply")
            
            tasks_db[task_id]['result_ply_path'] = ply_path

            final_message = "Entrenamiento completado. El modelo está listo para descargar."
            yield f"data: {final_message}\n\n"

        except Exception as e:
            yield f"data: ERROR: {repr(e)}\n\n"
            
    # El bucle que llama a la pipeline se mantiene igual
    training_gen = training_pipeline()
    for log_message in training_gen:
        yield log_message
        await asyncio.sleep(0.1)

# --- Eventos de Ciclo de Vida de la App ---

@app.on_event("startup")
async def startup_event():
    """
    Carga el modelo RoMa cuando el servidor se inicia.
    """
    global roma_model
    print("🚀 Iniciando servidor FastAPI...")
    if torch.cuda.is_available():
        device = "cuda:0"
        print("✅ GPU detectada. Usando CUDA.")
    else:
        device = "cpu"
        print("⚠️ No se detectó GPU. Usando CPU (puede ser muy lento).")
    
    roma_model = roma_indoor(device=device)
    roma_model.upsample_preds = False
    roma_model.symmetric = False
    print("🤖 Modelo RoMa cargado y listo.")

# --- Endpoints de la API ---

@app.post("/preprocess", response_model=PreprocessResponse)
async def preprocess_video(
    num_ref_views: int = Body(16, embed=True, description="Número de vistas de referencia a extraer del video."),
    video: UploadFile = File(..., description="Archivo de video a procesar (.mp4, .mov).")
):
    """
    Recibe un video, lo preprocesa (extrae frames + COLMAP) y prepara para el entrenamiento.
    """
    if not video.filename.lower().endswith(('.mp4', '.avi', '.mov')):
        raise HTTPException(status_code=400, detail="Formato de archivo no soportado. Usa .mp4, .avi, o .mov.")

    # Guarda el video temporalmente para que la librería pueda procesarlo
    with tempfile.NamedTemporaryFile(delete=False, suffix=video.filename) as tmp_video:
        shutil.copyfileobj(video.file, tmp_video)
        tmp_video_path = tmp_video.name
    
    try:
        loop = asyncio.get_running_loop()
        # Ejecuta la función síncrona y bloqueante en un executor para no bloquear el servidor
        scene_dir, selected_frames = await loop.run_in_executor(
            None, run_preprocessing_sync, tmp_video_path, num_ref_views
        )
        
        # Genera un ID único para esta tarea y guarda la ruta
        task_id = str(uuid.uuid4())
        tasks_db[task_id] = {
            "scene_dir": scene_dir,
            "num_ref_views": len(selected_frames),
            "result_ply_path": None
        }
        
        return JSONResponse(
            status_code=200,
            content={
                "task_id": task_id,
                "message": f"Preprocesamiento completado. Se generó el directorio de la escena. Listo para entrenar.",
                "selected_frames_count": len(selected_frames)
            }
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error durante el preprocesamiento: {e}")
    finally:
        os.unlink(tmp_video_path) # Limpia el archivo de video temporal


@app.post("/train/{task_id}")
async def train_model(task_id: str, params: TrainParams):
    """
    Inicia el entrenamiento para una tarea preprocesada.
    Devuelve un stream de logs en tiempo real.
    """
    if task_id not in tasks_db:
        raise HTTPException(status_code=404, detail="Task ID no encontrado. Por favor, ejecuta el preprocesamiento primero.")
    
    task_info = tasks_db[task_id]
    scene_dir = task_info["scene_dir"]
    num_ref_views = task_info["num_ref_views"]
    
    return StreamingResponse(
        training_log_generator(scene_dir, num_ref_views, params, task_id),
        media_type="text/event-stream"
    )

@app.get("/download/{task_id}")
async def download_ply_file(task_id: str):
    """
    Permite descargar el archivo .ply resultante de un entrenamiento completado.
    """
    if task_id not in tasks_db:
        raise HTTPException(status_code=404, detail="Task ID no encontrado.")
    
    task_info = tasks_db[task_id]
    ply_path = task_info.get("result_ply_path")

    if not ply_path:
        raise HTTPException(status_code=404, detail="El entrenamiento no ha finalizado o el archivo aún no está disponible.")

    if not os.path.exists(ply_path):
        raise HTTPException(status_code=500, detail="Error: El archivo del modelo no se encuentra en el servidor.")

    # Generamos un nombre de archivo amigable para el usuario
    file_name = f"model_{task_id[:8]}.ply"

    return FileResponse(
        path=ply_path,
        media_type='application/octet-stream',
        filename=file_name
    )

if __name__ == "__main__":
    import uvicorn
    # Para ejecutar: uvicorn main:app --reload
    # El flag --reload es para desarrollo. Quítalo en producción.
    uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)