Spaces:
Runtime error
Runtime error
| import asyncio | |
| import os | |
| import signal | |
| import sqlite3 | |
| from contextlib import asynccontextmanager | |
| import psutil | |
| from fastapi import FastAPI | |
| from loguru import logger | |
| from competitions.utils import run_evaluation | |
| def get_process_status(pid): | |
| try: | |
| process = psutil.Process(pid) | |
| proc_status = process.status() | |
| return proc_status | |
| except psutil.NoSuchProcess: | |
| logger.info(f"No process found with PID: {pid}") | |
| return "Completed" | |
| def kill_process_by_pid(pid): | |
| """Kill process by PID.""" | |
| os.kill(pid, signal.SIGTERM) | |
| class JobDB: | |
| def __init__(self, db_path): | |
| self.db_path = db_path | |
| self.conn = sqlite3.connect(db_path) | |
| self.c = self.conn.cursor() | |
| self.create_jobs_table() | |
| def create_jobs_table(self): | |
| self.c.execute( | |
| """CREATE TABLE IF NOT EXISTS jobs | |
| (id INTEGER PRIMARY KEY, pid INTEGER)""" | |
| ) | |
| self.conn.commit() | |
| def add_job(self, pid): | |
| sql = f"INSERT INTO jobs (pid) VALUES ({pid})" | |
| self.c.execute(sql) | |
| self.conn.commit() | |
| def get_running_jobs(self): | |
| self.c.execute("""SELECT pid FROM jobs""") | |
| running_pids = self.c.fetchall() | |
| running_pids = [pid[0] for pid in running_pids] | |
| return running_pids | |
| def delete_job(self, pid): | |
| sql = f"DELETE FROM jobs WHERE pid={pid}" | |
| self.c.execute(sql) | |
| self.conn.commit() | |
| PARAMS = os.environ.get("PARAMS") | |
| DB = JobDB("job.db") | |
| class BackgroundRunner: | |
| async def run_main(self): | |
| while True: | |
| running_jobs = DB.get_running_jobs() | |
| if running_jobs: | |
| for _pid in running_jobs: | |
| proc_status = get_process_status(_pid) | |
| proc_status = proc_status.strip().lower() | |
| if proc_status in ("completed", "error", "zombie"): | |
| logger.info(f"Process {_pid} is already completed. Skipping...") | |
| try: | |
| kill_process_by_pid(_pid) | |
| except Exception as e: | |
| logger.info(f"Error while killing process: {e}") | |
| DB.delete_job(_pid) | |
| running_jobs = DB.get_running_jobs() | |
| if not running_jobs: | |
| logger.info("No running jobs found. Shutting down the server.") | |
| os.kill(os.getpid(), signal.SIGINT) | |
| await asyncio.sleep(30) | |
| runner = BackgroundRunner() | |
| async def lifespan(app: FastAPI): | |
| process_pid = run_evaluation(params=PARAMS) | |
| logger.info(f"Started training with PID {process_pid}") | |
| DB.add_job(process_pid) | |
| asyncio.create_task(runner.run_main()) | |
| yield | |
| api = FastAPI(lifespan=lifespan) | |
| async def root(): | |
| return "Your model is being evaluated..." | |
| async def health(): | |
| return "OK" | |