Spaces:
Sleeping
Sleeping
File size: 3,136 Bytes
aa1c1e5 6c8bcbc aa1c1e5 6c8bcbc aa1c1e5 6c8bcbc aa1c1e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from fastapi import FastAPI, Request, UploadFile, Form
from fastapi.responses import RedirectResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import uuid
import shutil
import os
import cv2
from pipeline.process_session import process_session_image
from database.db import init_db
from database.crud import get_session, delete_session
from models.sam import SamWrapper
from models.dino import DinoWrapper
from huggingface_hub import hf_hub_download
# --- Init app and database
app = FastAPI()
init_db()
# --- Static and templates
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# --- Download checkpoint from model-repo
checkpoint_path = hf_hub_download(repo_id="stkrk/sam-vit-b-checkpoint", filename="sam_vit_b_01ec64.pth")
# --- Model initialization (once)
sam = SamWrapper(
model_type="vit_b",
checkpoint_path=checkpoint_path
)
dino = DinoWrapper()
@app.get("/")
def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/results/{session_id}", response_class=HTMLResponse)
def show_results(request: Request, session_id: str):
session = get_session(session_id)
if not session:
return templates.TemplateResponse("done.html", {"request": request, "message": "Session not found."})
return templates.TemplateResponse("results.html", {
"request": request,
"session_id": session_id,
"result_paths": session["result_paths"]
})
@app.post("/process")
def process_image(request: Request, image: UploadFile = Form(...), prompt: str = Form(...)):
# 1. Save uploaded image
session_id = uuid.uuid4().hex
save_dir = "uploads"
os.makedirs(save_dir, exist_ok=True)
image_path = os.path.join(save_dir, f"{session_id}_{image.filename}")
with open(image_path, "wb") as buffer:
shutil.copyfileobj(image.file, buffer)
# 2. Run main pipeline
process_session_image(
session_id=session_id,
image_path=image_path,
prompt_text=prompt,
sam_wrapper=sam,
dino_wrapper=dino
)
# 3. Redirect to results page
return RedirectResponse(f"/results/{session_id}", status_code=303)
@app.post("/finalize", response_class=HTMLResponse)
async def finalize_selection(
request: Request,
session_id: str = Form(...),
selected: list[str] = Form(default=[])
):
session = get_session(session_id)
if not session:
return templates.TemplateResponse("ready.html", {"request": request, "message": "Session not found."})
# Remove all the rest of PNGs
for path in session["result_paths"]:
if path not in selected and os.path.exists(path):
os.remove(path)
# Remove all closed session
delete_session(session_id)
return templates.TemplateResponse("ready.html", {
"request": request,
"message": f"Saved {len(selected)} file(s). Session {session_id} closed."
})
|