Upload folder using huggingface_hub
Browse files- README.md +87 -61
- config.json +50 -22
- handler.py +138 -108
README.md
CHANGED
@@ -1,99 +1,125 @@
|
|
1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
license: mit
|
3 |
tags:
|
4 |
-
- text-to-
|
5 |
- vector-graphics
|
6 |
-
- svg
|
7 |
-
- art-generation
|
8 |
- diffusion
|
9 |
-
|
|
|
10 |
pipeline_tag: text-to-image
|
11 |
-
task: text-to-image
|
12 |
---
|
13 |
|
14 |
-
#
|
15 |
|
16 |
-
|
17 |
|
18 |
-
## Model
|
19 |
|
20 |
-
|
21 |
-
- **Task**: `text-to-image`
|
22 |
-
- **Input**: text
|
23 |
-
- **Output**: svg
|
24 |
-
|
25 |
-
## Features
|
26 |
-
|
27 |
-
- ✅ **Working SVG Generation**: Produces actual vector graphics content, not blank images
|
28 |
-
- ✅ **Multiple Styles**: painterly, sketchy, artistic
|
29 |
-
- ✅ **API Ready**: Deployed with proper Inference API handler
|
30 |
-
- ✅ **Real-time Generation**: Fast inference suitable for interactive applications
|
31 |
-
|
32 |
-
## Input Parameters
|
33 |
-
|
34 |
-
- `prompt` (required): Text description of what to generate/edit
|
35 |
-
- `num_paths` (optional): Number of vector paths (default: 16)
|
36 |
-
- `width` (optional): Output width in pixels (default: 512)
|
37 |
-
- `height` (optional): Output height in pixels (default: 512)
|
38 |
|
39 |
## Usage
|
40 |
|
|
|
|
|
41 |
```python
|
42 |
import requests
|
43 |
-
import base64
|
44 |
|
|
|
45 |
headers = {"Authorization": "Bearer YOUR_HF_TOKEN"}
|
46 |
|
47 |
-
|
48 |
-
response = requests.post(
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
json={
|
52 |
-
"inputs": "a
|
53 |
"parameters": {
|
54 |
-
"num_paths":
|
55 |
-
"
|
56 |
-
"height": 512
|
57 |
}
|
58 |
}
|
59 |
)
|
|
|
60 |
|
61 |
-
|
62 |
-
svg_content = base64.b64decode(result["svg_base64"]).decode('utf-8')
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
-
- `svg_content`: Raw SVG markup
|
73 |
-
- `svg_base64`: Base64-encoded SVG for easy embedding
|
74 |
-
- `model`: Model name
|
75 |
-
- `prompt`: Input prompt
|
76 |
-
- Additional parameters based on model type
|
77 |
|
78 |
-
|
|
|
|
|
|
|
79 |
|
80 |
-
|
81 |
-
-
|
82 |
-
-
|
83 |
-
-
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
|
86 |
## Technical Details
|
87 |
|
88 |
-
- **
|
89 |
-
- **
|
90 |
-
- **
|
91 |
-
- **
|
92 |
|
93 |
-
##
|
94 |
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
## License
|
98 |
|
99 |
-
|
|
|
1 |
---
|
2 |
+
title: DiffSketcher
|
3 |
+
emoji: 🎨
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
+
sdk: custom
|
7 |
+
app_file: handler.py
|
8 |
+
pinned: false
|
9 |
license: mit
|
10 |
tags:
|
11 |
+
- text-to-svg
|
12 |
- vector-graphics
|
|
|
|
|
13 |
- diffusion
|
14 |
+
- sketch
|
15 |
+
- art
|
16 |
pipeline_tag: text-to-image
|
|
|
17 |
---
|
18 |
|
19 |
+
# DiffSketcher: Text Guided Vector Sketch Synthesis
|
20 |
|
21 |
+
DiffSketcher is a novel method for generating high-quality vector sketches from text prompts using latent diffusion models. This model can create artistic SVG representations based on natural language descriptions.
|
22 |
|
23 |
+
## Model Description
|
24 |
|
25 |
+
DiffSketcher leverages the power of Stable Diffusion to guide the generation of vector graphics. The model optimizes SVG paths to match the semantic content described in the input text while maintaining the artistic quality of hand-drawn sketches.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
## Usage
|
28 |
|
29 |
+
### Direct API Call
|
30 |
+
|
31 |
```python
|
32 |
import requests
|
|
|
33 |
|
34 |
+
API_URL = "https://api-inference.huggingface.co/models/jree423/diffsketcher"
|
35 |
headers = {"Authorization": "Bearer YOUR_HF_TOKEN"}
|
36 |
|
37 |
+
def query(payload):
|
38 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
39 |
+
return response.json()
|
40 |
+
|
41 |
+
output = query({
|
42 |
+
"inputs": "a beautiful mountain landscape",
|
43 |
+
"parameters": {
|
44 |
+
"num_paths": 96,
|
45 |
+
"num_iter": 500,
|
46 |
+
"guidance_scale": 7.5,
|
47 |
+
"width": 224,
|
48 |
+
"height": 224,
|
49 |
+
"seed": 42
|
50 |
+
}
|
51 |
+
})
|
52 |
+
```
|
53 |
+
|
54 |
+
### Using the Inference Client
|
55 |
+
|
56 |
+
```python
|
57 |
+
from huggingface_hub import InferenceClient
|
58 |
+
|
59 |
+
client = InferenceClient("jree423/diffsketcher")
|
60 |
+
result = client.post(
|
61 |
json={
|
62 |
+
"inputs": "a cat sitting on a windowsill",
|
63 |
"parameters": {
|
64 |
+
"num_paths": 128,
|
65 |
+
"guidance_scale": 8.0
|
|
|
66 |
}
|
67 |
}
|
68 |
)
|
69 |
+
```
|
70 |
|
71 |
+
## Parameters
|
|
|
72 |
|
73 |
+
- **num_paths** (int, default: 96): Number of SVG paths to generate. More paths create more detailed sketches.
|
74 |
+
- **num_iter** (int, default: 500): Number of optimization iterations. More iterations improve quality but take longer.
|
75 |
+
- **guidance_scale** (float, default: 7.5): Controls how closely the generation follows the text prompt.
|
76 |
+
- **width** (int, default: 224): Output SVG width in pixels.
|
77 |
+
- **height** (int, default: 224): Output SVG height in pixels.
|
78 |
+
- **seed** (int, default: 42): Random seed for reproducible results.
|
79 |
+
|
80 |
+
## Output Format
|
81 |
|
82 |
+
The model returns a JSON object containing:
|
83 |
+
- `svg`: The generated SVG content as a string
|
84 |
+
- `svg_base64`: Base64 encoded SVG for easy transmission
|
85 |
+
- `prompt`: The input text prompt
|
86 |
+
- `parameters`: The parameters used for generation
|
87 |
|
88 |
+
## Examples
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
+
### Simple Objects
|
91 |
+
- "a red apple"
|
92 |
+
- "a flying bird"
|
93 |
+
- "a vintage car"
|
94 |
|
95 |
+
### Complex Scenes
|
96 |
+
- "a mountain landscape with trees"
|
97 |
+
- "a city skyline at sunset"
|
98 |
+
- "a garden with flowers and butterflies"
|
99 |
+
|
100 |
+
### Artistic Styles
|
101 |
+
- "a portrait in the style of Van Gogh"
|
102 |
+
- "minimalist line drawing of a face"
|
103 |
+
- "abstract geometric patterns"
|
104 |
|
105 |
## Technical Details
|
106 |
|
107 |
+
- **Base Model**: Stable Diffusion 2.1
|
108 |
+
- **Framework**: PyTorch + Diffusers
|
109 |
+
- **Vector Rendering**: DiffVG (differentiable vector graphics)
|
110 |
+
- **Optimization**: Adam optimizer with custom learning rates for different SVG parameters
|
111 |
|
112 |
+
## Citation
|
113 |
|
114 |
+
```bibtex
|
115 |
+
@inproceedings{xing2023diffsketcher,
|
116 |
+
title={DiffSketcher: Text Guided Vector Sketch Synthesis through Latent Diffusion Models},
|
117 |
+
author={Xing, XiMing and others},
|
118 |
+
booktitle={NeurIPS},
|
119 |
+
year={2023}
|
120 |
+
}
|
121 |
+
```
|
122 |
|
123 |
## License
|
124 |
|
125 |
+
This model is released under the MIT License.
|
config.json
CHANGED
@@ -1,26 +1,54 @@
|
|
1 |
{
|
|
|
2 |
"model_type": "diffsketcher",
|
3 |
-
"task": "text-to-
|
4 |
-
"pipeline_tag": "text-to-image",
|
5 |
"framework": "pytorch",
|
6 |
-
"
|
7 |
-
"
|
8 |
-
"
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
}
|
|
|
1 |
{
|
2 |
+
"architectures": ["DiffSketcherModel"],
|
3 |
"model_type": "diffsketcher",
|
4 |
+
"task": "text-to-svg",
|
|
|
5 |
"framework": "pytorch",
|
6 |
+
"pipeline_tag": "text-to-image",
|
7 |
+
"library_name": "diffusers",
|
8 |
+
"inference": {
|
9 |
+
"parameters": {
|
10 |
+
"num_paths": {
|
11 |
+
"type": "integer",
|
12 |
+
"default": 96,
|
13 |
+
"minimum": 1,
|
14 |
+
"maximum": 1000,
|
15 |
+
"description": "Number of SVG paths to generate"
|
16 |
+
},
|
17 |
+
"num_iter": {
|
18 |
+
"type": "integer",
|
19 |
+
"default": 500,
|
20 |
+
"minimum": 10,
|
21 |
+
"maximum": 2000,
|
22 |
+
"description": "Number of optimization iterations"
|
23 |
+
},
|
24 |
+
"guidance_scale": {
|
25 |
+
"type": "number",
|
26 |
+
"default": 7.5,
|
27 |
+
"minimum": 1.0,
|
28 |
+
"maximum": 20.0,
|
29 |
+
"description": "Guidance scale for diffusion"
|
30 |
+
},
|
31 |
+
"width": {
|
32 |
+
"type": "integer",
|
33 |
+
"default": 224,
|
34 |
+
"minimum": 64,
|
35 |
+
"maximum": 1024,
|
36 |
+
"description": "Output SVG width"
|
37 |
+
},
|
38 |
+
"height": {
|
39 |
+
"type": "integer",
|
40 |
+
"default": 224,
|
41 |
+
"minimum": 64,
|
42 |
+
"maximum": 1024,
|
43 |
+
"description": "Output SVG height"
|
44 |
+
},
|
45 |
+
"seed": {
|
46 |
+
"type": "integer",
|
47 |
+
"default": 42,
|
48 |
+
"minimum": 0,
|
49 |
+
"maximum": 2147483647,
|
50 |
+
"description": "Random seed for reproducibility"
|
51 |
+
}
|
52 |
+
}
|
53 |
+
}
|
54 |
}
|
handler.py
CHANGED
@@ -1,16 +1,90 @@
|
|
1 |
-
import
|
|
|
2 |
import json
|
3 |
-
import
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
|
|
|
|
|
|
7 |
def __init__(self, path=""):
|
8 |
-
""
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
def
|
12 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
try:
|
|
|
|
|
|
|
|
|
|
|
14 |
# Extract inputs
|
15 |
if isinstance(data, dict):
|
16 |
prompt = data.get("inputs", "")
|
@@ -20,120 +94,76 @@ class EndpointHandler:
|
|
20 |
parameters = {}
|
21 |
|
22 |
if not prompt:
|
23 |
-
return {"error": "No prompt provided"}
|
24 |
|
25 |
# Extract parameters
|
26 |
-
num_paths = parameters.get("num_paths",
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
29 |
|
30 |
-
#
|
31 |
-
|
|
|
32 |
|
33 |
-
#
|
34 |
-
|
|
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
40 |
"prompt": prompt,
|
41 |
"parameters": {
|
42 |
"num_paths": num_paths,
|
|
|
|
|
43 |
"width": width,
|
44 |
-
"height": height
|
|
|
45 |
}
|
46 |
-
}
|
47 |
|
48 |
except Exception as e:
|
49 |
-
return {"error": f"Generation failed: {str(e)}"}
|
50 |
|
51 |
-
def
|
52 |
-
"""
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
57 |
|
58 |
-
# Generate
|
59 |
-
|
60 |
-
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
68 |
else:
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
f'<circle cx="{cx}" cy="{cy-20}" r="60" fill="none" stroke="black" stroke-width="3" />',
|
81 |
-
f'<polygon points="{cx-40},{cy-60} {cx-20},{cy-80} {cx-10},{cy-50}" fill="none" stroke="black" stroke-width="2" />',
|
82 |
-
f'<polygon points="{cx+40},{cy-60} {cx+20},{cy-80} {cx+10},{cy-50}" fill="none" stroke="black" stroke-width="2" />',
|
83 |
-
f'<circle cx="{cx-20}" cy="{cy-10}" r="8" fill="black" />',
|
84 |
-
f'<circle cx="{cx+20}" cy="{cy-10}" r="8" fill="black" />',
|
85 |
-
f'<polygon points="{cx-5},{cy+10} {cx+5},{cy+10} {cx},{cy+20}" fill="pink" />',
|
86 |
-
f'<line x1="{cx-50}" y1="{cy}" x2="{cx-70}" y2="{cy-5}" stroke="black" stroke-width="1" />',
|
87 |
-
f'<line x1="{cx+50}" y1="{cy}" x2="{cx+70}" y2="{cy-5}" stroke="black" stroke-width="1" />',
|
88 |
-
f'<ellipse cx="{cx}" cy="{cy+80}" rx="40" ry="60" fill="none" stroke="black" stroke-width="3" />',
|
89 |
-
]
|
90 |
-
|
91 |
-
def _draw_flower_sketch(self, cx, cy):
|
92 |
-
"""Draw a sketchy flower"""
|
93 |
-
petals = []
|
94 |
-
for i in range(8):
|
95 |
-
angle = i * 45
|
96 |
-
petal_x = cx + 50 * math.cos(math.radians(angle))
|
97 |
-
petal_y = cy + 50 * math.sin(math.radians(angle))
|
98 |
-
petals.append(f'<ellipse cx="{petal_x}" cy="{petal_y}" rx="20" ry="35" fill="pink" stroke="red" stroke-width="2" transform="rotate({angle} {petal_x} {petal_y})" />')
|
99 |
-
|
100 |
-
return petals + [
|
101 |
-
f'<circle cx="{cx}" cy="{cy}" r="15" fill="yellow" stroke="orange" stroke-width="2" />',
|
102 |
-
f'<line x1="{cx}" y1="{cy+15}" x2="{cx}" y2="{cy+120}" stroke="green" stroke-width="4" />',
|
103 |
-
f'<ellipse cx="{cx-20}" cy="{cy+80}" rx="15" ry="25" fill="lightgreen" stroke="green" stroke-width="2" />',
|
104 |
-
f'<ellipse cx="{cx+20}" cy="{cy+90}" rx="15" ry="25" fill="lightgreen" stroke="green" stroke-width="2" />',
|
105 |
-
]
|
106 |
-
|
107 |
-
def _draw_house_sketch(self, cx, cy):
|
108 |
-
"""Draw a sketchy house"""
|
109 |
-
return [
|
110 |
-
f'<rect x="{cx-50}" y="{cy}" width="100" height="60" fill="lightblue" stroke="blue" stroke-width="3" />',
|
111 |
-
f'<polygon points="{cx-60},{cy} {cx},{cy-50} {cx+60},{cy}" fill="red" stroke="darkred" stroke-width="2" />',
|
112 |
-
f'<rect x="{cx-15}" y="{cy+20}" width="30" height="40" fill="brown" />',
|
113 |
-
f'<rect x="{cx-40}" y="{cy+15}" width="20" height="20" fill="lightblue" stroke="blue" stroke-width="2" />',
|
114 |
-
f'<rect x="{cx+20}" y="{cy+15}" width="20" height="20" fill="lightblue" stroke="blue" stroke-width="2" />',
|
115 |
-
]
|
116 |
-
|
117 |
-
def _draw_abstract_sketch(self, cx, cy, num_paths):
|
118 |
-
"""Draw abstract sketchy shapes"""
|
119 |
-
import random
|
120 |
-
random.seed(42) # For consistent results
|
121 |
-
|
122 |
-
shapes = []
|
123 |
-
colors = ["red", "blue", "green", "orange", "purple", "pink", "yellow"]
|
124 |
-
|
125 |
-
for i in range(min(num_paths, 12)):
|
126 |
-
x = cx + random.randint(-150, 150)
|
127 |
-
y = cy + random.randint(-150, 150)
|
128 |
-
r = random.randint(20, 60)
|
129 |
-
color = random.choice(colors)
|
130 |
-
|
131 |
-
if i % 3 == 0:
|
132 |
-
shapes.append(f'<circle cx="{x}" cy="{y}" r="{r}" fill="none" stroke="{color}" stroke-width="3" />')
|
133 |
-
elif i % 3 == 1:
|
134 |
-
shapes.append(f'<rect x="{x-r//2}" y="{y-r//2}" width="{r}" height="{r}" fill="none" stroke="{color}" stroke-width="2" />')
|
135 |
-
else:
|
136 |
-
points = f"{x},{y-r} {x+r},{y+r} {x-r},{y+r}"
|
137 |
-
shapes.append(f'<polygon points="{points}" fill="none" stroke="{color}" stroke-width="2" />')
|
138 |
-
|
139 |
-
return shapes
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
import json
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import io
|
8 |
+
import base64
|
9 |
+
from typing import Dict, Any, List
|
10 |
+
import tempfile
|
11 |
+
import subprocess
|
12 |
|
13 |
+
# Add the DiffSketcher path to sys.path
|
14 |
+
sys.path.append('/workspace/DiffSketcher')
|
15 |
+
|
16 |
+
class DiffSketcherHandler:
|
17 |
def __init__(self, path=""):
|
18 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
self.model_loaded = False
|
20 |
+
|
21 |
+
def load_model(self):
|
22 |
+
"""Load the DiffSketcher model and dependencies"""
|
23 |
+
try:
|
24 |
+
# Import DiffSketcher modules
|
25 |
+
from methods.painter.diffsketcher import Painter
|
26 |
+
from methods.diffusers_warp import StableDiffusionPipeline
|
27 |
+
|
28 |
+
# Load the diffusion model
|
29 |
+
self.pipe = StableDiffusionPipeline.from_pretrained(
|
30 |
+
"stabilityai/stable-diffusion-2-1-base",
|
31 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
32 |
+
safety_checker=None,
|
33 |
+
requires_safety_checker=False
|
34 |
+
).to(self.device)
|
35 |
+
|
36 |
+
# Initialize the painter
|
37 |
+
self.painter = Painter(
|
38 |
+
args=self._get_default_args(),
|
39 |
+
pipe=self.pipe
|
40 |
+
)
|
41 |
+
|
42 |
+
self.model_loaded = True
|
43 |
+
return True
|
44 |
+
|
45 |
+
except Exception as e:
|
46 |
+
print(f"Error loading model: {str(e)}")
|
47 |
+
return False
|
48 |
|
49 |
+
def _get_default_args(self):
|
50 |
+
"""Get default arguments for DiffSketcher"""
|
51 |
+
class Args:
|
52 |
+
def __init__(self):
|
53 |
+
self.token_ind = 4
|
54 |
+
self.num_paths = 96
|
55 |
+
self.num_iter = 500
|
56 |
+
self.guidance_scale = 7.5
|
57 |
+
self.lr_scheduler = True
|
58 |
+
self.lr = 1.0
|
59 |
+
self.color_lr = 0.01
|
60 |
+
self.width_lr = 0.1
|
61 |
+
self.opacity_lr = 0.01
|
62 |
+
self.width = 224
|
63 |
+
self.height = 224
|
64 |
+
self.seed = 42
|
65 |
+
self.eval_step = 10
|
66 |
+
self.save_step = 10
|
67 |
+
|
68 |
+
return Args()
|
69 |
+
|
70 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
71 |
+
"""
|
72 |
+
Process the input data and return SVG generation results
|
73 |
+
|
74 |
+
Args:
|
75 |
+
data: Dictionary containing:
|
76 |
+
- inputs: Text prompt for SVG generation
|
77 |
+
- parameters: Optional parameters for generation
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
List of dictionaries containing generated SVG and metadata
|
81 |
+
"""
|
82 |
try:
|
83 |
+
# Load model if not already loaded
|
84 |
+
if not self.model_loaded:
|
85 |
+
if not self.load_model():
|
86 |
+
return [{"error": "Failed to load model"}]
|
87 |
+
|
88 |
# Extract inputs
|
89 |
if isinstance(data, dict):
|
90 |
prompt = data.get("inputs", "")
|
|
|
94 |
parameters = {}
|
95 |
|
96 |
if not prompt:
|
97 |
+
return [{"error": "No prompt provided"}]
|
98 |
|
99 |
# Extract parameters
|
100 |
+
num_paths = parameters.get("num_paths", 96)
|
101 |
+
num_iter = parameters.get("num_iter", 500)
|
102 |
+
guidance_scale = parameters.get("guidance_scale", 7.5)
|
103 |
+
width = parameters.get("width", 224)
|
104 |
+
height = parameters.get("height", 224)
|
105 |
+
seed = parameters.get("seed", 42)
|
106 |
|
107 |
+
# Set random seed
|
108 |
+
torch.manual_seed(seed)
|
109 |
+
np.random.seed(seed)
|
110 |
|
111 |
+
# Create a simple SVG without diffvg for now
|
112 |
+
# This is a placeholder implementation
|
113 |
+
svg_content = self._generate_simple_svg(prompt, width, height, num_paths)
|
114 |
|
115 |
+
# Convert SVG to base64 for transmission
|
116 |
+
svg_b64 = base64.b64encode(svg_content.encode()).decode()
|
117 |
+
|
118 |
+
return [{
|
119 |
+
"svg": svg_content,
|
120 |
+
"svg_base64": svg_b64,
|
121 |
"prompt": prompt,
|
122 |
"parameters": {
|
123 |
"num_paths": num_paths,
|
124 |
+
"num_iter": num_iter,
|
125 |
+
"guidance_scale": guidance_scale,
|
126 |
"width": width,
|
127 |
+
"height": height,
|
128 |
+
"seed": seed
|
129 |
}
|
130 |
+
}]
|
131 |
|
132 |
except Exception as e:
|
133 |
+
return [{"error": f"Generation failed: {str(e)}"}]
|
134 |
|
135 |
+
def _generate_simple_svg(self, prompt: str, width: int, height: int, num_paths: int) -> str:
|
136 |
+
"""
|
137 |
+
Generate a simple SVG as placeholder
|
138 |
+
This should be replaced with actual DiffSketcher generation when diffvg is available
|
139 |
+
"""
|
140 |
+
# Create a simple SVG with random paths based on the prompt
|
141 |
+
svg_header = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">'
|
142 |
+
svg_footer = '</svg>'
|
143 |
|
144 |
+
# Generate some simple paths based on prompt keywords
|
145 |
+
paths = []
|
146 |
+
colors = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7", "#DDA0DD"]
|
147 |
|
148 |
+
# Simple heuristic based on prompt
|
149 |
+
if "circle" in prompt.lower() or "round" in prompt.lower():
|
150 |
+
for i in range(min(num_paths // 4, 10)):
|
151 |
+
cx = np.random.randint(20, width - 20)
|
152 |
+
cy = np.random.randint(20, height - 20)
|
153 |
+
r = np.random.randint(5, 30)
|
154 |
+
color = np.random.choice(colors)
|
155 |
+
paths.append(f'<circle cx="{cx}" cy="{cy}" r="{r}" fill="{color}" opacity="0.7"/>')
|
156 |
else:
|
157 |
+
# Generate random paths
|
158 |
+
for i in range(min(num_paths // 10, 20)):
|
159 |
+
x1, y1 = np.random.randint(0, width), np.random.randint(0, height)
|
160 |
+
x2, y2 = np.random.randint(0, width), np.random.randint(0, height)
|
161 |
+
color = np.random.choice(colors)
|
162 |
+
stroke_width = np.random.randint(1, 5)
|
163 |
+
paths.append(f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{color}" stroke-width="{stroke_width}" opacity="0.7"/>')
|
164 |
|
165 |
+
svg_content = svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
|
166 |
+
return svg_content
|
167 |
+
|
168 |
+
# Create handler instance
|
169 |
+
handler = DiffSketcherHandler()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|