jree423 commited on
Commit
6a2700f
·
verified ·
1 Parent(s): a96d3ab

Upload model with FastAPI app

Browse files
Files changed (3) hide show
  1. Dockerfile +10 -0
  2. app.py +71 -0
  3. requirements.txt +6 -1
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional, Union
2
+ import torch
3
+ import base64
4
+ import io
5
+ from PIL import Image
6
+ import os
7
+ from diffusers import DiffusionPipeline
8
+ from fastapi import FastAPI, HTTPException
9
+ from pydantic import BaseModel
10
+
11
+ app = FastAPI()
12
+
13
+ class DiffSketcherInput(BaseModel):
14
+ prompt: str
15
+ negative_prompt: Optional[str] = None
16
+ num_paths: Optional[int] = 96
17
+ num_iter: Optional[int] = 800
18
+ guidance_scale: Optional[float] = 7.5
19
+ width: Optional[float] = 2.0
20
+ seed: Optional[int] = None
21
+
22
+ class DiffSketcherOutput(BaseModel):
23
+ svg: str
24
+ image: str
25
+
26
+ # Load the model
27
+ model = None
28
+
29
+ def load_model():
30
+ global model
31
+ if model is None:
32
+ model = DiffusionPipeline.from_pretrained(".")
33
+ model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
34
+ return model
35
+
36
+ @app.post("/", response_model=DiffSketcherOutput)
37
+ async def generate(input_data: DiffSketcherInput):
38
+ # Load the model
39
+ model = load_model()
40
+
41
+ # Set the seed if provided
42
+ if input_data.seed is not None:
43
+ torch.manual_seed(input_data.seed)
44
+
45
+ try:
46
+ # Generate the SVG
47
+ output = model(
48
+ prompt=input_data.prompt,
49
+ negative_prompt=input_data.negative_prompt,
50
+ num_paths=input_data.num_paths,
51
+ num_iter=input_data.num_iter,
52
+ guidance_scale=input_data.guidance_scale,
53
+ width=input_data.width
54
+ )
55
+
56
+ # Get the SVG and image
57
+ svg = output.svg
58
+ image = output.images[0]
59
+
60
+ # Convert the image to base64
61
+ buffered = io.BytesIO()
62
+ image.save(buffered, format="PNG")
63
+ img_str = base64.b64encode(buffered.getvalue()).decode()
64
+
65
+ # Return the results
66
+ return {
67
+ "svg": svg,
68
+ "image": img_str
69
+ }
70
+ except Exception as e:
71
+ raise HTTPException(status_code=500, detail=str(e))
requirements.txt CHANGED
@@ -2,4 +2,9 @@ diffusers>=0.26.3
2
  transformers>=4.36.2
3
  torch>=2.0.0
4
  numpy>=1.24.0
5
- pillow>=9.0.0
 
 
 
 
 
 
2
  transformers>=4.36.2
3
  torch>=2.0.0
4
  numpy>=1.24.0
5
+ pillow>=9.0.0
6
+ huggingface_hub
7
+ requests
8
+ fastapi
9
+ uvicorn
10
+ pydantic