jree423 commited on
Commit
f9f7f17
·
verified ·
1 Parent(s): 8de5d7e

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +9 -67
app.py CHANGED
@@ -1,71 +1,13 @@
1
- from typing import Dict, List, Any, Optional, Union
2
- import torch
3
- import base64
4
- import io
5
- from PIL import Image
6
  import os
7
- from diffusers import DiffusionPipeline
8
- from fastapi import FastAPI, HTTPException
9
- from pydantic import BaseModel
10
 
11
  app = FastAPI()
 
12
 
13
- class DiffSketcherInput(BaseModel):
14
- prompt: str
15
- negative_prompt: Optional[str] = None
16
- num_paths: Optional[int] = 96
17
- num_iter: Optional[int] = 800
18
- guidance_scale: Optional[float] = 7.5
19
- width: Optional[float] = 2.0
20
- seed: Optional[int] = None
21
-
22
- class DiffSketcherOutput(BaseModel):
23
- svg: str
24
- image: str
25
-
26
- # Load the model
27
- model = None
28
-
29
- def load_model():
30
- global model
31
- if model is None:
32
- model = DiffusionPipeline.from_pretrained(".")
33
- model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
34
- return model
35
-
36
- @app.post("/", response_model=DiffSketcherOutput)
37
- async def generate(input_data: DiffSketcherInput):
38
- # Load the model
39
- model = load_model()
40
-
41
- # Set the seed if provided
42
- if input_data.seed is not None:
43
- torch.manual_seed(input_data.seed)
44
-
45
- try:
46
- # Generate the SVG
47
- output = model(
48
- prompt=input_data.prompt,
49
- negative_prompt=input_data.negative_prompt,
50
- num_paths=input_data.num_paths,
51
- num_iter=input_data.num_iter,
52
- guidance_scale=input_data.guidance_scale,
53
- width=input_data.width
54
- )
55
-
56
- # Get the SVG and image
57
- svg = output.svg
58
- image = output.images[0]
59
-
60
- # Convert the image to base64
61
- buffered = io.BytesIO()
62
- image.save(buffered, format="PNG")
63
- img_str = base64.b64encode(buffered.getvalue()).decode()
64
-
65
- # Return the results
66
- return {
67
- "svg": svg,
68
- "image": img_str
69
- }
70
- except Exception as e:
71
- raise HTTPException(status_code=500, detail=str(e))
 
1
+
2
+ from fastapi import FastAPI, Request
3
+ from handler import EndpointHandler
 
 
4
  import os
 
 
 
5
 
6
  app = FastAPI()
7
+ handler = EndpointHandler(os.getcwd())
8
 
9
+ @app.post("/")
10
+ async def process_request(request: Request):
11
+ json_data = await request.json()
12
+ return handler(json_data)
13
+