Jaocs commited on
Commit
ae6c706
1 Parent(s): c16a263

restore old version

Browse files
Files changed (1) hide show
  1. main.py +33 -54
main.py CHANGED
@@ -24,7 +24,6 @@ try:
24
  import sys
25
  sys.path.append('../submodules/RoMa') # Ajusta esta ruta si es necesario
26
  from romatch import roma_indoor
27
- import trimesh # <-- A脩ADIDO: Importaci贸n necesaria para la conversi贸n
28
  except ImportError as e:
29
  print(f"Error: No se pudieron importar los m贸dulos del proyecto EDGS. Aseg煤rate de que las rutas y la instalaci贸n son correctas. {e}")
30
  sys.exit(1)
@@ -33,12 +32,15 @@ except ImportError as e:
33
  # 1. Inicializaci贸n de la App FastAPI
34
  app = FastAPI(
35
  title="EDGS Training API",
36
- description="Una API para preprocesar videos, entrenar modelos 3DGS con EDGS y exportar a GLB.",
37
- version="1.1.0"
38
  )
39
 
40
  # 2. Variables Globales y Almacenamiento de Estado
 
41
  roma_model = None
 
 
42
  tasks_db = {}
43
 
44
  # 3. Modelos Pydantic para la validaci贸n de datos
@@ -50,20 +52,11 @@ class PreprocessResponse(BaseModel):
50
  task_id: str
51
  message: str
52
  selected_frames_count: int
 
 
53
 
54
  # --- L贸gica de Negocio (Adaptada del script de Gradio) ---
55
 
56
- def convert_ply_to_glb(ply_path: str) -> str:
57
- """
58
- Carga el PLY con trimesh y lo exporta como GLB (glTF binario).
59
- """
60
- # Generar ruta .glb basada en .ply
61
- glb_path = os.path.splitext(ply_path)[0] + ".glb"
62
- # Cargar y exportar
63
- mesh = trimesh.load(ply_path, force='mesh')
64
- mesh.export(glb_path)
65
- return glb_path
66
-
67
  # Esta funci贸n se ejecutar谩 en un hilo separado para no bloquear el servidor
68
  def run_preprocessing_sync(input_path: str, num_ref_views: int):
69
  """
@@ -87,10 +80,12 @@ def run_preprocessing_sync(input_path: str, num_ref_views: int):
87
 
88
  async def training_log_generator(scene_dir: str, num_ref_views: int, params: TrainParams, task_id: str):
89
  """
90
- Un generador as铆ncrono que ejecuta el entrenamiento y la conversi贸n a GLB.
 
91
  """
92
  def training_pipeline():
93
  try:
 
94
  with initialize(config_path="./configs", version_base="1.1"):
95
  cfg = compose(config_name="train")
96
 
@@ -123,30 +118,32 @@ async def training_log_generator(scene_dir: str, num_ref_views: int, params: Tra
123
  trainer.evaluate_iterations = []
124
  trainer.timer.start()
125
 
 
126
  yield "data: Inicializando modelo...\n\n"
127
  trainer.init_with_corr(cfg.init_wC, roma_model=roma_model)
128
 
 
129
  for step in range(int(params.num_steps // 10)):
130
  cfg.train.gs_epochs = 10
 
131
  trainer.train(cfg.train)
 
 
 
132
  yield f"data: Progreso: {step*10+10}/{params.num_steps} pasos completados.\n\n"
133
 
134
  trainer.save_model()
135
  ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply")
136
 
137
- # --- CAMBIO CLAVE: Conversi贸n a GLB y almacenamiento de ambas rutas ---
138
- yield "data: Convirtiendo modelo a formato GLB...\n\n"
139
- glb_path = convert_ply_to_glb(ply_path)
140
-
141
  tasks_db[task_id]['result_ply_path'] = ply_path
142
- tasks_db[task_id]['result_glb_path'] = glb_path
143
-
144
- final_message = "Entrenamiento y conversi贸n completados. El modelo est谩 listo para descargar."
145
  yield f"data: {final_message}\n\n"
146
 
147
  except Exception as e:
148
  yield f"data: ERROR: {repr(e)}\n\n"
149
 
 
150
  training_gen = training_pipeline()
151
  for log_message in training_gen:
152
  yield log_message
@@ -186,22 +183,24 @@ async def preprocess_video(
186
  if not video.filename.lower().endswith(('.mp4', '.avi', '.mov')):
187
  raise HTTPException(status_code=400, detail="Formato de archivo no soportado. Usa .mp4, .avi, o .mov.")
188
 
 
189
  with tempfile.NamedTemporaryFile(delete=False, suffix=video.filename) as tmp_video:
190
  shutil.copyfileobj(video.file, tmp_video)
191
  tmp_video_path = tmp_video.name
192
 
193
  try:
194
  loop = asyncio.get_running_loop()
 
195
  scene_dir, selected_frames = await loop.run_in_executor(
196
  None, run_preprocessing_sync, tmp_video_path, num_ref_views
197
  )
198
 
 
199
  task_id = str(uuid.uuid4())
200
  tasks_db[task_id] = {
201
  "scene_dir": scene_dir,
202
  "num_ref_views": len(selected_frames),
203
- "result_ply_path": None,
204
- "result_glb_path": None # <-- A脩ADIDO: Inicializar ruta GLB
205
  }
206
 
207
  return JSONResponse(
@@ -215,15 +214,17 @@ async def preprocess_video(
215
  except Exception as e:
216
  raise HTTPException(status_code=500, detail=f"Error durante el preprocesamiento: {e}")
217
  finally:
218
- os.unlink(tmp_video_path)
 
219
 
220
  @app.post("/train/{task_id}")
221
  async def train_model(task_id: str, params: TrainParams):
222
  """
223
- Inicia el entrenamiento para una tarea preprocesada. Devuelve un stream de logs.
 
224
  """
225
  if task_id not in tasks_db:
226
- raise HTTPException(status_code=404, detail="Task ID no encontrado. Ejecuta el preprocesamiento primero.")
227
 
228
  task_info = tasks_db[task_id]
229
  scene_dir = task_info["scene_dir"]
@@ -234,7 +235,7 @@ async def train_model(task_id: str, params: TrainParams):
234
  media_type="text/event-stream"
235
  )
236
 
237
- @app.get("/download-ply/{task_id}")
238
  async def download_ply_file(task_id: str):
239
  """
240
  Permite descargar el archivo .ply resultante de un entrenamiento completado.
@@ -246,42 +247,20 @@ async def download_ply_file(task_id: str):
246
  ply_path = task_info.get("result_ply_path")
247
 
248
  if not ply_path:
249
- raise HTTPException(status_code=404, detail="El entrenamiento no ha finalizado o el archivo PLY a煤n no est谩 disponible.")
250
 
251
  if not os.path.exists(ply_path):
252
- raise HTTPException(status_code=500, detail="Error: El archivo del modelo PLY no se encuentra en el servidor.")
253
 
 
254
  file_name = f"model_{task_id[:8]}.ply"
 
255
  return FileResponse(
256
  path=ply_path,
257
  media_type='application/octet-stream',
258
  filename=file_name
259
  )
260
 
261
- @app.get("/download-glb/{task_id}")
262
- async def download_glb_file(task_id: str):
263
- """
264
- Permite descargar el archivo .glb resultante de un entrenamiento completado.
265
- """
266
- if task_id not in tasks_db:
267
- raise HTTPException(status_code=404, detail="Task ID no encontrado.")
268
-
269
- task_info = tasks_db[task_id]
270
- glb_path = task_info.get("result_glb_path")
271
-
272
- if not glb_path:
273
- raise HTTPException(status_code=404, detail="El entrenamiento no ha finalizado o el archivo GLB a煤n no est谩 disponible.")
274
-
275
- if not os.path.exists(glb_path):
276
- raise HTTPException(status_code=500, detail="Error: El archivo del modelo GLB no se encuentra en el servidor.")
277
-
278
- file_name = f"model_{task_id[:8]}.glb"
279
- return FileResponse(
280
- path=glb_path,
281
- media_type='model/gltf-binary',
282
- filename=file_name
283
- )
284
-
285
  if __name__ == "__main__":
286
  import uvicorn
287
  # Para ejecutar: uvicorn main:app --reload
 
24
  import sys
25
  sys.path.append('../submodules/RoMa') # Ajusta esta ruta si es necesario
26
  from romatch import roma_indoor
 
27
  except ImportError as e:
28
  print(f"Error: No se pudieron importar los m贸dulos del proyecto EDGS. Aseg煤rate de que las rutas y la instalaci贸n son correctas. {e}")
29
  sys.exit(1)
 
32
  # 1. Inicializaci贸n de la App FastAPI
33
  app = FastAPI(
34
  title="EDGS Training API",
35
+ description="Una API para preprocesar videos y entrenar modelos 3DGS con EDGS.",
36
+ version="1.0.0"
37
  )
38
 
39
  # 2. Variables Globales y Almacenamiento de Estado
40
+ # El modelo se cargar谩 en el evento 'startup'
41
  roma_model = None
42
+
43
+ # Base de datos en memoria para gestionar el estado de las tareas entre endpoints
44
  tasks_db = {}
45
 
46
  # 3. Modelos Pydantic para la validaci贸n de datos
 
52
  task_id: str
53
  message: str
54
  selected_frames_count: int
55
+ # Opcional: podr铆as devolver las im谩genes en base64 si el cliente las necesita visualizar
56
+ # frames: list[str]
57
 
58
  # --- L贸gica de Negocio (Adaptada del script de Gradio) ---
59
 
 
 
 
 
 
 
 
 
 
 
 
60
  # Esta funci贸n se ejecutar谩 en un hilo separado para no bloquear el servidor
61
  def run_preprocessing_sync(input_path: str, num_ref_views: int):
62
  """
 
80
 
81
  async def training_log_generator(scene_dir: str, num_ref_views: int, params: TrainParams, task_id: str):
82
  """
83
+ Un generador as铆ncrono que ejecuta el entrenamiento. Los logs detallados se muestran
84
+ en la terminal del servidor, mientras que el cliente recibe un stream de progreso simple.
85
  """
86
  def training_pipeline():
87
  try:
88
+ # La inicializaci贸n y configuraci贸n de Hydra se mantienen igual
89
  with initialize(config_path="./configs", version_base="1.1"):
90
  cfg = compose(config_name="train")
91
 
 
118
  trainer.evaluate_iterations = []
119
  trainer.timer.start()
120
 
121
+ # Mensaje de progreso para el cliente antes de la inicializaci贸n
122
  yield "data: Inicializando modelo...\n\n"
123
  trainer.init_with_corr(cfg.init_wC, roma_model=roma_model)
124
 
125
+ # El bucle de entrenamiento principal
126
  for step in range(int(params.num_steps // 10)):
127
  cfg.train.gs_epochs = 10
128
+ # trainer.train() ahora imprimir谩 sus logs detallados directamente en la terminal
129
  trainer.train(cfg.train)
130
+
131
+ # --- CAMBIO CLAVE ---
132
+ # Env铆a un mensaje de progreso simple al cliente en lugar de los logs capturados.
133
  yield f"data: Progreso: {step*10+10}/{params.num_steps} pasos completados.\n\n"
134
 
135
  trainer.save_model()
136
  ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply")
137
 
 
 
 
 
138
  tasks_db[task_id]['result_ply_path'] = ply_path
139
+
140
+ final_message = "Entrenamiento completado. El modelo est谩 listo para descargar."
 
141
  yield f"data: {final_message}\n\n"
142
 
143
  except Exception as e:
144
  yield f"data: ERROR: {repr(e)}\n\n"
145
 
146
+ # El bucle que llama a la pipeline se mantiene igual
147
  training_gen = training_pipeline()
148
  for log_message in training_gen:
149
  yield log_message
 
183
  if not video.filename.lower().endswith(('.mp4', '.avi', '.mov')):
184
  raise HTTPException(status_code=400, detail="Formato de archivo no soportado. Usa .mp4, .avi, o .mov.")
185
 
186
+ # Guarda el video temporalmente para que la librer铆a pueda procesarlo
187
  with tempfile.NamedTemporaryFile(delete=False, suffix=video.filename) as tmp_video:
188
  shutil.copyfileobj(video.file, tmp_video)
189
  tmp_video_path = tmp_video.name
190
 
191
  try:
192
  loop = asyncio.get_running_loop()
193
+ # Ejecuta la funci贸n s铆ncrona y bloqueante en un executor para no bloquear el servidor
194
  scene_dir, selected_frames = await loop.run_in_executor(
195
  None, run_preprocessing_sync, tmp_video_path, num_ref_views
196
  )
197
 
198
+ # Genera un ID 煤nico para esta tarea y guarda la ruta
199
  task_id = str(uuid.uuid4())
200
  tasks_db[task_id] = {
201
  "scene_dir": scene_dir,
202
  "num_ref_views": len(selected_frames),
203
+ "result_ply_path": None
 
204
  }
205
 
206
  return JSONResponse(
 
214
  except Exception as e:
215
  raise HTTPException(status_code=500, detail=f"Error durante el preprocesamiento: {e}")
216
  finally:
217
+ os.unlink(tmp_video_path) # Limpia el archivo de video temporal
218
+
219
 
220
  @app.post("/train/{task_id}")
221
  async def train_model(task_id: str, params: TrainParams):
222
  """
223
+ Inicia el entrenamiento para una tarea preprocesada.
224
+ Devuelve un stream de logs en tiempo real.
225
  """
226
  if task_id not in tasks_db:
227
+ raise HTTPException(status_code=404, detail="Task ID no encontrado. Por favor, ejecuta el preprocesamiento primero.")
228
 
229
  task_info = tasks_db[task_id]
230
  scene_dir = task_info["scene_dir"]
 
235
  media_type="text/event-stream"
236
  )
237
 
238
+ @app.get("/download/{task_id}")
239
  async def download_ply_file(task_id: str):
240
  """
241
  Permite descargar el archivo .ply resultante de un entrenamiento completado.
 
247
  ply_path = task_info.get("result_ply_path")
248
 
249
  if not ply_path:
250
+ raise HTTPException(status_code=404, detail="El entrenamiento no ha finalizado o el archivo a煤n no est谩 disponible.")
251
 
252
  if not os.path.exists(ply_path):
253
+ raise HTTPException(status_code=500, detail="Error: El archivo del modelo no se encuentra en el servidor.")
254
 
255
+ # Generamos un nombre de archivo amigable para el usuario
256
  file_name = f"model_{task_id[:8]}.ply"
257
+
258
  return FileResponse(
259
  path=ply_path,
260
  media_type='application/octet-stream',
261
  filename=file_name
262
  )
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  if __name__ == "__main__":
265
  import uvicorn
266
  # Para ejecutar: uvicorn main:app --reload