Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request, UploadFile, Form | |
from fastapi.responses import RedirectResponse, HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
import uuid | |
import shutil | |
import os | |
import cv2 | |
from pipeline.process_session import process_session_image | |
from database.db import init_db | |
from database.crud import get_session, delete_session | |
from models.sam import SamWrapper | |
from models.dino import DinoWrapper | |
from huggingface_hub import hf_hub_download | |
# --- Init app and database | |
app = FastAPI() | |
init_db() | |
# --- Ensure necessary folders exist | |
os.makedirs("uploads", exist_ok=True) | |
os.makedirs("outputs", exist_ok=True) | |
# --- Static and templates | |
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
templates = Jinja2Templates(directory="templates") | |
# --- Download checkpoint from model-repo | |
checkpoint_path = hf_hub_download(repo_id="stkrk/sam-vit-b-checkpoint", filename="sam_vit_b_01ec64.pth") | |
# --- Model initialization (once) | |
sam = SamWrapper( | |
model_type="vit_b", | |
checkpoint_path=checkpoint_path | |
) | |
dino = DinoWrapper() | |
def index(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request}) | |
def show_results(request: Request, session_id: str): | |
session = get_session(session_id) | |
if not session: | |
return templates.TemplateResponse("done.html", {"request": request, "message": "Session not found."}) | |
return templates.TemplateResponse("results.html", { | |
"request": request, | |
"session_id": session_id, | |
"result_paths": session["result_paths"] | |
}) | |
def process_image(request: Request, image: UploadFile = Form(...), prompt: str = Form(...)): | |
# 1. Save uploaded image | |
session_id = uuid.uuid4().hex | |
save_dir = "uploads" | |
os.makedirs(save_dir, exist_ok=True) | |
image_path = os.path.join(save_dir, f"{session_id}_{image.filename}") | |
with open(image_path, "wb") as buffer: | |
shutil.copyfileobj(image.file, buffer) | |
# 2. Run main pipeline | |
process_session_image( | |
session_id=session_id, | |
image_path=image_path, | |
prompt_text=prompt, | |
sam_wrapper=sam, | |
dino_wrapper=dino | |
) | |
# 3. Redirect to results page | |
return RedirectResponse(f"/results/{session_id}", status_code=303) | |
async def finalize_selection( | |
request: Request, | |
session_id: str = Form(...), | |
selected: list[str] = Form(default=[]) | |
): | |
session = get_session(session_id) | |
if not session: | |
return templates.TemplateResponse("ready.html", {"request": request, "message": "Session not found."}) | |
# Remove all the rest of PNGs | |
for path in session["result_paths"]: | |
if path not in selected and os.path.exists(path): | |
os.remove(path) | |
# Remove all closed session | |
delete_session(session_id) | |
return templates.TemplateResponse("ready.html", { | |
"request": request, | |
"message": f"Saved {len(selected)} file(s). Session {session_id} closed." | |
}) | |