File size: 3,136 Bytes
aa1c1e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c8bcbc
 
aa1c1e5
 
 
 
 
 
 
 
 
6c8bcbc
 
 
aa1c1e5
 
 
6c8bcbc
aa1c1e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from fastapi import FastAPI, Request, UploadFile, Form
from fastapi.responses import RedirectResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import uuid
import shutil
import os
import cv2

from pipeline.process_session import process_session_image
from database.db import init_db
from database.crud import get_session, delete_session
from models.sam import SamWrapper
from models.dino import DinoWrapper

from huggingface_hub import hf_hub_download

# --- Init app and database
app = FastAPI()
init_db()

# --- Static and templates
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")

# --- Download checkpoint from model-repo
checkpoint_path = hf_hub_download(repo_id="stkrk/sam-vit-b-checkpoint", filename="sam_vit_b_01ec64.pth")

# --- Model initialization (once)
sam = SamWrapper(
        model_type="vit_b",
        checkpoint_path=checkpoint_path
    )
dino = DinoWrapper()


@app.get("/")
def index(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

@app.get("/results/{session_id}", response_class=HTMLResponse)
def show_results(request: Request, session_id: str):
    session = get_session(session_id)
    if not session:
        return templates.TemplateResponse("done.html", {"request": request, "message": "Session not found."})

    return templates.TemplateResponse("results.html", {
        "request": request,
        "session_id": session_id,
        "result_paths": session["result_paths"]
    })


@app.post("/process")
def process_image(request: Request, image: UploadFile = Form(...), prompt: str = Form(...)):
    # 1. Save uploaded image
    session_id = uuid.uuid4().hex
    save_dir = "uploads"
    os.makedirs(save_dir, exist_ok=True)

    image_path = os.path.join(save_dir, f"{session_id}_{image.filename}")
    with open(image_path, "wb") as buffer:
        shutil.copyfileobj(image.file, buffer)

    # 2. Run main pipeline
    process_session_image(
        session_id=session_id,
        image_path=image_path,
        prompt_text=prompt,
        sam_wrapper=sam,
        dino_wrapper=dino
    )

    # 3. Redirect to results page
    return RedirectResponse(f"/results/{session_id}", status_code=303)


@app.post("/finalize", response_class=HTMLResponse)
async def finalize_selection(
    request: Request,
    session_id: str = Form(...),
    selected: list[str] = Form(default=[])
):
    session = get_session(session_id)
    if not session:
        return templates.TemplateResponse("ready.html", {"request": request, "message": "Session not found."})

    # Remove all the rest of PNGs
    for path in session["result_paths"]:
        if path not in selected and os.path.exists(path):
            os.remove(path)

    # Remove all closed session
    delete_session(session_id)

    return templates.TemplateResponse("ready.html", {
        "request": request,
        "message": f"Saved {len(selected)} file(s). Session {session_id} closed."
    })