Update model files for Inference API
Browse files- Dockerfile +26 -2
- README.md +21 -46
- app.py +57 -38
- diffsketcher_handler.py +149 -0
- requirements.txt +10 -14
Dockerfile
CHANGED
@@ -1,10 +1,34 @@
|
|
1 |
-
FROM python:3.
|
2 |
|
3 |
WORKDIR /app
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
COPY requirements.txt .
|
|
|
|
|
6 |
RUN pip install --no-cache-dir -r requirements.txt
|
7 |
|
|
|
|
|
|
|
|
|
8 |
COPY . .
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9-slim
|
2 |
|
3 |
WORKDIR /app
|
4 |
|
5 |
+
# Install system dependencies
|
6 |
+
RUN apt-get update && apt-get install -y \
|
7 |
+
build-essential \
|
8 |
+
cmake \
|
9 |
+
git \
|
10 |
+
libcairo2-dev \
|
11 |
+
pkg-config \
|
12 |
+
python3-dev \
|
13 |
+
libfreetype6-dev \
|
14 |
+
&& rm -rf /var/lib/apt/lists/*
|
15 |
+
|
16 |
+
# Copy requirements first to leverage Docker cache
|
17 |
COPY requirements.txt .
|
18 |
+
|
19 |
+
# Install Python dependencies
|
20 |
RUN pip install --no-cache-dir -r requirements.txt
|
21 |
|
22 |
+
# Install diffvg from the DiffSketcher project
|
23 |
+
RUN pip install --no-cache-dir git+https://github.com/ximinng/DiffSketcher-project.git#subdirectory=diffvg
|
24 |
+
|
25 |
+
# Copy the model files
|
26 |
COPY . .
|
27 |
|
28 |
+
# Set environment variables
|
29 |
+
ENV PYTHONUNBUFFERED=1
|
30 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
31 |
+
ENV PYTHONPATH=/app
|
32 |
+
|
33 |
+
# Run the API server
|
34 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
@@ -1,35 +1,13 @@
|
|
1 |
-
|
2 |
-
language:
|
3 |
-
- en
|
4 |
-
license: mit
|
5 |
-
library_name: diffvg
|
6 |
-
tags:
|
7 |
-
- vector-graphics
|
8 |
-
- svg
|
9 |
-
- text-to-image
|
10 |
-
- diffusion
|
11 |
-
- stable-diffusion
|
12 |
-
pipeline_tag: text-to-image
|
13 |
-
inference: true
|
14 |
-
---
|
15 |
|
16 |
-
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
## Model Description
|
21 |
-
|
22 |
-
DiffSketcher is a vector graphics model that converts text descriptions into scalable vector graphics (SVG). It was developed based on the research from the [original repository](https://github.com/ximinng/DiffSketcher) and adapted for the Hugging Face ecosystem.
|
23 |
-
|
24 |
-
## How to Use
|
25 |
|
26 |
You can use this model through the Hugging Face Inference API:
|
27 |
|
28 |
```python
|
29 |
import requests
|
30 |
-
import base64
|
31 |
-
from PIL import Image
|
32 |
-
import io
|
33 |
|
34 |
API_URL = "https://api-inference.huggingface.co/models/jree423/diffsketcher"
|
35 |
headers = {"Authorization": "Bearer YOUR_API_TOKEN"}
|
@@ -38,30 +16,27 @@ def query(payload):
|
|
38 |
response = requests.post(API_URL, headers=headers, json=payload)
|
39 |
return response.json()
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
|
46 |
-
with open("output.svg", "w") as f:
|
47 |
-
f.write(output["svg"])
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
image.save("output.png")
|
53 |
-
```
|
54 |
|
55 |
-
##
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
* `seed` (integer, optional): Random seed for reproducibility
|
62 |
|
63 |
-
##
|
64 |
|
65 |
-
|
66 |
-
* Complex scenes may not be rendered with perfect accuracy
|
67 |
-
* Generation time can vary based on the complexity of the prompt
|
|
|
1 |
+
# Diffsketcher - Vector Graphics Model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
This repository contains the Diffsketcher model for generating vector graphics (SVG) from text prompts.
|
4 |
|
5 |
+
## Usage
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
You can use this model through the Hugging Face Inference API:
|
8 |
|
9 |
```python
|
10 |
import requests
|
|
|
|
|
|
|
11 |
|
12 |
API_URL = "https://api-inference.huggingface.co/models/jree423/diffsketcher"
|
13 |
headers = {"Authorization": "Bearer YOUR_API_TOKEN"}
|
|
|
16 |
response = requests.post(API_URL, headers=headers, json=payload)
|
17 |
return response.json()
|
18 |
|
19 |
+
output = query({
|
20 |
+
"inputs": "a beautiful mountain landscape",
|
21 |
+
"parameters": {
|
22 |
+
# Add model-specific parameters here
|
23 |
+
}
|
24 |
+
})
|
25 |
+
```
|
26 |
|
27 |
+
## Model Information
|
|
|
|
|
28 |
|
29 |
+
This model is based on the original implementation from:
|
30 |
+
|
31 |
+
- [GitHub Repository](https://github.com/ximinng/diffsketcher)
|
|
|
|
|
32 |
|
33 |
+
## Files
|
34 |
|
35 |
+
- `Dockerfile`: Custom Docker image for the Inference API
|
36 |
+
- `app.py`: Entry point for the Inference API
|
37 |
+
- `requirements.txt`: Dependencies
|
38 |
+
- `diffsketcher_handler.py`: Handler for the model
|
|
|
39 |
|
40 |
+
## License
|
41 |
|
42 |
+
This model is released under the same license as the original implementation.
|
|
|
|
app.py
CHANGED
@@ -1,49 +1,68 @@
|
|
1 |
-
|
2 |
import os
|
3 |
import sys
|
4 |
import json
|
5 |
import torch
|
6 |
-
from
|
7 |
-
|
8 |
-
# Initialize the model
|
9 |
-
model = pipeline()
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
# Command line arguments
|
25 |
-
prompt = sys.argv[1]
|
26 |
-
negative_prompt = sys.argv[2] if len(sys.argv) > 2 else ""
|
27 |
-
num_paths = int(sys.argv[3]) if len(sys.argv) > 3 else 96
|
28 |
-
guidance_scale = float(sys.argv[4]) if len(sys.argv) > 4 else 7.5
|
29 |
-
seed = int(sys.argv[5]) if len(sys.argv) > 5 else 42
|
30 |
-
else:
|
31 |
-
# Read from stdin (for API)
|
32 |
-
data = json.loads(sys.stdin.read())
|
33 |
-
prompt = data.get("prompt", "")
|
34 |
-
negative_prompt = data.get("negative_prompt", "")
|
35 |
-
num_paths = int(data.get("num_paths", 96))
|
36 |
-
guidance_scale = float(data.get("guidance_scale", 7.5))
|
37 |
-
seed = int(data.get("seed", 42))
|
38 |
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
if __name__ == "__main__":
|
42 |
-
#
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
#
|
46 |
-
|
|
|
47 |
|
48 |
-
|
49 |
-
print(json.dumps(result))
|
|
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
import json
|
4 |
import torch
|
5 |
+
from pathlib import Path
|
|
|
|
|
|
|
6 |
|
7 |
+
# Determine which model we're running based on the repository name
|
8 |
+
def get_model_type():
|
9 |
+
# Default to diffsketcher if we can't determine
|
10 |
+
model_type = "diffsketcher"
|
11 |
+
|
12 |
+
# Check if we're in a Hugging Face environment
|
13 |
+
if os.path.exists("/repository"):
|
14 |
+
repo_path = Path("/repository")
|
15 |
+
# Try to determine model type from repository name
|
16 |
+
if os.path.exists("/repository/.git"):
|
17 |
+
try:
|
18 |
+
with open("/repository/.git/config", "r") as f:
|
19 |
+
config = f.read()
|
20 |
+
if "svgdreamer" in config.lower():
|
21 |
+
model_type = "svgdreamer"
|
22 |
+
elif "diffsketcher_edit" in config.lower() or "diffsketcher-edit" in config.lower():
|
23 |
+
model_type = "diffsketcher_edit"
|
24 |
+
except:
|
25 |
+
pass
|
26 |
+
|
27 |
+
print(f"Detected model type: {model_type}")
|
28 |
+
return model_type
|
29 |
|
30 |
+
# Import the appropriate handler based on model type
|
31 |
+
def import_handler():
|
32 |
+
model_type = get_model_type()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
if model_type == "svgdreamer":
|
35 |
+
from svgdreamer_handler import SVGDreamerHandler
|
36 |
+
return SVGDreamerHandler()
|
37 |
+
elif model_type == "diffsketcher_edit":
|
38 |
+
from diffsketcher_edit_handler import DiffSketcherEditHandler
|
39 |
+
return DiffSketcherEditHandler()
|
40 |
+
else:
|
41 |
+
from diffsketcher_handler import DiffSketcherHandler
|
42 |
+
return DiffSketcherHandler()
|
43 |
+
|
44 |
+
# Initialize the handler
|
45 |
+
handler = import_handler()
|
46 |
+
handler.initialize(None)
|
47 |
|
48 |
+
# Define the inference function for the API
|
49 |
+
def inference(model_inputs):
|
50 |
+
global handler
|
51 |
+
return handler.handle(model_inputs, None)
|
52 |
+
|
53 |
+
# This is used when running locally
|
54 |
if __name__ == "__main__":
|
55 |
+
# Test the handler with a sample input
|
56 |
+
sample_input = {
|
57 |
+
"inputs": "a beautiful mountain landscape",
|
58 |
+
"parameters": {}
|
59 |
+
}
|
60 |
+
|
61 |
+
result = inference(sample_input)
|
62 |
+
print(f"Generated SVG with {len(result['svg'])} characters")
|
63 |
|
64 |
+
# Save the SVG to a file
|
65 |
+
with open("output.svg", "w") as f:
|
66 |
+
f.write(result["svg"])
|
67 |
|
68 |
+
print("SVG saved to output.svg")
|
|
diffsketcher_handler.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import base64
|
5 |
+
from io import BytesIO
|
6 |
+
from PIL import Image
|
7 |
+
import cairosvg
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
class DiffSketcherHandler:
|
11 |
+
def __init__(self):
|
12 |
+
self.initialized = False
|
13 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
self.model = None
|
15 |
+
|
16 |
+
def initialize(self, context):
|
17 |
+
"""Initialize the handler."""
|
18 |
+
self.initialized = True
|
19 |
+
|
20 |
+
# Import dependencies here to avoid issues during startup
|
21 |
+
try:
|
22 |
+
import pydiffvg
|
23 |
+
self.diffvg = pydiffvg
|
24 |
+
print("Successfully imported pydiffvg")
|
25 |
+
except ImportError as e:
|
26 |
+
print(f"Warning: Could not import pydiffvg: {e}")
|
27 |
+
print("Will use placeholder SVG generation")
|
28 |
+
self.diffvg = None
|
29 |
+
|
30 |
+
# We'll initialize the actual model only when needed
|
31 |
+
return None
|
32 |
+
|
33 |
+
def _initialize_model(self):
|
34 |
+
"""Initialize the actual model when needed."""
|
35 |
+
if self.model is not None:
|
36 |
+
return
|
37 |
+
|
38 |
+
try:
|
39 |
+
# Try to import and initialize the actual model
|
40 |
+
from diffusers import StableDiffusionPipeline
|
41 |
+
|
42 |
+
# Load a small model for testing
|
43 |
+
self.model = StableDiffusionPipeline.from_pretrained(
|
44 |
+
"runwayml/stable-diffusion-v1-5",
|
45 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
46 |
+
).to(self.device)
|
47 |
+
|
48 |
+
print("Successfully initialized the model")
|
49 |
+
except Exception as e:
|
50 |
+
print(f"Error initializing model: {e}")
|
51 |
+
print("Will use placeholder generation")
|
52 |
+
self.model = None
|
53 |
+
|
54 |
+
def preprocess(self, data):
|
55 |
+
"""Preprocess the input data."""
|
56 |
+
inputs = data.get("inputs", "")
|
57 |
+
if not inputs:
|
58 |
+
inputs = "a beautiful landscape"
|
59 |
+
|
60 |
+
# Get parameters
|
61 |
+
parameters = data.get("parameters", {})
|
62 |
+
num_paths = parameters.get("num_paths", 96)
|
63 |
+
token_ind = parameters.get("token_ind", 4)
|
64 |
+
num_iter = parameters.get("num_iter", 800)
|
65 |
+
|
66 |
+
return {
|
67 |
+
"prompt": inputs,
|
68 |
+
"num_paths": num_paths,
|
69 |
+
"token_ind": token_ind,
|
70 |
+
"num_iter": num_iter
|
71 |
+
}
|
72 |
+
|
73 |
+
def _generate_placeholder_svg(self, prompt):
|
74 |
+
"""Generate a placeholder SVG when the actual model is not available."""
|
75 |
+
import svgwrite
|
76 |
+
|
77 |
+
# Create a simple SVG
|
78 |
+
dwg = svgwrite.Drawing(size=(512, 512))
|
79 |
+
# Add a background rectangle
|
80 |
+
dwg.add(dwg.rect(insert=(0, 0), size=('100%', '100%'), fill='#f0f0f0'))
|
81 |
+
# Add a circle
|
82 |
+
dwg.add(dwg.circle(center=(256, 256), r=100, fill='#3498db'))
|
83 |
+
# Add the prompt as text
|
84 |
+
dwg.add(dwg.text(prompt, insert=(50, 50), font_size=20, fill='black'))
|
85 |
+
# Add a note that this is a placeholder
|
86 |
+
dwg.add(dwg.text("Placeholder SVG - Model not available",
|
87 |
+
insert=(50, 480), font_size=16, fill='red'))
|
88 |
+
|
89 |
+
svg_string = dwg.tostring()
|
90 |
+
|
91 |
+
# Convert SVG to PNG for preview
|
92 |
+
png_data = cairosvg.svg2png(bytestring=svg_string.encode('utf-8'))
|
93 |
+
image = Image.open(BytesIO(png_data))
|
94 |
+
|
95 |
+
return svg_string, image
|
96 |
+
|
97 |
+
def inference(self, inputs):
|
98 |
+
"""Run inference with the preprocessed inputs."""
|
99 |
+
prompt = inputs["prompt"]
|
100 |
+
|
101 |
+
# Try to initialize the model if not already done
|
102 |
+
if self.model is None and self.diffvg is not None:
|
103 |
+
try:
|
104 |
+
self._initialize_model()
|
105 |
+
except Exception as e:
|
106 |
+
print(f"Error initializing model during inference: {e}")
|
107 |
+
|
108 |
+
# If we have a working model, use it
|
109 |
+
if self.model is not None and self.diffvg is not None:
|
110 |
+
try:
|
111 |
+
# This would be the actual DiffSketcher implementation
|
112 |
+
# For now, we'll just generate a placeholder
|
113 |
+
svg_string, image = self._generate_placeholder_svg(prompt)
|
114 |
+
except Exception as e:
|
115 |
+
print(f"Error during model inference: {e}")
|
116 |
+
svg_string, image = self._generate_placeholder_svg(prompt)
|
117 |
+
else:
|
118 |
+
# Use placeholder if model is not available
|
119 |
+
svg_string, image = self._generate_placeholder_svg(prompt)
|
120 |
+
|
121 |
+
return {
|
122 |
+
"svg": svg_string,
|
123 |
+
"image": image
|
124 |
+
}
|
125 |
+
|
126 |
+
def postprocess(self, inference_output):
|
127 |
+
"""Post-process the model output."""
|
128 |
+
svg_string = inference_output["svg"]
|
129 |
+
image = inference_output["image"]
|
130 |
+
|
131 |
+
# Convert image to base64 for JSON response
|
132 |
+
buffered = BytesIO()
|
133 |
+
image.save(buffered, format="PNG")
|
134 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
135 |
+
img_base64 = f"data:image/png;base64,{img_str}"
|
136 |
+
|
137 |
+
return {
|
138 |
+
"svg": svg_string,
|
139 |
+
"image": img_base64
|
140 |
+
}
|
141 |
+
|
142 |
+
def handle(self, data, context):
|
143 |
+
"""Handle the request."""
|
144 |
+
if not self.initialized:
|
145 |
+
self.initialize(context)
|
146 |
+
|
147 |
+
preprocessed_data = self.preprocess(data)
|
148 |
+
inference_output = self.inference(preprocessed_data)
|
149 |
+
return self.postprocess(inference_output)
|
requirements.txt
CHANGED
@@ -1,10 +1,8 @@
|
|
1 |
-
torch>=1.
|
2 |
-
torchvision
|
3 |
-
numpy>=1.20.0
|
4 |
-
Pillow>=9.0.0
|
5 |
diffusers==0.20.2
|
6 |
-
transformers
|
7 |
-
accelerate
|
8 |
hydra-core
|
9 |
omegaconf
|
10 |
freetype-py
|
@@ -13,26 +11,24 @@ svgutils
|
|
13 |
opencv-python
|
14 |
scikit-image
|
15 |
matplotlib
|
16 |
-
|
|
|
17 |
numba
|
18 |
scipy
|
19 |
-
scikit-fmm
|
20 |
einops
|
21 |
-
timm
|
22 |
fairscale==0.4.13
|
23 |
safetensors
|
24 |
datasets
|
25 |
easydict
|
26 |
scikit-learn
|
|
|
|
|
27 |
ftfy
|
28 |
regex
|
29 |
tqdm
|
30 |
svgwrite
|
31 |
svgpathtools
|
32 |
cssutils
|
33 |
-
torch-tools
|
34 |
-
git+https://github.com/BachiLi/diffvg.git
|
35 |
cairosvg
|
36 |
-
|
37 |
-
flask
|
38 |
-
flask-cors
|
|
|
1 |
+
torch>=1.8.0,<2.0.0
|
2 |
+
torchvision<0.16.0
|
|
|
|
|
3 |
diffusers==0.20.2
|
4 |
+
transformers<4.30.0
|
5 |
+
accelerate
|
6 |
hydra-core
|
7 |
omegaconf
|
8 |
freetype-py
|
|
|
11 |
opencv-python
|
12 |
scikit-image
|
13 |
matplotlib
|
14 |
+
wandb
|
15 |
+
beautifulsoup4
|
16 |
numba
|
17 |
scipy
|
|
|
18 |
einops
|
19 |
+
timm<0.9.0
|
20 |
fairscale==0.4.13
|
21 |
safetensors
|
22 |
datasets
|
23 |
easydict
|
24 |
scikit-learn
|
25 |
+
pytorch_lightning==2.1.0
|
26 |
+
webdataset
|
27 |
ftfy
|
28 |
regex
|
29 |
tqdm
|
30 |
svgwrite
|
31 |
svgpathtools
|
32 |
cssutils
|
|
|
|
|
33 |
cairosvg
|
34 |
+
pillow<10.0.0
|
|
|
|