diffsketchedit / handler.py
jree423's picture
Fix: Remove local json import that was causing variable scope issues
9391275 verified
import torch
import torch.nn.functional as F
import numpy as np
import json
import base64
import io
from PIL import Image
import svgwrite
from typing import Dict, Any, List, Optional, Union
import diffusers
from diffusers import StableDiffusionPipeline, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torchvision.transforms as transforms
import random
import math
import re
class EndpointHandler:
def __init__(self, path=""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model_id = "runwayml/stable-diffusion-v1-5"
try:
# Initialize the diffusion pipeline
self.pipe = StableDiffusionPipeline.from_pretrained(
self.model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
safety_checker=None,
requires_safety_checker=False
).to(self.device)
# Use DDIM scheduler for better control
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
# CLIP model for guidance
self.clip_model = self.pipe.text_encoder
self.clip_tokenizer = self.pipe.tokenizer
print("DiffSketchEdit handler initialized successfully!")
except Exception as e:
print(f"Warning: Could not load diffusion model: {e}")
self.pipe = None
self.clip_model = None
self.clip_tokenizer = None
def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image:
"""
Perform sketch editing using DiffSketchEdit approach
"""
try:
# Parse inputs
if isinstance(inputs, str):
# Check if it's a JSON string
try:
parsed_inputs = json.loads(inputs)
if isinstance(parsed_inputs, dict):
inputs = parsed_inputs
else:
# Simple prompt - treat as generation
prompts = [inputs]
edit_type = "generate"
parameters = {}
except:
# Simple prompt - treat as generation
prompts = [inputs]
edit_type = "generate"
parameters = {}
if isinstance(inputs, dict):
input_data = inputs.get("inputs", inputs)
if isinstance(input_data, str):
prompts = [input_data]
edit_type = "generate"
elif isinstance(input_data, dict):
prompts = input_data.get("prompts", [input_data.get("prompt", "a simple sketch")])
edit_type = input_data.get("edit_type", "generate")
else:
prompts = ["a simple sketch"]
edit_type = "generate"
parameters = inputs.get("parameters", {})
# Extract parameters with defaults
width = parameters.get("width", 224)
height = parameters.get("height", 224)
seed = parameters.get("seed", None)
input_svg = parameters.get("input_svg", None)
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
print(f"Processing edit type: '{edit_type}' with prompts: {prompts}")
# Process based on edit type
if edit_type == "replace" and len(prompts) >= 2:
svg_content, metadata = self.word_replacement_edit(prompts[0], prompts[1], width, height, input_svg)
elif edit_type == "refine":
svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
elif edit_type == "reweight":
svg_content, metadata = self.attention_reweighting_edit(prompts[0], width, height, input_svg)
elif edit_type == "generate":
svg_content, metadata = self.simple_generation(prompts[0], width, height)
else:
# Default to refinement
svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
# Convert SVG to PIL Image for HF API compatibility
pil_image = self.svg_to_pil_image(svg_content, width, height)
# Store metadata
pil_image.info['svg_content'] = svg_content
for key, value in metadata.items():
if isinstance(value, (dict, list)):
pil_image.info[key] = json.dumps(value)
else:
pil_image.info[key] = str(value)
return pil_image
except Exception as e:
print(f"Error in handler: {e}")
# Return fallback image
fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height)
fallback_image = self.svg_to_pil_image(fallback_svg, width, height)
fallback_image.info['error'] = str(e)
fallback_image.info['edit_type'] = edit_type
return fallback_image
def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None):
"""Perform word replacement editing"""
try:
print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'")
# Analyze word differences
added_words, removed_words = self.analyze_word_differences(source_prompt, target_prompt)
print(f"Added words: {added_words}, Removed words: {removed_words}")
# Generate or use base SVG
if input_svg:
base_svg = input_svg
else:
base_svg = self.generate_base_svg(source_prompt, width, height)
# Apply word replacement transformations
edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height)
metadata = {
"edit_type": "replace",
"source_prompt": source_prompt,
"target_prompt": target_prompt,
"added_words": list(added_words),
"removed_words": list(removed_words)
}
return edited_svg, metadata
except Exception as e:
print(f"Error in word_replacement_edit: {e}")
fallback_svg = self.create_fallback_svg(source_prompt, width, height)
metadata = {"edit_type": "replace", "error": str(e)}
return fallback_svg, metadata
def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
"""Perform prompt refinement editing"""
try:
print(f"Prompt refinement for: '{prompt}'")
# Generate or use base SVG
if input_svg:
base_svg = input_svg
else:
base_svg = self.generate_base_svg(prompt, width, height)
# Apply refinement based on prompt analysis
refined_svg = self.apply_refinement(base_svg, prompt, width, height)
metadata = {
"edit_type": "refine",
"prompt": prompt
}
return refined_svg, metadata
except Exception as e:
print(f"Error in prompt_refinement_edit: {e}")
fallback_svg = self.create_fallback_svg(prompt, width, height)
metadata = {"edit_type": "refine", "error": str(e)}
return fallback_svg, metadata
def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
"""Perform attention reweighting editing"""
try:
print(f"Attention reweighting for: '{prompt}'")
# Parse attention weights from prompt (e.g., "(cat:1.5)" or "[table:0.5]")
weighted_prompt, attention_weights = self.parse_attention_weights(prompt)
print(f"Weighted prompt: '{weighted_prompt}', weights: {attention_weights}")
# Generate or use base SVG
if input_svg:
base_svg = input_svg
else:
base_svg = self.generate_base_svg(weighted_prompt, width, height)
# Apply attention reweighting
reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height)
metadata = {
"edit_type": "reweight",
"prompt": prompt,
"weighted_prompt": weighted_prompt,
"attention_weights": attention_weights
}
return reweighted_svg, metadata
except Exception as e:
print(f"Error in attention_reweighting_edit: {e}")
fallback_svg = self.create_fallback_svg(prompt, width, height)
metadata = {"edit_type": "reweight", "error": str(e)}
return fallback_svg, metadata
def simple_generation(self, prompt: str, width: int, height: int):
"""Perform simple SVG generation"""
try:
print(f"Simple generation for: '{prompt}'")
svg_content = self.generate_base_svg(prompt, width, height)
metadata = {
"edit_type": "generate",
"prompt": prompt
}
return svg_content, metadata
except Exception as e:
print(f"Error in simple_generation: {e}")
fallback_svg = self.create_fallback_svg(prompt, width, height)
metadata = {"edit_type": "generate", "error": str(e)}
return fallback_svg, metadata
def generate_base_svg(self, prompt: str, width: int, height: int):
"""Generate base SVG from prompt"""
dwg = svgwrite.Drawing(size=(width, height))
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
# Extract semantic features
features = self.extract_semantic_features(prompt)
# Generate content based on prompt
if any(word in prompt.lower() for word in ['person', 'people', 'human', 'man', 'woman']):
self.add_person_elements(dwg, width, height, features)
elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog', 'bird', 'horse']):
self.add_animal_elements(dwg, width, height, features)
elif any(word in prompt.lower() for word in ['house', 'building', 'architecture']):
self.add_building_elements(dwg, width, height, features)
elif any(word in prompt.lower() for word in ['tree', 'nature', 'landscape']):
self.add_nature_elements(dwg, width, height, features)
elif any(word in prompt.lower() for word in ['car', 'vehicle', 'transport']):
self.add_vehicle_elements(dwg, width, height, features)
else:
self.add_abstract_elements(dwg, width, height, features)
return dwg.tostring()
def analyze_word_differences(self, source: str, target: str):
"""Analyze differences between source and target prompts"""
source_words = set(source.lower().split())
target_words = set(target.lower().split())
added_words = target_words - source_words
removed_words = source_words - target_words
return added_words, removed_words
def parse_attention_weights(self, prompt: str):
"""Parse attention weights from prompt"""
# Pattern for (word:weight) - increase attention
increase_pattern = r'\(([^:]+):([0-9.]+)\)'
# Pattern for [word:weight] - decrease attention
decrease_pattern = r'\[([^:]+):([0-9.]+)\]'
attention_weights = {}
weighted_prompt = prompt
# Find increase weights
for match in re.finditer(increase_pattern, prompt):
word = match.group(1).strip()
weight = float(match.group(2))
attention_weights[word] = weight
# Remove the weight notation from prompt
weighted_prompt = weighted_prompt.replace(match.group(0), word)
# Find decrease weights
for match in re.finditer(decrease_pattern, prompt):
word = match.group(1).strip()
weight = float(match.group(2))
attention_weights[word] = weight
# Remove the weight notation from prompt
weighted_prompt = weighted_prompt.replace(match.group(0), word)
return weighted_prompt.strip(), attention_weights
def apply_word_replacement(self, base_svg: str, source_prompt: str, target_prompt: str,
added_words: set, removed_words: set, width: int, height: int):
"""Apply word replacement transformations to SVG"""
# For now, regenerate with target prompt but keep some base structure
# In a full implementation, this would do more sophisticated editing
# Parse the base SVG to understand its structure
features = self.extract_semantic_features(target_prompt)
# Create new SVG with target prompt characteristics
dwg = svgwrite.Drawing(size=(width, height))
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
# Apply changes based on word differences
if any(word in added_words for word in ['red', 'blue', 'green', 'yellow']):
# Color change
self.add_colored_elements(dwg, width, height, added_words)
elif any(word in added_words for word in ['big', 'large', 'huge']):
# Size change
self.add_large_elements(dwg, width, height, features)
elif any(word in added_words for word in ['small', 'tiny', 'mini']):
# Size change
self.add_small_elements(dwg, width, height, features)
else:
# General content change
self.add_content_based_on_prompt(dwg, target_prompt, width, height)
return dwg.tostring()
def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int):
"""Apply refinement to existing SVG"""
# For now, enhance the base SVG with additional details
features = self.extract_semantic_features(prompt)
dwg = svgwrite.Drawing(size=(width, height))
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
# Add refined elements based on prompt
if features.get('detailed', False):
self.add_detailed_elements(dwg, width, height, features)
else:
self.add_content_based_on_prompt(dwg, prompt, width, height)
return dwg.tostring()
def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int):
"""Apply attention reweighting to SVG"""
dwg = svgwrite.Drawing(size=(width, height))
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
# Apply different emphasis based on attention weights
for word, weight in attention_weights.items():
if weight > 1.0:
# Emphasize this element
self.add_emphasized_element(dwg, word, weight, width, height)
elif weight < 1.0:
# De-emphasize this element
self.add_deemphasized_element(dwg, word, weight, width, height)
# Add base content
self.add_content_based_on_prompt(dwg, prompt, width, height)
return dwg.tostring()
def add_person_elements(self, dwg, width, height, features):
"""Add person-like elements"""
center_x, center_y = width // 2, height // 2
# Head
head_radius = 20
dwg.add(dwg.circle(center=(center_x, center_y - 40), r=head_radius, fill='#FDBCB4', stroke='black', stroke_width=2))
# Body
body_height = 60
body_width = 30
dwg.add(dwg.rect(
insert=(center_x - body_width//2, center_y - 10),
size=(body_width, body_height),
fill='#4A90E2',
stroke='black',
stroke_width=2
))
# Arms
dwg.add(dwg.line(start=(center_x - body_width//2, center_y), end=(center_x - 40, center_y + 20), stroke='black', stroke_width=3))
dwg.add(dwg.line(start=(center_x + body_width//2, center_y), end=(center_x + 40, center_y + 20), stroke='black', stroke_width=3))
# Legs
dwg.add(dwg.line(start=(center_x - 10, center_y + body_height - 10), end=(center_x - 20, center_y + body_height + 30), stroke='black', stroke_width=3))
dwg.add(dwg.line(start=(center_x + 10, center_y + body_height - 10), end=(center_x + 20, center_y + body_height + 30), stroke='black', stroke_width=3))
def add_animal_elements(self, dwg, width, height, features):
"""Add animal-like elements"""
center_x, center_y = width // 2, height // 2
# Body (oval)
dwg.add(dwg.ellipse(center=(center_x, center_y), r=(40, 25), fill='#8B4513', stroke='black', stroke_width=2))
# Head
dwg.add(dwg.circle(center=(center_x - 30, center_y - 10), r=20, fill='#A0522D', stroke='black', stroke_width=2))
# Legs
for i, x_offset in enumerate([-20, -10, 10, 20]):
dwg.add(dwg.line(
start=(center_x + x_offset, center_y + 25),
end=(center_x + x_offset, center_y + 45),
stroke='black',
stroke_width=3
))
# Tail
dwg.add(dwg.path(
d=f"M {center_x + 40},{center_y} Q {center_x + 60},{center_y - 20} {center_x + 50},{center_y - 35}",
stroke='black',
stroke_width=3,
fill='none'
))
def add_building_elements(self, dwg, width, height, features):
"""Add building-like elements"""
# Main building
building_width = width * 0.6
building_height = height * 0.7
x = (width - building_width) // 2
y = height - building_height - 10
dwg.add(dwg.rect(
insert=(x, y),
size=(building_width, building_height),
fill='#CD853F',
stroke='black',
stroke_width=2
))
# Roof
roof_points = [(x, y), (x + building_width//2, y - 30), (x + building_width, y)]
dwg.add(dwg.polygon(points=roof_points, fill='#8B0000', stroke='black', stroke_width=2))
# Windows
window_size = 15
for i in range(3):
for j in range(4):
wx = x + 15 + i * 30
wy = y + 15 + j * 25
if wy < y + building_height - 20:
dwg.add(dwg.rect(
insert=(wx, wy),
size=(window_size, window_size),
fill='#87CEEB',
stroke='black',
stroke_width=1
))
# Door
door_width = 20
door_height = 40
door_x = x + building_width//2 - door_width//2
door_y = y + building_height - door_height
dwg.add(dwg.rect(
insert=(door_x, door_y),
size=(door_width, door_height),
fill='#8B4513',
stroke='black',
stroke_width=2
))
def add_nature_elements(self, dwg, width, height, features):
"""Add nature-like elements"""
# Tree
center_x, center_y = width // 2, height // 2
# Trunk
trunk_width = 15
trunk_height = height // 3
trunk_x = center_x - trunk_width // 2
trunk_y = height - trunk_height - 10
dwg.add(dwg.rect(
insert=(trunk_x, trunk_y),
size=(trunk_width, trunk_height),
fill='#8B4513',
stroke='black',
stroke_width=1
))
# Crown (multiple circles for foliage)
crown_radius = 30
for i, (dx, dy) in enumerate([(-15, -20), (15, -20), (0, -35), (-10, -50), (10, -50)]):
dwg.add(dwg.circle(
center=(center_x + dx, center_y + dy),
r=crown_radius - i * 3,
fill='#228B22',
stroke='#006400',
stroke_width=1,
opacity=0.8
))
def add_vehicle_elements(self, dwg, width, height, features):
"""Add vehicle-like elements"""
center_x, center_y = width // 2, height // 2
# Car body
car_width = width * 0.6
car_height = height * 0.3
car_x = (width - car_width) // 2
car_y = center_y + 10
dwg.add(dwg.rect(
insert=(car_x, car_y),
size=(car_width, car_height),
fill='#FF4500',
stroke='black',
stroke_width=2,
rx=5
))
# Windshield
windshield_width = car_width * 0.6
windshield_height = car_height * 0.4
windshield_x = car_x + (car_width - windshield_width) // 2
windshield_y = car_y - windshield_height + 5
dwg.add(dwg.rect(
insert=(windshield_x, windshield_y),
size=(windshield_width, windshield_height),
fill='#87CEEB',
stroke='black',
stroke_width=1
))
# Wheels
wheel_radius = 12
wheel_y = car_y + car_height - 5
dwg.add(dwg.circle(center=(car_x + 25, wheel_y), r=wheel_radius, fill='black'))
dwg.add(dwg.circle(center=(car_x + car_width - 25, wheel_y), r=wheel_radius, fill='black'))
def add_abstract_elements(self, dwg, width, height, features):
"""Add abstract elements"""
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7']
for i in range(5):
shape_type = random.choice(['circle', 'rect', 'path'])
color = random.choice(colors)
if shape_type == 'circle':
radius = random.randint(10, 30)
x = random.randint(radius, width - radius)
y = random.randint(radius, height - radius)
dwg.add(dwg.circle(center=(x, y), r=radius, fill=color, opacity=0.7))
elif shape_type == 'rect':
w = random.randint(20, 60)
h = random.randint(20, 60)
x = random.randint(0, width - w)
y = random.randint(0, height - h)
dwg.add(dwg.rect(insert=(x, y), size=(w, h), fill=color, opacity=0.7))
else:
# Random path
start_x = random.randint(0, width)
start_y = random.randint(0, height)
end_x = random.randint(0, width)
end_y = random.randint(0, height)
dwg.add(dwg.line(start=(start_x, start_y), end=(end_x, end_y), stroke=color, stroke_width=3))
def add_colored_elements(self, dwg, width, height, color_words):
"""Add elements with specific colors"""
color_map = {
'red': '#FF0000',
'blue': '#0000FF',
'green': '#00FF00',
'yellow': '#FFFF00',
'purple': '#800080',
'orange': '#FFA500'
}
center_x, center_y = width // 2, height // 2
for word in color_words:
if word in color_map:
color = color_map[word]
# Add a colored shape
dwg.add(dwg.circle(
center=(center_x + random.randint(-50, 50), center_y + random.randint(-50, 50)),
r=random.randint(15, 35),
fill=color,
opacity=0.8
))
def add_large_elements(self, dwg, width, height, features):
"""Add large-sized elements"""
center_x, center_y = width // 2, height // 2
# Large central element
dwg.add(dwg.circle(
center=(center_x, center_y),
r=min(width, height) // 3,
fill='#4A90E2',
stroke='black',
stroke_width=3
))
def add_small_elements(self, dwg, width, height, features):
"""Add small-sized elements"""
# Multiple small elements
for i in range(8):
x = random.randint(10, width - 10)
y = random.randint(10, height - 10)
dwg.add(dwg.circle(
center=(x, y),
r=random.randint(3, 8),
fill='#E74C3C',
opacity=0.7
))
def add_detailed_elements(self, dwg, width, height, features):
"""Add detailed elements for refinement"""
# Add more complex shapes and details
self.add_abstract_elements(dwg, width, height, features)
# Add decorative elements
center_x, center_y = width // 2, height // 2
for i in range(4):
angle = i * math.pi / 2
x = center_x + 40 * math.cos(angle)
y = center_y + 40 * math.sin(angle)
dwg.add(dwg.circle(center=(x, y), r=8, fill='#9B59B6', opacity=0.6))
def add_emphasized_element(self, dwg, word: str, weight: float, width: int, height: int):
"""Add emphasized element based on attention weight"""
center_x, center_y = width // 2, height // 2
# Scale size based on weight
base_size = 20
size = int(base_size * weight)
dwg.add(dwg.circle(
center=(center_x + random.randint(-30, 30), center_y + random.randint(-30, 30)),
r=size,
fill='#FF6B6B',
opacity=min(1.0, weight / 2),
stroke='black',
stroke_width=2
))
def add_deemphasized_element(self, dwg, word: str, weight: float, width: int, height: int):
"""Add de-emphasized element based on attention weight"""
center_x, center_y = width // 2, height // 2
# Scale size based on weight
base_size = 15
size = int(base_size * weight)
dwg.add(dwg.circle(
center=(center_x + random.randint(-40, 40), center_y + random.randint(-40, 40)),
r=max(3, size),
fill='#CCCCCC',
opacity=weight,
stroke='gray',
stroke_width=1
))
def add_content_based_on_prompt(self, dwg, prompt: str, width: int, height: int):
"""Add content based on prompt analysis"""
features = self.extract_semantic_features(prompt)
if any(word in prompt.lower() for word in ['person', 'people', 'human']):
self.add_person_elements(dwg, width, height, features)
elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog']):
self.add_animal_elements(dwg, width, height, features)
elif any(word in prompt.lower() for word in ['house', 'building']):
self.add_building_elements(dwg, width, height, features)
elif any(word in prompt.lower() for word in ['tree', 'nature']):
self.add_nature_elements(dwg, width, height, features)
elif any(word in prompt.lower() for word in ['car', 'vehicle']):
self.add_vehicle_elements(dwg, width, height, features)
else:
self.add_abstract_elements(dwg, width, height, features)
def extract_semantic_features(self, prompt: str):
"""Extract semantic features from prompt"""
features = {
'detailed': False,
'simple': False,
'colorful': False,
'large': False,
'small': False
}
prompt_lower = prompt.lower()
if any(word in prompt_lower for word in ['detailed', 'complex', 'intricate']):
features['detailed'] = True
if any(word in prompt_lower for word in ['simple', 'minimal', 'basic']):
features['simple'] = True
if any(word in prompt_lower for word in ['colorful', 'bright', 'vibrant']):
features['colorful'] = True
if any(word in prompt_lower for word in ['large', 'big', 'huge']):
features['large'] = True
if any(word in prompt_lower for word in ['small', 'tiny', 'mini']):
features['small'] = True
return features
def svg_to_pil_image(self, svg_content: str, width: int, height: int):
"""Convert SVG content to PIL Image"""
try:
import cairosvg
# Convert SVG to PNG bytes
png_bytes = cairosvg.svg2png(
bytestring=svg_content.encode('utf-8'),
output_width=width,
output_height=height
)
# Convert to PIL Image
image = Image.open(io.BytesIO(png_bytes)).convert('RGB')
return image
except ImportError:
print("cairosvg not available, creating simple image representation")
# Fallback: create a simple image with text
image = Image.new('RGB', (width, height), 'white')
return image
except Exception as e:
print(f"Error converting SVG to image: {e}")
# Fallback: create a simple image
image = Image.new('RGB', (width, height), 'white')
return image
def create_fallback_svg(self, prompt: str, width: int, height: int):
"""Create simple fallback SVG"""
dwg = svgwrite.Drawing(size=(width, height))
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
# Simple centered text
prompt_str = str(prompt)[:30] if prompt else "error"
dwg.add(dwg.text(
f"DiffSketchEdit\n{prompt_str}...",
insert=(width/2, height/2),
text_anchor="middle",
font_size="12px",
fill="black"
))
return dwg.tostring()