|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import json |
|
import base64 |
|
import io |
|
from PIL import Image |
|
import svgwrite |
|
from typing import Dict, Any, List, Optional, Union |
|
import diffusers |
|
from diffusers import StableDiffusionPipeline, DDIMScheduler |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
import torchvision.transforms as transforms |
|
import random |
|
import math |
|
import re |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model_id = "runwayml/stable-diffusion-v1-5" |
|
|
|
try: |
|
|
|
self.pipe = StableDiffusionPipeline.from_pretrained( |
|
self.model_id, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
safety_checker=None, |
|
requires_safety_checker=False |
|
).to(self.device) |
|
|
|
|
|
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
|
|
|
|
|
self.clip_model = self.pipe.text_encoder |
|
self.clip_tokenizer = self.pipe.tokenizer |
|
|
|
print("DiffSketchEdit handler initialized successfully!") |
|
except Exception as e: |
|
print(f"Warning: Could not load diffusion model: {e}") |
|
self.pipe = None |
|
self.clip_model = None |
|
self.clip_tokenizer = None |
|
|
|
def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image: |
|
""" |
|
Perform sketch editing using DiffSketchEdit approach |
|
""" |
|
try: |
|
|
|
if isinstance(inputs, str): |
|
|
|
try: |
|
parsed_inputs = json.loads(inputs) |
|
if isinstance(parsed_inputs, dict): |
|
inputs = parsed_inputs |
|
else: |
|
|
|
prompts = [inputs] |
|
edit_type = "generate" |
|
parameters = {} |
|
except: |
|
|
|
prompts = [inputs] |
|
edit_type = "generate" |
|
parameters = {} |
|
|
|
if isinstance(inputs, dict): |
|
input_data = inputs.get("inputs", inputs) |
|
if isinstance(input_data, str): |
|
prompts = [input_data] |
|
edit_type = "generate" |
|
elif isinstance(input_data, dict): |
|
prompts = input_data.get("prompts", [input_data.get("prompt", "a simple sketch")]) |
|
edit_type = input_data.get("edit_type", "generate") |
|
else: |
|
prompts = ["a simple sketch"] |
|
edit_type = "generate" |
|
|
|
parameters = inputs.get("parameters", {}) |
|
|
|
|
|
width = parameters.get("width", 224) |
|
height = parameters.get("height", 224) |
|
seed = parameters.get("seed", None) |
|
input_svg = parameters.get("input_svg", None) |
|
|
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
print(f"Processing edit type: '{edit_type}' with prompts: {prompts}") |
|
|
|
|
|
if edit_type == "replace" and len(prompts) >= 2: |
|
svg_content, metadata = self.word_replacement_edit(prompts[0], prompts[1], width, height, input_svg) |
|
elif edit_type == "refine": |
|
svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg) |
|
elif edit_type == "reweight": |
|
svg_content, metadata = self.attention_reweighting_edit(prompts[0], width, height, input_svg) |
|
elif edit_type == "generate": |
|
svg_content, metadata = self.simple_generation(prompts[0], width, height) |
|
else: |
|
|
|
svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg) |
|
|
|
|
|
pil_image = self.svg_to_pil_image(svg_content, width, height) |
|
|
|
|
|
pil_image.info['svg_content'] = svg_content |
|
for key, value in metadata.items(): |
|
if isinstance(value, (dict, list)): |
|
pil_image.info[key] = json.dumps(value) |
|
else: |
|
pil_image.info[key] = str(value) |
|
|
|
return pil_image |
|
|
|
except Exception as e: |
|
print(f"Error in handler: {e}") |
|
|
|
fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height) |
|
fallback_image = self.svg_to_pil_image(fallback_svg, width, height) |
|
fallback_image.info['error'] = str(e) |
|
fallback_image.info['edit_type'] = edit_type |
|
return fallback_image |
|
|
|
def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None): |
|
"""Perform word replacement editing""" |
|
try: |
|
print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'") |
|
|
|
|
|
added_words, removed_words = self.analyze_word_differences(source_prompt, target_prompt) |
|
print(f"Added words: {added_words}, Removed words: {removed_words}") |
|
|
|
|
|
if input_svg: |
|
base_svg = input_svg |
|
else: |
|
base_svg = self.generate_base_svg(source_prompt, width, height) |
|
|
|
|
|
edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height) |
|
|
|
metadata = { |
|
"edit_type": "replace", |
|
"source_prompt": source_prompt, |
|
"target_prompt": target_prompt, |
|
"added_words": list(added_words), |
|
"removed_words": list(removed_words) |
|
} |
|
|
|
return edited_svg, metadata |
|
|
|
except Exception as e: |
|
print(f"Error in word_replacement_edit: {e}") |
|
fallback_svg = self.create_fallback_svg(source_prompt, width, height) |
|
metadata = {"edit_type": "replace", "error": str(e)} |
|
return fallback_svg, metadata |
|
|
|
def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None): |
|
"""Perform prompt refinement editing""" |
|
try: |
|
print(f"Prompt refinement for: '{prompt}'") |
|
|
|
|
|
if input_svg: |
|
base_svg = input_svg |
|
else: |
|
base_svg = self.generate_base_svg(prompt, width, height) |
|
|
|
|
|
refined_svg = self.apply_refinement(base_svg, prompt, width, height) |
|
|
|
metadata = { |
|
"edit_type": "refine", |
|
"prompt": prompt |
|
} |
|
|
|
return refined_svg, metadata |
|
|
|
except Exception as e: |
|
print(f"Error in prompt_refinement_edit: {e}") |
|
fallback_svg = self.create_fallback_svg(prompt, width, height) |
|
metadata = {"edit_type": "refine", "error": str(e)} |
|
return fallback_svg, metadata |
|
|
|
def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None): |
|
"""Perform attention reweighting editing""" |
|
try: |
|
print(f"Attention reweighting for: '{prompt}'") |
|
|
|
|
|
weighted_prompt, attention_weights = self.parse_attention_weights(prompt) |
|
print(f"Weighted prompt: '{weighted_prompt}', weights: {attention_weights}") |
|
|
|
|
|
if input_svg: |
|
base_svg = input_svg |
|
else: |
|
base_svg = self.generate_base_svg(weighted_prompt, width, height) |
|
|
|
|
|
reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height) |
|
|
|
metadata = { |
|
"edit_type": "reweight", |
|
"prompt": prompt, |
|
"weighted_prompt": weighted_prompt, |
|
"attention_weights": attention_weights |
|
} |
|
|
|
return reweighted_svg, metadata |
|
|
|
except Exception as e: |
|
print(f"Error in attention_reweighting_edit: {e}") |
|
fallback_svg = self.create_fallback_svg(prompt, width, height) |
|
metadata = {"edit_type": "reweight", "error": str(e)} |
|
return fallback_svg, metadata |
|
|
|
def simple_generation(self, prompt: str, width: int, height: int): |
|
"""Perform simple SVG generation""" |
|
try: |
|
print(f"Simple generation for: '{prompt}'") |
|
|
|
svg_content = self.generate_base_svg(prompt, width, height) |
|
|
|
metadata = { |
|
"edit_type": "generate", |
|
"prompt": prompt |
|
} |
|
|
|
return svg_content, metadata |
|
|
|
except Exception as e: |
|
print(f"Error in simple_generation: {e}") |
|
fallback_svg = self.create_fallback_svg(prompt, width, height) |
|
metadata = {"edit_type": "generate", "error": str(e)} |
|
return fallback_svg, metadata |
|
|
|
def generate_base_svg(self, prompt: str, width: int, height: int): |
|
"""Generate base SVG from prompt""" |
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
|
|
features = self.extract_semantic_features(prompt) |
|
|
|
|
|
if any(word in prompt.lower() for word in ['person', 'people', 'human', 'man', 'woman']): |
|
self.add_person_elements(dwg, width, height, features) |
|
elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog', 'bird', 'horse']): |
|
self.add_animal_elements(dwg, width, height, features) |
|
elif any(word in prompt.lower() for word in ['house', 'building', 'architecture']): |
|
self.add_building_elements(dwg, width, height, features) |
|
elif any(word in prompt.lower() for word in ['tree', 'nature', 'landscape']): |
|
self.add_nature_elements(dwg, width, height, features) |
|
elif any(word in prompt.lower() for word in ['car', 'vehicle', 'transport']): |
|
self.add_vehicle_elements(dwg, width, height, features) |
|
else: |
|
self.add_abstract_elements(dwg, width, height, features) |
|
|
|
return dwg.tostring() |
|
|
|
def analyze_word_differences(self, source: str, target: str): |
|
"""Analyze differences between source and target prompts""" |
|
source_words = set(source.lower().split()) |
|
target_words = set(target.lower().split()) |
|
|
|
added_words = target_words - source_words |
|
removed_words = source_words - target_words |
|
|
|
return added_words, removed_words |
|
|
|
def parse_attention_weights(self, prompt: str): |
|
"""Parse attention weights from prompt""" |
|
|
|
increase_pattern = r'\(([^:]+):([0-9.]+)\)' |
|
|
|
decrease_pattern = r'\[([^:]+):([0-9.]+)\]' |
|
|
|
attention_weights = {} |
|
weighted_prompt = prompt |
|
|
|
|
|
for match in re.finditer(increase_pattern, prompt): |
|
word = match.group(1).strip() |
|
weight = float(match.group(2)) |
|
attention_weights[word] = weight |
|
|
|
weighted_prompt = weighted_prompt.replace(match.group(0), word) |
|
|
|
|
|
for match in re.finditer(decrease_pattern, prompt): |
|
word = match.group(1).strip() |
|
weight = float(match.group(2)) |
|
attention_weights[word] = weight |
|
|
|
weighted_prompt = weighted_prompt.replace(match.group(0), word) |
|
|
|
return weighted_prompt.strip(), attention_weights |
|
|
|
def apply_word_replacement(self, base_svg: str, source_prompt: str, target_prompt: str, |
|
added_words: set, removed_words: set, width: int, height: int): |
|
"""Apply word replacement transformations to SVG""" |
|
|
|
|
|
|
|
|
|
features = self.extract_semantic_features(target_prompt) |
|
|
|
|
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
|
|
if any(word in added_words for word in ['red', 'blue', 'green', 'yellow']): |
|
|
|
self.add_colored_elements(dwg, width, height, added_words) |
|
elif any(word in added_words for word in ['big', 'large', 'huge']): |
|
|
|
self.add_large_elements(dwg, width, height, features) |
|
elif any(word in added_words for word in ['small', 'tiny', 'mini']): |
|
|
|
self.add_small_elements(dwg, width, height, features) |
|
else: |
|
|
|
self.add_content_based_on_prompt(dwg, target_prompt, width, height) |
|
|
|
return dwg.tostring() |
|
|
|
def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int): |
|
"""Apply refinement to existing SVG""" |
|
|
|
features = self.extract_semantic_features(prompt) |
|
|
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
|
|
if features.get('detailed', False): |
|
self.add_detailed_elements(dwg, width, height, features) |
|
else: |
|
self.add_content_based_on_prompt(dwg, prompt, width, height) |
|
|
|
return dwg.tostring() |
|
|
|
def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int): |
|
"""Apply attention reweighting to SVG""" |
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
|
|
for word, weight in attention_weights.items(): |
|
if weight > 1.0: |
|
|
|
self.add_emphasized_element(dwg, word, weight, width, height) |
|
elif weight < 1.0: |
|
|
|
self.add_deemphasized_element(dwg, word, weight, width, height) |
|
|
|
|
|
self.add_content_based_on_prompt(dwg, prompt, width, height) |
|
|
|
return dwg.tostring() |
|
|
|
def add_person_elements(self, dwg, width, height, features): |
|
"""Add person-like elements""" |
|
center_x, center_y = width // 2, height // 2 |
|
|
|
|
|
head_radius = 20 |
|
dwg.add(dwg.circle(center=(center_x, center_y - 40), r=head_radius, fill='#FDBCB4', stroke='black', stroke_width=2)) |
|
|
|
|
|
body_height = 60 |
|
body_width = 30 |
|
dwg.add(dwg.rect( |
|
insert=(center_x - body_width//2, center_y - 10), |
|
size=(body_width, body_height), |
|
fill='#4A90E2', |
|
stroke='black', |
|
stroke_width=2 |
|
)) |
|
|
|
|
|
dwg.add(dwg.line(start=(center_x - body_width//2, center_y), end=(center_x - 40, center_y + 20), stroke='black', stroke_width=3)) |
|
dwg.add(dwg.line(start=(center_x + body_width//2, center_y), end=(center_x + 40, center_y + 20), stroke='black', stroke_width=3)) |
|
|
|
|
|
dwg.add(dwg.line(start=(center_x - 10, center_y + body_height - 10), end=(center_x - 20, center_y + body_height + 30), stroke='black', stroke_width=3)) |
|
dwg.add(dwg.line(start=(center_x + 10, center_y + body_height - 10), end=(center_x + 20, center_y + body_height + 30), stroke='black', stroke_width=3)) |
|
|
|
def add_animal_elements(self, dwg, width, height, features): |
|
"""Add animal-like elements""" |
|
center_x, center_y = width // 2, height // 2 |
|
|
|
|
|
dwg.add(dwg.ellipse(center=(center_x, center_y), r=(40, 25), fill='#8B4513', stroke='black', stroke_width=2)) |
|
|
|
|
|
dwg.add(dwg.circle(center=(center_x - 30, center_y - 10), r=20, fill='#A0522D', stroke='black', stroke_width=2)) |
|
|
|
|
|
for i, x_offset in enumerate([-20, -10, 10, 20]): |
|
dwg.add(dwg.line( |
|
start=(center_x + x_offset, center_y + 25), |
|
end=(center_x + x_offset, center_y + 45), |
|
stroke='black', |
|
stroke_width=3 |
|
)) |
|
|
|
|
|
dwg.add(dwg.path( |
|
d=f"M {center_x + 40},{center_y} Q {center_x + 60},{center_y - 20} {center_x + 50},{center_y - 35}", |
|
stroke='black', |
|
stroke_width=3, |
|
fill='none' |
|
)) |
|
|
|
def add_building_elements(self, dwg, width, height, features): |
|
"""Add building-like elements""" |
|
|
|
building_width = width * 0.6 |
|
building_height = height * 0.7 |
|
x = (width - building_width) // 2 |
|
y = height - building_height - 10 |
|
|
|
dwg.add(dwg.rect( |
|
insert=(x, y), |
|
size=(building_width, building_height), |
|
fill='#CD853F', |
|
stroke='black', |
|
stroke_width=2 |
|
)) |
|
|
|
|
|
roof_points = [(x, y), (x + building_width//2, y - 30), (x + building_width, y)] |
|
dwg.add(dwg.polygon(points=roof_points, fill='#8B0000', stroke='black', stroke_width=2)) |
|
|
|
|
|
window_size = 15 |
|
for i in range(3): |
|
for j in range(4): |
|
wx = x + 15 + i * 30 |
|
wy = y + 15 + j * 25 |
|
if wy < y + building_height - 20: |
|
dwg.add(dwg.rect( |
|
insert=(wx, wy), |
|
size=(window_size, window_size), |
|
fill='#87CEEB', |
|
stroke='black', |
|
stroke_width=1 |
|
)) |
|
|
|
|
|
door_width = 20 |
|
door_height = 40 |
|
door_x = x + building_width//2 - door_width//2 |
|
door_y = y + building_height - door_height |
|
dwg.add(dwg.rect( |
|
insert=(door_x, door_y), |
|
size=(door_width, door_height), |
|
fill='#8B4513', |
|
stroke='black', |
|
stroke_width=2 |
|
)) |
|
|
|
def add_nature_elements(self, dwg, width, height, features): |
|
"""Add nature-like elements""" |
|
|
|
center_x, center_y = width // 2, height // 2 |
|
|
|
|
|
trunk_width = 15 |
|
trunk_height = height // 3 |
|
trunk_x = center_x - trunk_width // 2 |
|
trunk_y = height - trunk_height - 10 |
|
|
|
dwg.add(dwg.rect( |
|
insert=(trunk_x, trunk_y), |
|
size=(trunk_width, trunk_height), |
|
fill='#8B4513', |
|
stroke='black', |
|
stroke_width=1 |
|
)) |
|
|
|
|
|
crown_radius = 30 |
|
for i, (dx, dy) in enumerate([(-15, -20), (15, -20), (0, -35), (-10, -50), (10, -50)]): |
|
dwg.add(dwg.circle( |
|
center=(center_x + dx, center_y + dy), |
|
r=crown_radius - i * 3, |
|
fill='#228B22', |
|
stroke='#006400', |
|
stroke_width=1, |
|
opacity=0.8 |
|
)) |
|
|
|
def add_vehicle_elements(self, dwg, width, height, features): |
|
"""Add vehicle-like elements""" |
|
center_x, center_y = width // 2, height // 2 |
|
|
|
|
|
car_width = width * 0.6 |
|
car_height = height * 0.3 |
|
car_x = (width - car_width) // 2 |
|
car_y = center_y + 10 |
|
|
|
dwg.add(dwg.rect( |
|
insert=(car_x, car_y), |
|
size=(car_width, car_height), |
|
fill='#FF4500', |
|
stroke='black', |
|
stroke_width=2, |
|
rx=5 |
|
)) |
|
|
|
|
|
windshield_width = car_width * 0.6 |
|
windshield_height = car_height * 0.4 |
|
windshield_x = car_x + (car_width - windshield_width) // 2 |
|
windshield_y = car_y - windshield_height + 5 |
|
|
|
dwg.add(dwg.rect( |
|
insert=(windshield_x, windshield_y), |
|
size=(windshield_width, windshield_height), |
|
fill='#87CEEB', |
|
stroke='black', |
|
stroke_width=1 |
|
)) |
|
|
|
|
|
wheel_radius = 12 |
|
wheel_y = car_y + car_height - 5 |
|
dwg.add(dwg.circle(center=(car_x + 25, wheel_y), r=wheel_radius, fill='black')) |
|
dwg.add(dwg.circle(center=(car_x + car_width - 25, wheel_y), r=wheel_radius, fill='black')) |
|
|
|
def add_abstract_elements(self, dwg, width, height, features): |
|
"""Add abstract elements""" |
|
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7'] |
|
|
|
for i in range(5): |
|
shape_type = random.choice(['circle', 'rect', 'path']) |
|
color = random.choice(colors) |
|
|
|
if shape_type == 'circle': |
|
radius = random.randint(10, 30) |
|
x = random.randint(radius, width - radius) |
|
y = random.randint(radius, height - radius) |
|
dwg.add(dwg.circle(center=(x, y), r=radius, fill=color, opacity=0.7)) |
|
elif shape_type == 'rect': |
|
w = random.randint(20, 60) |
|
h = random.randint(20, 60) |
|
x = random.randint(0, width - w) |
|
y = random.randint(0, height - h) |
|
dwg.add(dwg.rect(insert=(x, y), size=(w, h), fill=color, opacity=0.7)) |
|
else: |
|
|
|
start_x = random.randint(0, width) |
|
start_y = random.randint(0, height) |
|
end_x = random.randint(0, width) |
|
end_y = random.randint(0, height) |
|
dwg.add(dwg.line(start=(start_x, start_y), end=(end_x, end_y), stroke=color, stroke_width=3)) |
|
|
|
def add_colored_elements(self, dwg, width, height, color_words): |
|
"""Add elements with specific colors""" |
|
color_map = { |
|
'red': '#FF0000', |
|
'blue': '#0000FF', |
|
'green': '#00FF00', |
|
'yellow': '#FFFF00', |
|
'purple': '#800080', |
|
'orange': '#FFA500' |
|
} |
|
|
|
center_x, center_y = width // 2, height // 2 |
|
|
|
for word in color_words: |
|
if word in color_map: |
|
color = color_map[word] |
|
|
|
dwg.add(dwg.circle( |
|
center=(center_x + random.randint(-50, 50), center_y + random.randint(-50, 50)), |
|
r=random.randint(15, 35), |
|
fill=color, |
|
opacity=0.8 |
|
)) |
|
|
|
def add_large_elements(self, dwg, width, height, features): |
|
"""Add large-sized elements""" |
|
center_x, center_y = width // 2, height // 2 |
|
|
|
|
|
dwg.add(dwg.circle( |
|
center=(center_x, center_y), |
|
r=min(width, height) // 3, |
|
fill='#4A90E2', |
|
stroke='black', |
|
stroke_width=3 |
|
)) |
|
|
|
def add_small_elements(self, dwg, width, height, features): |
|
"""Add small-sized elements""" |
|
|
|
for i in range(8): |
|
x = random.randint(10, width - 10) |
|
y = random.randint(10, height - 10) |
|
dwg.add(dwg.circle( |
|
center=(x, y), |
|
r=random.randint(3, 8), |
|
fill='#E74C3C', |
|
opacity=0.7 |
|
)) |
|
|
|
def add_detailed_elements(self, dwg, width, height, features): |
|
"""Add detailed elements for refinement""" |
|
|
|
self.add_abstract_elements(dwg, width, height, features) |
|
|
|
|
|
center_x, center_y = width // 2, height // 2 |
|
for i in range(4): |
|
angle = i * math.pi / 2 |
|
x = center_x + 40 * math.cos(angle) |
|
y = center_y + 40 * math.sin(angle) |
|
dwg.add(dwg.circle(center=(x, y), r=8, fill='#9B59B6', opacity=0.6)) |
|
|
|
def add_emphasized_element(self, dwg, word: str, weight: float, width: int, height: int): |
|
"""Add emphasized element based on attention weight""" |
|
center_x, center_y = width // 2, height // 2 |
|
|
|
|
|
base_size = 20 |
|
size = int(base_size * weight) |
|
|
|
dwg.add(dwg.circle( |
|
center=(center_x + random.randint(-30, 30), center_y + random.randint(-30, 30)), |
|
r=size, |
|
fill='#FF6B6B', |
|
opacity=min(1.0, weight / 2), |
|
stroke='black', |
|
stroke_width=2 |
|
)) |
|
|
|
def add_deemphasized_element(self, dwg, word: str, weight: float, width: int, height: int): |
|
"""Add de-emphasized element based on attention weight""" |
|
center_x, center_y = width // 2, height // 2 |
|
|
|
|
|
base_size = 15 |
|
size = int(base_size * weight) |
|
|
|
dwg.add(dwg.circle( |
|
center=(center_x + random.randint(-40, 40), center_y + random.randint(-40, 40)), |
|
r=max(3, size), |
|
fill='#CCCCCC', |
|
opacity=weight, |
|
stroke='gray', |
|
stroke_width=1 |
|
)) |
|
|
|
def add_content_based_on_prompt(self, dwg, prompt: str, width: int, height: int): |
|
"""Add content based on prompt analysis""" |
|
features = self.extract_semantic_features(prompt) |
|
|
|
if any(word in prompt.lower() for word in ['person', 'people', 'human']): |
|
self.add_person_elements(dwg, width, height, features) |
|
elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog']): |
|
self.add_animal_elements(dwg, width, height, features) |
|
elif any(word in prompt.lower() for word in ['house', 'building']): |
|
self.add_building_elements(dwg, width, height, features) |
|
elif any(word in prompt.lower() for word in ['tree', 'nature']): |
|
self.add_nature_elements(dwg, width, height, features) |
|
elif any(word in prompt.lower() for word in ['car', 'vehicle']): |
|
self.add_vehicle_elements(dwg, width, height, features) |
|
else: |
|
self.add_abstract_elements(dwg, width, height, features) |
|
|
|
def extract_semantic_features(self, prompt: str): |
|
"""Extract semantic features from prompt""" |
|
features = { |
|
'detailed': False, |
|
'simple': False, |
|
'colorful': False, |
|
'large': False, |
|
'small': False |
|
} |
|
|
|
prompt_lower = prompt.lower() |
|
|
|
if any(word in prompt_lower for word in ['detailed', 'complex', 'intricate']): |
|
features['detailed'] = True |
|
if any(word in prompt_lower for word in ['simple', 'minimal', 'basic']): |
|
features['simple'] = True |
|
if any(word in prompt_lower for word in ['colorful', 'bright', 'vibrant']): |
|
features['colorful'] = True |
|
if any(word in prompt_lower for word in ['large', 'big', 'huge']): |
|
features['large'] = True |
|
if any(word in prompt_lower for word in ['small', 'tiny', 'mini']): |
|
features['small'] = True |
|
|
|
return features |
|
|
|
def svg_to_pil_image(self, svg_content: str, width: int, height: int): |
|
"""Convert SVG content to PIL Image""" |
|
try: |
|
import cairosvg |
|
|
|
|
|
png_bytes = cairosvg.svg2png( |
|
bytestring=svg_content.encode('utf-8'), |
|
output_width=width, |
|
output_height=height |
|
) |
|
|
|
|
|
image = Image.open(io.BytesIO(png_bytes)).convert('RGB') |
|
return image |
|
|
|
except ImportError: |
|
print("cairosvg not available, creating simple image representation") |
|
|
|
image = Image.new('RGB', (width, height), 'white') |
|
return image |
|
except Exception as e: |
|
print(f"Error converting SVG to image: {e}") |
|
|
|
image = Image.new('RGB', (width, height), 'white') |
|
return image |
|
|
|
def create_fallback_svg(self, prompt: str, width: int, height: int): |
|
"""Create simple fallback SVG""" |
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
|
|
prompt_str = str(prompt)[:30] if prompt else "error" |
|
dwg.add(dwg.text( |
|
f"DiffSketchEdit\n{prompt_str}...", |
|
insert=(width/2, height/2), |
|
text_anchor="middle", |
|
font_size="12px", |
|
fill="black" |
|
)) |
|
|
|
return dwg.tostring() |