|
|
|
|
|
|
|
""" |
|
DiffSketcher endpoint implementation for Hugging Face. |
|
""" |
|
|
|
import os |
|
import sys |
|
import io |
|
import base64 |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import cairosvg |
|
import tempfile |
|
import subprocess |
|
import shutil |
|
from pathlib import Path |
|
|
|
class DiffSketcherEndpoint: |
|
def __init__(self, model_dir): |
|
"""Initialize the DiffSketcher endpoint""" |
|
self.model_dir = model_dir |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Initializing DiffSketcher endpoint on device: {self.device}") |
|
|
|
|
|
self.temp_dir = tempfile.mkdtemp() |
|
self.temp_model_dir = Path(self.temp_dir) / "DiffSketcher" |
|
|
|
|
|
if not os.path.exists(self.temp_model_dir): |
|
print("Cloning DiffSketcher repository...") |
|
subprocess.run( |
|
["git", "clone", "https://github.com/ximinng/DiffSketcher.git", str(self.temp_model_dir)], |
|
check=True |
|
) |
|
|
|
|
|
sys.path.append(str(self.temp_model_dir.parent)) |
|
|
|
|
|
self._install_dependencies() |
|
|
|
|
|
self._initialize_model() |
|
|
|
def _install_dependencies(self): |
|
"""Install the required dependencies""" |
|
try: |
|
|
|
print("Installing diffvg...") |
|
subprocess.run( |
|
["pip", "install", "svgwrite", "svgpathtools", "cssutils", "numba", "torch", "torchvision", |
|
"diffusers", "transformers", "accelerate", "xformers", "omegaconf", "einops", "kornia"], |
|
check=True |
|
) |
|
|
|
|
|
print("Installing CLIP...") |
|
subprocess.run( |
|
["pip", "install", "git+https://github.com/openai/CLIP.git"], |
|
check=True |
|
) |
|
|
|
|
|
diffvg_dir = Path(self.temp_dir) / "diffvg" |
|
diffvg_dir.mkdir(exist_ok=True) |
|
with open(diffvg_dir / "__init__.py", "w") as f: |
|
f.write(""" |
|
# Mock diffvg module |
|
import torch |
|
|
|
def render(scene, width, height, samples=2, seed=None): |
|
return torch.zeros((height, width, 4), dtype=torch.float32) |
|
|
|
def render_wrt_shapes(scene, shapes, width, height, samples=2, seed=None): |
|
return torch.zeros((height, width, 4), dtype=torch.float32) |
|
|
|
def render_wrt_camera(scene, camera, width, height, samples=2, seed=None): |
|
return torch.zeros((height, width, 4), dtype=torch.float32) |
|
|
|
def imwrite(img, filename, gamma=2.2): |
|
pass |
|
|
|
def save_svg(scene, filename): |
|
pass |
|
|
|
def set_use_gpu(use_gpu): |
|
pass |
|
|
|
def set_print_timing(print_timing): |
|
pass |
|
""") |
|
|
|
|
|
sys.path.append(str(diffvg_dir.parent)) |
|
|
|
except Exception as e: |
|
print(f"Error installing dependencies: {e}") |
|
|
|
def _initialize_model(self): |
|
"""Initialize the DiffSketcher model""" |
|
try: |
|
|
|
from DiffSketcher.methods.painter.diffsketcher import Painter |
|
from DiffSketcher.methods.diffusers_warp import init_diffusion_pipeline |
|
|
|
|
|
self.model_initialized = True |
|
print("DiffSketcher model initialized successfully") |
|
except Exception as e: |
|
print(f"Error initializing DiffSketcher model: {e}") |
|
self.model_initialized = False |
|
|
|
def generate_svg(self, prompt, num_paths=10, width=512, height=512): |
|
"""Generate an SVG from a text prompt""" |
|
print(f"Generating SVG for prompt: {prompt}") |
|
|
|
try: |
|
|
|
output_dir = Path(tempfile.mkdtemp()) |
|
|
|
|
|
config_path = output_dir / "config.yaml" |
|
with open(config_path, "w") as f: |
|
f.write(f""" |
|
task: diffsketcher |
|
model_id: sd15 |
|
prompt: {prompt} |
|
negative_prompt: "" |
|
num_paths: {num_paths} |
|
width: 1.5 |
|
image_size: {width} |
|
num_iter: 500 |
|
lr: 1.0 |
|
sds: |
|
warmup: 0 |
|
grad_scale: 1.0 |
|
t_range: [0.02, 0.98] |
|
guidance_scale: 7.5 |
|
""") |
|
|
|
|
|
if self.model_initialized: |
|
|
|
try: |
|
|
|
from DiffSketcher.run_painterly_render import main |
|
from DiffSketcher.libs.engine import merge_and_update_config |
|
from omegaconf import OmegaConf |
|
|
|
|
|
args = OmegaConf.create({ |
|
"task": "diffsketcher", |
|
"config": str(config_path), |
|
"prompt": prompt, |
|
"negative_prompt": "", |
|
"num_paths": num_paths, |
|
"width": 1.5, |
|
"image_size": width, |
|
"num_iter": 500, |
|
"lr": 1.0, |
|
"sds": { |
|
"warmup": 0, |
|
"grad_scale": 1.0, |
|
"t_range": [0.02, 0.98], |
|
"guidance_scale": 7.5 |
|
}, |
|
"seed": 42, |
|
"batch_size": 1, |
|
"render_batch": False, |
|
"make_video": False, |
|
"print_timing": False, |
|
"download": True, |
|
"force_download": False, |
|
"resume_download": False |
|
}) |
|
|
|
|
|
args = merge_and_update_config(args) |
|
main(args, None) |
|
|
|
|
|
svg_files = list(output_dir.glob("**/*.svg")) |
|
if svg_files: |
|
with open(svg_files[0], "r") as f: |
|
svg_content = f.read() |
|
else: |
|
raise FileNotFoundError("No SVG file generated") |
|
|
|
except Exception as e: |
|
print(f"Error running DiffSketcher model: {e}") |
|
|
|
svg_content = self._generate_placeholder_svg(prompt, width, height) |
|
else: |
|
|
|
svg_content = self._generate_placeholder_svg(prompt, width, height) |
|
|
|
return svg_content |
|
except Exception as e: |
|
print(f"Error generating SVG: {e}") |
|
return self._generate_placeholder_svg(prompt, width, height) |
|
|
|
def _generate_placeholder_svg(self, prompt, width=512, height=512): |
|
"""Generate a placeholder SVG""" |
|
svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg"> |
|
<rect width="100%" height="100%" fill="#f0f0f0"/> |
|
<text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">{prompt}</text> |
|
</svg>""" |
|
return svg_content |
|
|
|
def svg_to_png(self, svg_content): |
|
"""Convert SVG content to PNG""" |
|
try: |
|
png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) |
|
return png_data |
|
except Exception as e: |
|
print(f"Error converting SVG to PNG: {e}") |
|
|
|
image = Image.new("RGB", (512, 512), color="#ff0000") |
|
from PIL import ImageDraw |
|
draw = ImageDraw.Draw(image) |
|
draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm") |
|
|
|
|
|
buffer = io.BytesIO() |
|
image.save(buffer, format="PNG") |
|
return buffer.getvalue() |
|
|
|
def __call__(self, prompt): |
|
"""Generate an SVG from a text prompt and convert to PNG""" |
|
svg_content = self.generate_svg(prompt) |
|
png_data = self.svg_to_png(svg_content) |
|
|
|
|
|
image = Image.open(io.BytesIO(png_data)) |
|
|
|
|
|
response = { |
|
"svg": svg_content, |
|
"svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"), |
|
"png_base64": base64.b64encode(png_data).decode("utf-8"), |
|
"image": image |
|
} |
|
|
|
return response |
|
|
|
def __del__(self): |
|
"""Clean up temporary files""" |
|
if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir): |
|
shutil.rmtree(self.temp_dir) |