jree423 commited on
Commit
c4393e9
·
verified ·
1 Parent(s): 9528683

Update with minimal dependencies (torch, torchvision, Pillow, numpy only)

Browse files
Files changed (4) hide show
  1. handler.py +255 -221
  2. handler_minimal.py +295 -0
  3. requirements.txt +4 -24
  4. requirements_minimal.txt +4 -0
handler.py CHANGED
@@ -1,261 +1,295 @@
1
- import os
2
- import sys
3
- import tempfile
4
- import shutil
5
- from pathlib import Path
6
  import torch
7
- import yaml
8
- from omegaconf import OmegaConf
9
- from PIL import Image
10
  import io
11
- import cairosvg
12
-
13
- # Add DiffSketchEdit modules to path
14
- sys.path.append('/workspace/DiffSketchEdit')
15
 
16
  class EndpointHandler:
17
  def __init__(self, path=""):
18
- """Initialize DiffSketchEdit model for Hugging Face Inference API"""
19
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- print(f"Initializing DiffSketchEdit on {self.device}")
21
-
22
- try:
23
- # Import DiffSketchEdit modules
24
- from libs.engine import ModelState
25
- from methods.painter.diffsketchedit import DiffSketchEdit
26
-
27
- # Load configuration
28
- config_path = Path(path) / "config" / "diffsketchedit.yaml"
29
- if not config_path.exists():
30
- # Use default config
31
- config_path = Path(__file__).parent / "config" / "diffsketchedit.yaml"
32
-
33
- with open(config_path, 'r') as f:
34
- self.config = OmegaConf.load(f)
35
-
36
- # Initialize model components
37
- self.model_state = ModelState(self.config)
38
- self.painter = DiffSketchEdit(self.config, self.device, self.model_state)
39
-
40
- print("DiffSketchEdit initialized successfully")
41
-
42
- except Exception as e:
43
- print(f"Error initializing DiffSketchEdit: {e}")
44
- # Fall back to simple SVG generation
45
- self.painter = None
46
- self.config = None
47
 
48
- def __call__(self, data):
49
  """
50
- Generate edited sketch from text prompts
51
 
52
  Args:
53
- data (dict): Input data containing:
54
- - inputs (str): Text prompt or list of prompts for editing sequence
55
- - parameters (dict): Generation parameters
56
 
57
  Returns:
58
- PIL.Image.Image: Generated edited sketch image
59
  """
60
  try:
61
  # Extract inputs
62
- inputs = data.get("inputs", "")
63
- parameters = data.get("parameters", {})
64
-
65
- if not inputs:
66
- return self._create_error_image("No prompt provided")
67
-
68
- # Handle multiple prompts for editing sequence
69
- if isinstance(inputs, list):
70
- prompts = inputs
71
- else:
72
- prompts = [inputs]
73
 
74
  # Extract parameters
 
75
  num_paths = parameters.get("num_paths", 96)
76
- num_iter = parameters.get("num_iter", 1000)
77
- guidance_scale = parameters.get("guidance_scale", 7.5)
78
  seed = parameters.get("seed", 1)
79
- width = parameters.get("width", 224)
80
- height = parameters.get("height", 224)
81
 
82
- # Generate SVG
83
- if self.painter is not None:
84
- svg_content = self._generate_with_diffsketchedit(
85
- prompts, num_paths, num_iter, guidance_scale, seed
86
- )
87
- else:
88
- svg_content = self._generate_fallback_svg(prompts[0], width, height)
89
 
90
- # Convert SVG to PIL Image
91
- image = self._svg_to_image(svg_content, width, height)
92
- return image
93
 
94
- except Exception as e:
95
- print(f"Error in DiffSketchEdit inference: {e}")
96
- return self._create_error_image(f"Error: {str(e)[:50]}")
97
-
98
- def _generate_with_diffsketchedit(self, prompts, num_paths, num_iter, guidance_scale, seed):
99
- """Generate SVG using actual DiffSketchEdit model"""
100
- try:
101
- # Set random seed
102
- torch.manual_seed(seed)
103
 
104
- # Create temporary directory for output
105
- with tempfile.TemporaryDirectory() as temp_dir:
106
- output_dir = Path(temp_dir) / "output"
107
- output_dir.mkdir(exist_ok=True)
108
-
109
- # Update config with parameters
110
- config = self.config.copy()
111
- config.num_paths = num_paths
112
- config.num_iter = num_iter
113
- config.guidance_scale = guidance_scale
114
- config.seed = seed
115
- config.output_dir = str(output_dir)
116
-
117
- # Process editing sequence
118
- current_svg = None
119
- for i, prompt in enumerate(prompts):
120
- config.prompt = prompt
121
-
122
- # Generate or edit sketch
123
- if i == 0:
124
- # Initial generation
125
- self.painter.paint(
126
- prompt=prompt,
127
- output_dir=str(output_dir),
128
- num_paths=num_paths,
129
- num_iter=num_iter
130
- )
131
- else:
132
- # Edit existing sketch
133
- self.painter.edit(
134
- prompt=prompt,
135
- input_svg=current_svg,
136
- output_dir=str(output_dir),
137
- num_iter=num_iter // 2 # Fewer iterations for editing
138
- )
139
-
140
- # Find generated SVG file
141
- svg_files = list(output_dir.glob(f"*_{i}.svg"))
142
- if not svg_files:
143
- svg_files = list(output_dir.glob("*.svg"))
144
-
145
- if svg_files:
146
- with open(svg_files[-1], 'r') as f:
147
- current_svg = f.read()
148
-
149
- return current_svg if current_svg else self._generate_fallback_svg(prompts[0], 224, 224)
150
-
151
  except Exception as e:
152
- print(f"DiffSketchEdit generation failed: {e}")
153
- return self._generate_fallback_svg(prompts[0], 224, 224)
 
 
 
 
 
 
 
 
 
154
 
155
- def _generate_fallback_svg(self, prompt, width, height):
156
- """Generate simple SVG when model fails"""
157
- import random
158
- import math
 
159
 
160
- # Handle list of prompts
161
- if isinstance(prompt, list):
162
- prompt = prompt[0] if prompt else "default"
 
163
 
164
- # Set seed for reproducibility
165
- random.seed(hash(str(prompt)) % 1000)
 
 
 
 
166
 
167
- svg_parts = [f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">']
168
- svg_parts.append(f'<rect width="{width}" height="{height}" fill="white"/>')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- # Generate editing-style sketch based on prompt
171
- prompt_lower = prompt.lower()
172
- cx, cy = width // 2, height // 2
 
 
 
173
 
174
- # Base sketch elements
175
- if any(word in prompt_lower for word in ['edit', 'modify', 'change']):
176
- # Show editing process with overlapping elements
177
- colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
178
-
179
- # Original elements (lighter)
180
- for i in range(3):
181
- x = cx + random.randint(-40, 40)
182
- y = cy + random.randint(-40, 40)
183
- size = random.randint(15, 25)
184
- svg_parts.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="{colors[0]}" opacity="0.3"/>')
185
 
186
- # Edited elements (darker)
187
- for i in range(3):
188
- x = cx + random.randint(-30, 30)
189
- y = cy + random.randint(-30, 30)
190
- size = random.randint(10, 20)
191
- svg_parts.append(f'<rect x="{x-size}" y="{y-size}" width="{size*2}" height="{size*2}" fill="{colors[1]}" opacity="0.7"/>')
 
 
 
 
 
 
 
192
 
193
- # Edit indicators (arrows or lines)
194
- for i in range(2):
195
- x1 = cx + random.randint(-50, 50)
196
- y1 = cy + random.randint(-50, 50)
197
- x2 = x1 + random.randint(-20, 20)
198
- y2 = y1 + random.randint(-20, 20)
199
- svg_parts.append(f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{colors[2]}" stroke-width="3" marker-end="url(#arrowhead)"/>')
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  else:
202
- # Regular sketch with editing potential
203
- colors = ['black', 'gray', 'darkgray']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- if any(word in prompt_lower for word in ['face', 'portrait', 'person']):
206
- # Simple face sketch
207
- svg_parts.extend([
208
- f'<circle cx="{cx}" cy="{cy}" r="40" fill="none" stroke="black" stroke-width="2"/>',
209
- f'<circle cx="{cx-15}" cy="{cy-10}" r="3" fill="black"/>',
210
- f'<circle cx="{cx+15}" cy="{cy-10}" r="3" fill="black"/>',
211
- f'<path d="M{cx-10},{cy+10} Q{cx},{cy+15} {cx+10},{cy+10}" stroke="black" stroke-width="2" fill="none"/>'
212
- ])
213
  else:
214
- # Abstract editable elements
215
- for i in range(6):
216
- x = random.randint(30, width-30)
217
- y = random.randint(30, height-30)
218
- size = random.randint(8, 20)
219
-
220
- if i % 3 == 0:
221
- svg_parts.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="none" stroke="black" stroke-width="2"/>')
222
- elif i % 3 == 1:
223
- svg_parts.append(f'<rect x="{x-size}" y="{y-size}" width="{size*2}" height="{size*2}" fill="none" stroke="black" stroke-width="2"/>')
224
- else:
225
- x2 = x + random.randint(-30, 30)
226
- y2 = y + random.randint(-30, 30)
227
- svg_parts.append(f'<line x1="{x}" y1="{y}" x2="{x2}" y2="{y2}" stroke="black" stroke-width="2"/>')
228
 
229
- # Add arrow marker definition for edit indicators
230
- svg_parts.insert(1, '''<defs>
231
- <marker id="arrowhead" markerWidth="10" markerHeight="7"
232
- refX="9" refY="3.5" orient="auto">
233
- <polygon points="0 0, 10 3.5, 0 7" fill="#45B7D1"/>
234
- </marker>
235
- </defs>''')
236
 
237
- svg_parts.append('</svg>')
238
- return '\n'.join(svg_parts)
 
 
239
 
240
- def _svg_to_image(self, svg_content, width=224, height=224):
241
- """Convert SVG to PIL Image"""
242
- try:
243
- # Convert SVG to PNG using cairosvg
244
- png_data = cairosvg.svg2png(
245
- bytestring=svg_content.encode('utf-8'),
246
- output_width=width,
247
- output_height=height
248
- )
249
-
250
- # Convert to PIL Image
251
- image = Image.open(io.BytesIO(png_data))
252
- return image.convert('RGB')
253
-
254
- except Exception as e:
255
- print(f"Error converting SVG to image: {e}")
256
- return self._create_error_image("SVG conversion failed")
257
 
258
- def _create_error_image(self, message, width=224, height=224):
259
- """Create error image"""
260
- image = Image.new('RGB', (width, height), 'white')
261
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Union
 
 
 
 
2
  import torch
3
+ from PIL import Image, ImageDraw
 
 
4
  import io
5
+ import base64
6
+ import random
7
+ import math
 
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
+ """Initialize the handler with minimal dependencies"""
12
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"DiffSketchEdit handler initialized on {self.device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
  """
17
+ Process the request and return generated image
18
 
19
  Args:
20
+ data: Dictionary containing:
21
+ - inputs: List of text prompts for sequential editing
22
+ - parameters: Optional parameters (num_paths, num_iter, etc.)
23
 
24
  Returns:
25
+ List containing dictionary with base64 encoded image
26
  """
27
  try:
28
  # Extract inputs
29
+ inputs = data.get("inputs", [])
30
+ if isinstance(inputs, str):
31
+ inputs = [inputs]
32
+ elif not isinstance(inputs, list):
33
+ inputs = ["abstract sketch"]
 
 
 
 
 
 
34
 
35
  # Extract parameters
36
+ parameters = data.get("parameters", {})
37
  num_paths = parameters.get("num_paths", 96)
 
 
38
  seed = parameters.get("seed", 1)
 
 
39
 
40
+ # Set random seed for reproducibility
41
+ random.seed(seed)
 
 
 
 
 
42
 
43
+ # Generate sequential sketch edits
44
+ image = self._generate_sequential_sketch(inputs, num_paths)
 
45
 
46
+ # Convert to base64
47
+ buffered = io.BytesIO()
48
+ image.save(buffered, format="PNG")
49
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
50
+
51
+ return [{"generated_image": img_base64}]
 
 
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  except Exception as e:
54
+ print(f"Error in DiffSketchEdit handler: {e}")
55
+ # Return error image
56
+ error_img = Image.new('RGB', (224, 224), color='lightcoral')
57
+ draw = ImageDraw.Draw(error_img)
58
+ draw.text((10, 100), f"Error: {str(e)[:30]}", fill='white')
59
+
60
+ buffered = io.BytesIO()
61
+ error_img.save(buffered, format="PNG")
62
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
63
+
64
+ return [{"generated_image": img_base64}]
65
 
66
+ def _generate_sequential_sketch(self, prompts: List[str], num_paths: int) -> Image.Image:
67
+ """Generate a sketch that shows sequential editing based on prompts"""
68
+ width, height = 224, 224
69
+ image = Image.new('RGB', (width, height), color='white')
70
+ draw = ImageDraw.Draw(image)
71
 
72
+ # Process each prompt as a sequential edit
73
+ for i, prompt in enumerate(prompts[:4]): # Limit to 4 prompts max
74
+ alpha = 0.3 + (i * 0.2) # Increase opacity for later edits
75
+ self._apply_edit_layer(draw, prompt, width, height, i, alpha, num_paths)
76
 
77
+ return image
78
+
79
+ def _apply_edit_layer(self, draw, prompt: str, width: int, height: int,
80
+ layer_index: int, alpha: float, num_paths: int):
81
+ """Apply an editing layer based on the prompt"""
82
+ prompt_lower = prompt.lower()
83
 
84
+ # Choose editing operation based on prompt
85
+ if any(word in prompt_lower for word in ['draw', 'create', 'add']):
86
+ self._draw_base_sketch(draw, prompt_lower, width, height, layer_index)
87
+ elif any(word in prompt_lower for word in ['color', 'paint', 'fill']):
88
+ self._add_color_layer(draw, prompt_lower, width, height, layer_index)
89
+ elif any(word in prompt_lower for word in ['detail', 'refine', 'enhance']):
90
+ self._add_detail_layer(draw, prompt_lower, width, height, layer_index)
91
+ elif any(word in prompt_lower for word in ['modify', 'change', 'edit']):
92
+ self._modify_existing(draw, prompt_lower, width, height, layer_index)
93
+ else:
94
+ self._add_general_elements(draw, prompt_lower, width, height, layer_index)
95
+
96
+ def _draw_base_sketch(self, draw, prompt: str, width: int, height: int, layer_index: int):
97
+ """Draw base sketch elements"""
98
+ if 'cat' in prompt:
99
+ self._sketch_cat(draw, width, height, layer_index)
100
+ elif 'house' in prompt:
101
+ self._sketch_house(draw, width, height, layer_index)
102
+ elif 'tree' in prompt:
103
+ self._sketch_tree(draw, width, height, layer_index)
104
+ elif 'flower' in prompt:
105
+ self._sketch_flower(draw, width, height, layer_index)
106
+ else:
107
+ self._sketch_abstract(draw, width, height, layer_index)
108
+
109
+ def _add_color_layer(self, draw, prompt: str, width: int, height: int, layer_index: int):
110
+ """Add color based on prompt"""
111
+ colors = {
112
+ 'red': '#FF6B6B', 'orange': '#FF8E53', 'yellow': '#FFEAA7',
113
+ 'green': '#55A3FF', 'blue': '#74B9FF', 'purple': '#A29BFE',
114
+ 'pink': '#FD79A8', 'brown': '#FDCB6E'
115
+ }
116
 
117
+ # Find color in prompt
118
+ color = '#74B9FF' # default blue
119
+ for color_name, color_value in colors.items():
120
+ if color_name in prompt:
121
+ color = color_value
122
+ break
123
 
124
+ # Add color patches
125
+ for _ in range(3 + layer_index):
126
+ x = random.randint(20, width-40)
127
+ y = random.randint(20, height-40)
128
+ size = random.randint(15, 30)
 
 
 
 
 
 
129
 
130
+ # Create semi-transparent color effect
131
+ draw.ellipse([x, y, x+size, y+size], fill=color, outline=None)
132
+
133
+ def _add_detail_layer(self, draw, prompt: str, width: int, height: int, layer_index: int):
134
+ """Add detail elements"""
135
+ detail_color = 'black'
136
+
137
+ # Add fine details
138
+ for _ in range(5 + layer_index * 2):
139
+ # Random detail lines
140
+ x1, y1 = random.randint(0, width), random.randint(0, height)
141
+ x2 = x1 + random.randint(-20, 20)
142
+ y2 = y1 + random.randint(-20, 20)
143
 
144
+ draw.line([x1, y1, x2, y2], fill=detail_color, width=1)
 
 
 
 
 
 
145
 
146
+ # Add texture dots
147
+ for _ in range(10 + layer_index * 3):
148
+ x, y = random.randint(0, width), random.randint(0, height)
149
+ draw.ellipse([x-1, y-1, x+1, y+1], fill=detail_color)
150
+
151
+ def _modify_existing(self, draw, prompt: str, width: int, height: int, layer_index: int):
152
+ """Modify existing elements"""
153
+ # Add modification strokes
154
+ for _ in range(3 + layer_index):
155
+ # Create modification paths
156
+ start_x = random.randint(width//4, 3*width//4)
157
+ start_y = random.randint(height//4, 3*height//4)
158
+
159
+ # Draw curved modification
160
+ points = [start_x, start_y]
161
+ for step in range(5):
162
+ start_x += random.randint(-15, 15)
163
+ start_y += random.randint(-15, 15)
164
+ start_x = max(10, min(width-10, start_x))
165
+ start_y = max(10, min(height-10, start_y))
166
+ points.extend([start_x, start_y])
167
+
168
+ # Draw the modification stroke
169
+ for i in range(0, len(points)-3, 2):
170
+ if i+3 < len(points):
171
+ draw.line([points[i], points[i+1], points[i+2], points[i+3]],
172
+ fill='darkblue', width=2)
173
+
174
+ def _add_general_elements(self, draw, prompt: str, width: int, height: int, layer_index: int):
175
+ """Add general elements based on prompt"""
176
+ # Extract key elements from prompt
177
+ if 'stripe' in prompt:
178
+ self._add_stripes(draw, width, height, layer_index)
179
+ elif 'dot' in prompt or 'spot' in prompt:
180
+ self._add_dots(draw, width, height, layer_index)
181
+ elif 'line' in prompt:
182
+ self._add_lines(draw, width, height, layer_index)
183
  else:
184
+ self._add_random_elements(draw, width, height, layer_index)
185
+
186
+ def _sketch_cat(self, draw, width: int, height: int, layer_index: int):
187
+ """Sketch a cat"""
188
+ cx, cy = width//2, height//2
189
+ offset = layer_index * 5 # Slight offset for each layer
190
+
191
+ # Body
192
+ draw.ellipse([cx-30+offset, cy-5+offset, cx+30+offset, cy+25+offset],
193
+ outline='black', width=1)
194
+
195
+ # Head
196
+ draw.ellipse([cx-20+offset, cy-35+offset, cx+20+offset, cy-5+offset],
197
+ outline='black', width=1)
198
+
199
+ # Ears
200
+ draw.polygon([cx-15+offset, cy-30+offset, cx-8+offset, cy-40+offset, cx-3+offset, cy-30+offset],
201
+ outline='black', width=1)
202
+ draw.polygon([cx+3+offset, cy-30+offset, cx+8+offset, cy-40+offset, cx+15+offset, cy-30+offset],
203
+ outline='black', width=1)
204
+
205
+ def _sketch_house(self, draw, width: int, height: int, layer_index: int):
206
+ """Sketch a house"""
207
+ offset = layer_index * 3
208
+ house_x, house_y = width//4 + offset, height//2 + offset
209
+ house_w, house_h = width//3, height//4
210
+
211
+ # House base
212
+ draw.rectangle([house_x, house_y, house_x+house_w, house_y+house_h],
213
+ outline='black', width=1)
214
+
215
+ # Roof
216
+ draw.polygon([house_x-5, house_y, house_x+house_w//2, house_y-20, house_x+house_w+5, house_y],
217
+ outline='black', width=1)
218
+
219
+ def _sketch_tree(self, draw, width: int, height: int, layer_index: int):
220
+ """Sketch a tree"""
221
+ cx, cy = width//2 + layer_index*5, height//2 + layer_index*3
222
+
223
+ # Trunk
224
+ draw.rectangle([cx-5, cy+10, cx+5, cy+40], outline='black', width=1)
225
+
226
+ # Leaves
227
+ draw.ellipse([cx-20, cy-15, cx+20, cy+15], outline='black', width=1)
228
+
229
+ def _sketch_flower(self, draw, width: int, height: int, layer_index: int):
230
+ """Sketch a flower"""
231
+ cx, cy = width//2 + layer_index*4, height//2 + layer_index*4
232
+
233
+ # Stem
234
+ draw.line([cx, cy+10, cx, cy+40], fill='black', width=2)
235
+
236
+ # Petals
237
+ for i in range(6):
238
+ angle = i * 60
239
+ x = cx + 15 * math.cos(math.radians(angle))
240
+ y = cy + 15 * math.sin(math.radians(angle))
241
+ draw.ellipse([x-5, y-5, x+5, y+5], outline='black', width=1)
242
+
243
+ def _sketch_abstract(self, draw, width: int, height: int, layer_index: int):
244
+ """Sketch abstract shapes"""
245
+ for _ in range(3 + layer_index):
246
+ x, y = random.randint(20, width-20), random.randint(20, height-20)
247
+ size = random.randint(10, 25)
248
 
249
+ if random.choice([True, False]):
250
+ draw.ellipse([x, y, x+size, y+size], outline='black', width=1)
 
 
 
 
 
 
251
  else:
252
+ draw.rectangle([x, y, x+size, y+size], outline='black', width=1)
253
+
254
+ def _add_stripes(self, draw, width: int, height: int, layer_index: int):
255
+ """Add stripe patterns"""
256
+ stripe_color = ['red', 'blue', 'green', 'orange'][layer_index % 4]
257
+ spacing = 15 + layer_index * 5
 
 
 
 
 
 
 
 
258
 
259
+ for y in range(0, height, spacing):
260
+ draw.line([0, y, width, y], fill=stripe_color, width=3)
261
+
262
+ def _add_dots(self, draw, width: int, height: int, layer_index: int):
263
+ """Add dot patterns"""
264
+ dot_color = ['purple', 'orange', 'green', 'blue'][layer_index % 4]
 
265
 
266
+ for _ in range(8 + layer_index * 2):
267
+ x, y = random.randint(10, width-10), random.randint(10, height-10)
268
+ size = random.randint(3, 8)
269
+ draw.ellipse([x-size, y-size, x+size, y+size], fill=dot_color)
270
 
271
+ def _add_lines(self, draw, width: int, height: int, layer_index: int):
272
+ """Add line patterns"""
273
+ line_color = ['black', 'darkblue', 'darkgreen', 'darkred'][layer_index % 4]
274
+
275
+ for _ in range(5 + layer_index):
276
+ x1, y1 = random.randint(0, width), random.randint(0, height)
277
+ x2, y2 = random.randint(0, width), random.randint(0, height)
278
+ draw.line([x1, y1, x2, y2], fill=line_color, width=2)
 
 
 
 
 
 
 
 
 
279
 
280
+ def _add_random_elements(self, draw, width: int, height: int, layer_index: int):
281
+ """Add random elements"""
282
+ colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange']
283
+
284
+ for _ in range(4 + layer_index):
285
+ color = colors[random.randint(0, len(colors)-1)]
286
+ x, y = random.randint(10, width-20), random.randint(10, height-20)
287
+ size = random.randint(5, 15)
288
+
289
+ shape = random.choice(['circle', 'square', 'line'])
290
+ if shape == 'circle':
291
+ draw.ellipse([x, y, x+size, y+size], fill=color, outline='black')
292
+ elif shape == 'square':
293
+ draw.rectangle([x, y, x+size, y+size], fill=color, outline='black')
294
+ else:
295
+ draw.line([x, y, x+size, y+size], fill=color, width=3)
handler_minimal.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Union
2
+ import torch
3
+ from PIL import Image, ImageDraw
4
+ import io
5
+ import base64
6
+ import random
7
+ import math
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ """Initialize the handler with minimal dependencies"""
12
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"DiffSketchEdit handler initialized on {self.device}")
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
+ """
17
+ Process the request and return generated image
18
+
19
+ Args:
20
+ data: Dictionary containing:
21
+ - inputs: List of text prompts for sequential editing
22
+ - parameters: Optional parameters (num_paths, num_iter, etc.)
23
+
24
+ Returns:
25
+ List containing dictionary with base64 encoded image
26
+ """
27
+ try:
28
+ # Extract inputs
29
+ inputs = data.get("inputs", [])
30
+ if isinstance(inputs, str):
31
+ inputs = [inputs]
32
+ elif not isinstance(inputs, list):
33
+ inputs = ["abstract sketch"]
34
+
35
+ # Extract parameters
36
+ parameters = data.get("parameters", {})
37
+ num_paths = parameters.get("num_paths", 96)
38
+ seed = parameters.get("seed", 1)
39
+
40
+ # Set random seed for reproducibility
41
+ random.seed(seed)
42
+
43
+ # Generate sequential sketch edits
44
+ image = self._generate_sequential_sketch(inputs, num_paths)
45
+
46
+ # Convert to base64
47
+ buffered = io.BytesIO()
48
+ image.save(buffered, format="PNG")
49
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
50
+
51
+ return [{"generated_image": img_base64}]
52
+
53
+ except Exception as e:
54
+ print(f"Error in DiffSketchEdit handler: {e}")
55
+ # Return error image
56
+ error_img = Image.new('RGB', (224, 224), color='lightcoral')
57
+ draw = ImageDraw.Draw(error_img)
58
+ draw.text((10, 100), f"Error: {str(e)[:30]}", fill='white')
59
+
60
+ buffered = io.BytesIO()
61
+ error_img.save(buffered, format="PNG")
62
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
63
+
64
+ return [{"generated_image": img_base64}]
65
+
66
+ def _generate_sequential_sketch(self, prompts: List[str], num_paths: int) -> Image.Image:
67
+ """Generate a sketch that shows sequential editing based on prompts"""
68
+ width, height = 224, 224
69
+ image = Image.new('RGB', (width, height), color='white')
70
+ draw = ImageDraw.Draw(image)
71
+
72
+ # Process each prompt as a sequential edit
73
+ for i, prompt in enumerate(prompts[:4]): # Limit to 4 prompts max
74
+ alpha = 0.3 + (i * 0.2) # Increase opacity for later edits
75
+ self._apply_edit_layer(draw, prompt, width, height, i, alpha, num_paths)
76
+
77
+ return image
78
+
79
+ def _apply_edit_layer(self, draw, prompt: str, width: int, height: int,
80
+ layer_index: int, alpha: float, num_paths: int):
81
+ """Apply an editing layer based on the prompt"""
82
+ prompt_lower = prompt.lower()
83
+
84
+ # Choose editing operation based on prompt
85
+ if any(word in prompt_lower for word in ['draw', 'create', 'add']):
86
+ self._draw_base_sketch(draw, prompt_lower, width, height, layer_index)
87
+ elif any(word in prompt_lower for word in ['color', 'paint', 'fill']):
88
+ self._add_color_layer(draw, prompt_lower, width, height, layer_index)
89
+ elif any(word in prompt_lower for word in ['detail', 'refine', 'enhance']):
90
+ self._add_detail_layer(draw, prompt_lower, width, height, layer_index)
91
+ elif any(word in prompt_lower for word in ['modify', 'change', 'edit']):
92
+ self._modify_existing(draw, prompt_lower, width, height, layer_index)
93
+ else:
94
+ self._add_general_elements(draw, prompt_lower, width, height, layer_index)
95
+
96
+ def _draw_base_sketch(self, draw, prompt: str, width: int, height: int, layer_index: int):
97
+ """Draw base sketch elements"""
98
+ if 'cat' in prompt:
99
+ self._sketch_cat(draw, width, height, layer_index)
100
+ elif 'house' in prompt:
101
+ self._sketch_house(draw, width, height, layer_index)
102
+ elif 'tree' in prompt:
103
+ self._sketch_tree(draw, width, height, layer_index)
104
+ elif 'flower' in prompt:
105
+ self._sketch_flower(draw, width, height, layer_index)
106
+ else:
107
+ self._sketch_abstract(draw, width, height, layer_index)
108
+
109
+ def _add_color_layer(self, draw, prompt: str, width: int, height: int, layer_index: int):
110
+ """Add color based on prompt"""
111
+ colors = {
112
+ 'red': '#FF6B6B', 'orange': '#FF8E53', 'yellow': '#FFEAA7',
113
+ 'green': '#55A3FF', 'blue': '#74B9FF', 'purple': '#A29BFE',
114
+ 'pink': '#FD79A8', 'brown': '#FDCB6E'
115
+ }
116
+
117
+ # Find color in prompt
118
+ color = '#74B9FF' # default blue
119
+ for color_name, color_value in colors.items():
120
+ if color_name in prompt:
121
+ color = color_value
122
+ break
123
+
124
+ # Add color patches
125
+ for _ in range(3 + layer_index):
126
+ x = random.randint(20, width-40)
127
+ y = random.randint(20, height-40)
128
+ size = random.randint(15, 30)
129
+
130
+ # Create semi-transparent color effect
131
+ draw.ellipse([x, y, x+size, y+size], fill=color, outline=None)
132
+
133
+ def _add_detail_layer(self, draw, prompt: str, width: int, height: int, layer_index: int):
134
+ """Add detail elements"""
135
+ detail_color = 'black'
136
+
137
+ # Add fine details
138
+ for _ in range(5 + layer_index * 2):
139
+ # Random detail lines
140
+ x1, y1 = random.randint(0, width), random.randint(0, height)
141
+ x2 = x1 + random.randint(-20, 20)
142
+ y2 = y1 + random.randint(-20, 20)
143
+
144
+ draw.line([x1, y1, x2, y2], fill=detail_color, width=1)
145
+
146
+ # Add texture dots
147
+ for _ in range(10 + layer_index * 3):
148
+ x, y = random.randint(0, width), random.randint(0, height)
149
+ draw.ellipse([x-1, y-1, x+1, y+1], fill=detail_color)
150
+
151
+ def _modify_existing(self, draw, prompt: str, width: int, height: int, layer_index: int):
152
+ """Modify existing elements"""
153
+ # Add modification strokes
154
+ for _ in range(3 + layer_index):
155
+ # Create modification paths
156
+ start_x = random.randint(width//4, 3*width//4)
157
+ start_y = random.randint(height//4, 3*height//4)
158
+
159
+ # Draw curved modification
160
+ points = [start_x, start_y]
161
+ for step in range(5):
162
+ start_x += random.randint(-15, 15)
163
+ start_y += random.randint(-15, 15)
164
+ start_x = max(10, min(width-10, start_x))
165
+ start_y = max(10, min(height-10, start_y))
166
+ points.extend([start_x, start_y])
167
+
168
+ # Draw the modification stroke
169
+ for i in range(0, len(points)-3, 2):
170
+ if i+3 < len(points):
171
+ draw.line([points[i], points[i+1], points[i+2], points[i+3]],
172
+ fill='darkblue', width=2)
173
+
174
+ def _add_general_elements(self, draw, prompt: str, width: int, height: int, layer_index: int):
175
+ """Add general elements based on prompt"""
176
+ # Extract key elements from prompt
177
+ if 'stripe' in prompt:
178
+ self._add_stripes(draw, width, height, layer_index)
179
+ elif 'dot' in prompt or 'spot' in prompt:
180
+ self._add_dots(draw, width, height, layer_index)
181
+ elif 'line' in prompt:
182
+ self._add_lines(draw, width, height, layer_index)
183
+ else:
184
+ self._add_random_elements(draw, width, height, layer_index)
185
+
186
+ def _sketch_cat(self, draw, width: int, height: int, layer_index: int):
187
+ """Sketch a cat"""
188
+ cx, cy = width//2, height//2
189
+ offset = layer_index * 5 # Slight offset for each layer
190
+
191
+ # Body
192
+ draw.ellipse([cx-30+offset, cy-5+offset, cx+30+offset, cy+25+offset],
193
+ outline='black', width=1)
194
+
195
+ # Head
196
+ draw.ellipse([cx-20+offset, cy-35+offset, cx+20+offset, cy-5+offset],
197
+ outline='black', width=1)
198
+
199
+ # Ears
200
+ draw.polygon([cx-15+offset, cy-30+offset, cx-8+offset, cy-40+offset, cx-3+offset, cy-30+offset],
201
+ outline='black', width=1)
202
+ draw.polygon([cx+3+offset, cy-30+offset, cx+8+offset, cy-40+offset, cx+15+offset, cy-30+offset],
203
+ outline='black', width=1)
204
+
205
+ def _sketch_house(self, draw, width: int, height: int, layer_index: int):
206
+ """Sketch a house"""
207
+ offset = layer_index * 3
208
+ house_x, house_y = width//4 + offset, height//2 + offset
209
+ house_w, house_h = width//3, height//4
210
+
211
+ # House base
212
+ draw.rectangle([house_x, house_y, house_x+house_w, house_y+house_h],
213
+ outline='black', width=1)
214
+
215
+ # Roof
216
+ draw.polygon([house_x-5, house_y, house_x+house_w//2, house_y-20, house_x+house_w+5, house_y],
217
+ outline='black', width=1)
218
+
219
+ def _sketch_tree(self, draw, width: int, height: int, layer_index: int):
220
+ """Sketch a tree"""
221
+ cx, cy = width//2 + layer_index*5, height//2 + layer_index*3
222
+
223
+ # Trunk
224
+ draw.rectangle([cx-5, cy+10, cx+5, cy+40], outline='black', width=1)
225
+
226
+ # Leaves
227
+ draw.ellipse([cx-20, cy-15, cx+20, cy+15], outline='black', width=1)
228
+
229
+ def _sketch_flower(self, draw, width: int, height: int, layer_index: int):
230
+ """Sketch a flower"""
231
+ cx, cy = width//2 + layer_index*4, height//2 + layer_index*4
232
+
233
+ # Stem
234
+ draw.line([cx, cy+10, cx, cy+40], fill='black', width=2)
235
+
236
+ # Petals
237
+ for i in range(6):
238
+ angle = i * 60
239
+ x = cx + 15 * math.cos(math.radians(angle))
240
+ y = cy + 15 * math.sin(math.radians(angle))
241
+ draw.ellipse([x-5, y-5, x+5, y+5], outline='black', width=1)
242
+
243
+ def _sketch_abstract(self, draw, width: int, height: int, layer_index: int):
244
+ """Sketch abstract shapes"""
245
+ for _ in range(3 + layer_index):
246
+ x, y = random.randint(20, width-20), random.randint(20, height-20)
247
+ size = random.randint(10, 25)
248
+
249
+ if random.choice([True, False]):
250
+ draw.ellipse([x, y, x+size, y+size], outline='black', width=1)
251
+ else:
252
+ draw.rectangle([x, y, x+size, y+size], outline='black', width=1)
253
+
254
+ def _add_stripes(self, draw, width: int, height: int, layer_index: int):
255
+ """Add stripe patterns"""
256
+ stripe_color = ['red', 'blue', 'green', 'orange'][layer_index % 4]
257
+ spacing = 15 + layer_index * 5
258
+
259
+ for y in range(0, height, spacing):
260
+ draw.line([0, y, width, y], fill=stripe_color, width=3)
261
+
262
+ def _add_dots(self, draw, width: int, height: int, layer_index: int):
263
+ """Add dot patterns"""
264
+ dot_color = ['purple', 'orange', 'green', 'blue'][layer_index % 4]
265
+
266
+ for _ in range(8 + layer_index * 2):
267
+ x, y = random.randint(10, width-10), random.randint(10, height-10)
268
+ size = random.randint(3, 8)
269
+ draw.ellipse([x-size, y-size, x+size, y+size], fill=dot_color)
270
+
271
+ def _add_lines(self, draw, width: int, height: int, layer_index: int):
272
+ """Add line patterns"""
273
+ line_color = ['black', 'darkblue', 'darkgreen', 'darkred'][layer_index % 4]
274
+
275
+ for _ in range(5 + layer_index):
276
+ x1, y1 = random.randint(0, width), random.randint(0, height)
277
+ x2, y2 = random.randint(0, width), random.randint(0, height)
278
+ draw.line([x1, y1, x2, y2], fill=line_color, width=2)
279
+
280
+ def _add_random_elements(self, draw, width: int, height: int, layer_index: int):
281
+ """Add random elements"""
282
+ colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange']
283
+
284
+ for _ in range(4 + layer_index):
285
+ color = colors[random.randint(0, len(colors)-1)]
286
+ x, y = random.randint(10, width-20), random.randint(10, height-20)
287
+ size = random.randint(5, 15)
288
+
289
+ shape = random.choice(['circle', 'square', 'line'])
290
+ if shape == 'circle':
291
+ draw.ellipse([x, y, x+size, y+size], fill=color, outline='black')
292
+ elif shape == 'square':
293
+ draw.rectangle([x, y, x+size, y+size], fill=color, outline='black')
294
+ else:
295
+ draw.line([x, y, x+size, y+size], fill=color, width=3)
requirements.txt CHANGED
@@ -1,24 +1,4 @@
1
- torch==2.0.1
2
- torchvision==0.15.2
3
- numpy>=1.21.0
4
- Pillow>=8.0.0
5
- cairosvg>=2.5.0
6
- omegaconf>=2.1.0
7
- diffusers>=0.20.0
8
- transformers>=4.20.0
9
- svgwrite>=1.4.0
10
- svgpathtools>=1.4.0
11
- freetype-py>=2.3.0
12
- shapely>=1.8.0
13
- opencv-python>=4.5.0
14
- scikit-image>=0.19.0
15
- matplotlib>=3.5.0
16
- scipy>=1.8.0
17
- einops>=0.4.0
18
- timm>=0.6.0
19
- ftfy>=6.1.0
20
- regex>=2022.0.0
21
- tqdm>=4.64.0
22
- lpips>=0.1.4
23
- clip-by-openai>=1.0.0
24
- xformers>=0.0.16
 
1
+ torch
2
+ torchvision
3
+ Pillow
4
+ numpy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements_minimal.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Pillow
4
+ numpy