dino_sam_objects / run_fastapi.py
Stanislav
feat: dockerfile, yaml, requirements, readme
6c8bcbc
raw
history blame
3.25 kB
from fastapi import FastAPI, Request, UploadFile, Form
from fastapi.responses import RedirectResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import uuid
import shutil
import os
import cv2
from pipeline.process_session import process_session_image
from database.db import init_db
from database.crud import get_session, delete_session
from models.sam import SamWrapper
from models.dino import DinoWrapper
from huggingface_hub import hf_hub_download
# --- Init app and database
app = FastAPI()
init_db()
# --- Ensure necessary folders exist
os.makedirs("uploads", exist_ok=True)
os.makedirs("outputs", exist_ok=True)
# --- Static and templates
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# --- Download checkpoint from model-repo
checkpoint_path = hf_hub_download(repo_id="stkrk/sam-vit-b-checkpoint", filename="sam_vit_b_01ec64.pth")
# --- Model initialization (once)
sam = SamWrapper(
model_type="vit_b",
checkpoint_path=checkpoint_path
)
dino = DinoWrapper()
@app.get("/")
def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/results/{session_id}", response_class=HTMLResponse)
def show_results(request: Request, session_id: str):
session = get_session(session_id)
if not session:
return templates.TemplateResponse("done.html", {"request": request, "message": "Session not found."})
return templates.TemplateResponse("results.html", {
"request": request,
"session_id": session_id,
"result_paths": session["result_paths"]
})
@app.post("/process")
def process_image(request: Request, image: UploadFile = Form(...), prompt: str = Form(...)):
# 1. Save uploaded image
session_id = uuid.uuid4().hex
save_dir = "uploads"
os.makedirs(save_dir, exist_ok=True)
image_path = os.path.join(save_dir, f"{session_id}_{image.filename}")
with open(image_path, "wb") as buffer:
shutil.copyfileobj(image.file, buffer)
# 2. Run main pipeline
process_session_image(
session_id=session_id,
image_path=image_path,
prompt_text=prompt,
sam_wrapper=sam,
dino_wrapper=dino
)
# 3. Redirect to results page
return RedirectResponse(f"/results/{session_id}", status_code=303)
@app.post("/finalize", response_class=HTMLResponse)
async def finalize_selection(
request: Request,
session_id: str = Form(...),
selected: list[str] = Form(default=[])
):
session = get_session(session_id)
if not session:
return templates.TemplateResponse("ready.html", {"request": request, "message": "Session not found."})
# Remove all the rest of PNGs
for path in session["result_paths"]:
if path not in selected and os.path.exists(path):
os.remove(path)
# Remove all closed session
delete_session(session_id)
return templates.TemplateResponse("ready.html", {
"request": request,
"message": f"Saved {len(selected)} file(s). Session {session_id} closed."
})