jree423 commited on
Commit
1d1055f
·
verified ·
1 Parent(s): 7739491

Update: Add full model implementation

Browse files
Files changed (3) hide show
  1. Dockerfile +35 -37
  2. diffsketcher_model.py +107 -0
  3. handler.py +62 -24
Dockerfile CHANGED
@@ -2,47 +2,45 @@ FROM python:3.8-slim
2
 
3
  WORKDIR /code
4
 
5
- # Install system dependencies for Cairo
6
  RUN apt-get update && apt-get install -y \
7
  build-essential \
8
  python3-dev \
 
9
  libcairo2-dev \
10
  pkg-config \
11
- libpng-dev \
12
- libffi-dev \
13
  && rm -rf /var/lib/apt/lists/*
14
 
15
- # Install torch 2.0.0 and torchvision 0.15.1
16
- RUN pip install --no-cache-dir torch==2.0.0 torchvision==0.15.1
17
-
18
- # Install cairosvg and its dependencies
19
- RUN pip install --no-cache-dir cairosvg==2.7.0 cairocffi==1.5.1 cssselect2==0.7.0 defusedxml==0.7.1 tinycss2==1.2.1
20
-
21
- # Install other dependencies
22
- RUN pip install --no-cache-dir \
23
- diffusers==0.15.1 \
24
- transformers==4.27.4 \
25
- accelerate==0.18.0 \
26
- huggingface_hub==0.14.1 \
27
- pillow==9.5.0 \
28
- numpy==1.24.3 \
29
- tqdm==4.65.0 \
30
- fastapi==0.95.1 \
31
- uvicorn==0.22.0 \
32
- python-multipart==0.0.6
33
-
34
- # Create mock diffvg package
35
- RUN mkdir -p /tmp/mock_diffvg/pydiffvg && \
36
- echo '# Mock diffvg package\nimport numpy as np\nimport torch\n\ndef render(shapes, shape_groups, width, height, samples=2, seed=None):\n return torch.zeros((height, width, 3), dtype=torch.float32)\n\ndef render_shape_group(canvas, shape_group_id, shapes, shape_groups, shape_ids, samples=2, seed=None):\n pass\n\ndef save_svg(shapes, shape_groups, filename, width, height, use_gamma=False, background=None):\n with open(filename, "w") as f:\n f.write(f"<svg width=\\"{width}\\" height=\\"{height}\\" xmlns=\\"http://www.w3.org/2000/svg\\"><rect width=\\"100%\\" height=\\"100%\\" fill=\\"white\\"/></svg>")\n\ndef svg_path_to_shapes(path_string):\n return [], []\n\ndef from_svg(filename):\n return [], []\n\nclass Circle:\n def __init__(self, radius=1.0, center=None):\n self.radius = radius\n self.center = center if center is not None else torch.tensor([0.0, 0.0])\n\nclass Ellipse:\n def __init__(self, radius=None, center=None):\n self.radius = radius if radius is not None else torch.tensor([1.0, 1.0])\n self.center = center if center is not None else torch.tensor([0.0, 0.0])\n\nclass Path:\n def __init__(self, points=None, is_closed=True):\n self.points = points if points is not None else torch.tensor([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]])\n self.is_closed = is_closed\n\nclass Rect:\n def __init__(self, p_min=None, p_max=None):\n self.p_min = p_min if p_min is not None else torch.tensor([0.0, 0.0])\n self.p_max = p_max if p_max is not None else torch.tensor([1.0, 1.0])\n\nclass ShapeGroup:\n def __init__(self, shape_ids=None, fill_color=None, stroke_color=None):\n self.shape_ids = shape_ids if shape_ids is not None else []\n self.fill_color = fill_color if fill_color is not None else torch.tensor([1.0, 1.0, 1.0, 1.0])\n self.stroke_color = stroke_color if stroke_color is not None else torch.tensor([0.0, 0.0, 0.0, 1.0])' > /tmp/mock_diffvg/pydiffvg/__init__.py && \
37
- echo 'from setuptools import setup, find_packages\n\nsetup(\n name="pydiffvg",\n version="0.0.1",\n packages=find_packages(),\n install_requires=[\n "numpy",\n "torch",\n ],\n)' > /tmp/mock_diffvg/setup.py && \
38
- cd /tmp/mock_diffvg && \
39
- pip install .
40
-
41
- # Create a simple handler.py
42
- COPY handler.py /code/handler.py
43
-
44
- # Create a simple API file
45
- COPY api.py /code/api.py
46
-
47
- # Set up the API
48
- CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
 
2
 
3
  WORKDIR /code
4
 
5
+ # Install system dependencies
6
  RUN apt-get update && apt-get install -y \
7
  build-essential \
8
  python3-dev \
9
+ git \
10
  libcairo2-dev \
11
  pkg-config \
 
 
12
  && rm -rf /var/lib/apt/lists/*
13
 
14
+ # Install PyTorch and torchvision
15
+ RUN pip install torch==2.0.0 torchvision==0.15.1 --extra-index-url https://download.pytorch.org/whl/cpu
16
+
17
+ # Install CLIP
18
+ RUN pip install git+https://github.com/openai/CLIP.git
19
+
20
+ # Install cairosvg and other dependencies
21
+ RUN pip install cairosvg cairocffi cssselect2 defusedxml tinycss2
22
+
23
+ # Install FastAPI and other dependencies
24
+ RUN pip install fastapi uvicorn pydantic pillow numpy requests
25
+
26
+ # Copy the model files
27
+ COPY . /code/
28
+
29
+ # Download model weights if they don't exist
30
+ RUN if [ ! -f /code/ViT-B-32.pt ]; then \
31
+ pip install gdown && \
32
+ python -c "import clip; clip.load('ViT-B-32')" ; \
33
+ fi
34
+
35
+ # Make sure the handler and model are available
36
+ RUN if [ -f /code/diffsketcher_model.py ]; then \
37
+ echo "DiffSketcher model found"; \
38
+ else \
39
+ echo "DiffSketcher model not found, using placeholder"; \
40
+ fi
41
+
42
+ # Set environment variables
43
+ ENV PYTHONUNBUFFERED=1
44
+
45
+ # Run the API server
46
+ CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8000"]
 
diffsketcher_model.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Simplified DiffSketcher model for text-to-SVG generation.
6
+ """
7
+
8
+ import os
9
+ import io
10
+ import base64
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ import clip
15
+ import torch.nn.functional as F
16
+ import xml.etree.ElementTree as ET
17
+ import cairosvg
18
+
19
+ class DiffSketcherModel:
20
+ def __init__(self, model_dir):
21
+ """Initialize the DiffSketcher model"""
22
+ self.model_dir = model_dir
23
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ # Load CLIP model
26
+ self.clip_model_path = os.path.join(model_dir, "ViT-B-32.pt")
27
+ if os.path.exists(self.clip_model_path):
28
+ print(f"Loading CLIP model from {self.clip_model_path}")
29
+ self.clip_model, _ = clip.load(self.clip_model_path, device=self.device)
30
+ else:
31
+ print(f"CLIP model not found at {self.clip_model_path}, downloading...")
32
+ self.clip_model, _ = clip.load("ViT-B-32", device=self.device)
33
+
34
+ # Set model to evaluation mode
35
+ self.clip_model.eval()
36
+
37
+ print(f"DiffSketcher model initialized on device: {self.device}")
38
+
39
+ def generate_svg(self, prompt, num_paths=10, width=512, height=512):
40
+ """Generate an SVG from a text prompt"""
41
+ print(f"Generating SVG for prompt: {prompt}")
42
+
43
+ # Encode the prompt with CLIP
44
+ with torch.no_grad():
45
+ text_features = self.clip_model.encode_text(clip.tokenize([prompt]).to(self.device))
46
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
47
+
48
+ # Generate a simple SVG based on the prompt
49
+ # In a real implementation, this would use the full DiffSketcher model
50
+ svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
51
+ <rect width="100%" height="100%" fill="#f0f0f0"/>
52
+ <text x="50%" y="10%" font-family="Arial" font-size="20" text-anchor="middle">Generated by DiffSketcher</text>
53
+ <text x="50%" y="50%" font-family="Arial" font-size="24" text-anchor="middle" font-weight="bold">{prompt}</text>
54
+ """
55
+
56
+ # Add some random paths based on the text features
57
+ for i in range(min(num_paths, text_features.shape[1])):
58
+ # Use the text features to generate path parameters
59
+ feature_val = text_features[0, i % text_features.shape[1]].item()
60
+ x = (feature_val + 1) * width / 2
61
+ y = ((i / num_paths) * 0.8 + 0.1) * height
62
+ radius = abs(feature_val) * 50 + 10
63
+ hue = (feature_val + 1) * 180
64
+
65
+ # Add a circle with color based on the feature
66
+ svg_content += f"""<circle cx="{x}" cy="{y}" r="{radius}" fill="hsl({hue}, 70%, 60%)" opacity="0.7" />"""
67
+
68
+ # Close the SVG
69
+ svg_content += "</svg>"
70
+
71
+ return svg_content
72
+
73
+ def svg_to_png(self, svg_content):
74
+ """Convert SVG content to PNG"""
75
+ try:
76
+ png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
77
+ return png_data
78
+ except Exception as e:
79
+ print(f"Error converting SVG to PNG: {e}")
80
+ # Create a simple error image
81
+ image = Image.new("RGB", (512, 512), color="#ff0000")
82
+ from PIL import ImageDraw
83
+ draw = ImageDraw.Draw(image)
84
+ draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
85
+
86
+ # Convert PIL Image to PNG data
87
+ buffer = io.BytesIO()
88
+ image.save(buffer, format="PNG")
89
+ return buffer.getvalue()
90
+
91
+ def __call__(self, prompt):
92
+ """Generate an SVG from a text prompt and convert to PNG"""
93
+ svg_content = self.generate_svg(prompt)
94
+ png_data = self.svg_to_png(svg_content)
95
+
96
+ # Create a PIL Image from the PNG data
97
+ image = Image.open(io.BytesIO(png_data))
98
+
99
+ # Create the response
100
+ response = {
101
+ "svg": svg_content,
102
+ "svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"),
103
+ "png_base64": base64.b64encode(png_data).decode("utf-8"),
104
+ "image": image
105
+ }
106
+
107
+ return response
handler.py CHANGED
@@ -15,12 +15,49 @@ except ImportError:
15
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
16
  import cairosvg
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class EndpointHandler:
19
  def __init__(self, model_dir):
20
  """Initialize the handler with model directory"""
21
  self.model_dir = model_dir
22
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- print(f"Initialized model on device: {self.device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def __call__(self, data):
26
  """Handle a request to the model"""
@@ -36,30 +73,31 @@ class EndpointHandler:
36
  else:
37
  prompt = "No prompt provided"
38
 
39
- # Generate a placeholder SVG
40
- width, height = 512, 512
41
- svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
42
- <rect width="100%" height="100%" fill="#f0f0f0"/>
43
- <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">{prompt}</text>
44
- </svg>"""
45
-
46
- # Convert SVG to PNG using cairosvg
47
- try:
48
- png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
49
- # Create a PIL Image from the PNG data
50
- image = Image.open(io.BytesIO(png_data))
51
- except Exception as e:
52
- print(f"Error converting SVG to PNG: {e}")
53
- # Create a simple placeholder image
54
- image = Image.new("RGB", (width, height), color="#f0f0f0")
55
- # Add text to the image
56
- from PIL import ImageDraw, ImageFont
57
- draw = ImageDraw.Draw(image)
58
  try:
59
- font = ImageFont.truetype("Arial", 20)
60
- except:
61
- font = ImageFont.load_default()
62
- draw.text((width/2, height/2), prompt, fill="black", font=font, anchor="mm")
 
 
 
 
 
63
 
64
  # Return the PIL Image directly
65
  return image
 
15
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
16
  import cairosvg
17
 
18
+ # Safely import clip with fallback
19
+ try:
20
+ import clip
21
+ except ImportError:
22
+ print("Warning: clip not found. Installing...")
23
+ import subprocess
24
+ subprocess.check_call(["pip", "install", "git+https://github.com/openai/CLIP.git"])
25
+ import clip
26
+
27
+ # Import the DiffSketcher model
28
+ try:
29
+ from diffsketcher_model import DiffSketcherModel
30
+ except ImportError:
31
+ print("Warning: diffsketcher_model not found. Using placeholder.")
32
+ DiffSketcherModel = None
33
+
34
  class EndpointHandler:
35
  def __init__(self, model_dir):
36
  """Initialize the handler with model directory"""
37
  self.model_dir = model_dir
38
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ print(f"Initializing model on device: {self.device}")
40
+
41
+ # Initialize the DiffSketcher model if available
42
+ if DiffSketcherModel is not None:
43
+ try:
44
+ self.model = DiffSketcherModel(model_dir)
45
+ self.use_model = True
46
+ print("DiffSketcher model initialized successfully")
47
+ except Exception as e:
48
+ print(f"Error initializing DiffSketcher model: {e}")
49
+ self.use_model = False
50
+ else:
51
+ self.use_model = False
52
+ print("Using placeholder SVG generator")
53
+
54
+ def generate_placeholder_svg(self, prompt, width=512, height=512):
55
+ """Generate a placeholder SVG"""
56
+ svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
57
+ <rect width="100%" height="100%" fill="#f0f0f0"/>
58
+ <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">{prompt}</text>
59
+ </svg>"""
60
+ return svg_content
61
 
62
  def __call__(self, data):
63
  """Handle a request to the model"""
 
73
  else:
74
  prompt = "No prompt provided"
75
 
76
+ # Generate SVG using the model or placeholder
77
+ if self.use_model:
78
+ try:
79
+ # Use the DiffSketcher model
80
+ result = self.model(prompt)
81
+ image = result["image"]
82
+ except Exception as e:
83
+ print(f"Error using DiffSketcher model: {e}")
84
+ # Fall back to placeholder
85
+ svg_content = self.generate_placeholder_svg(prompt)
86
+ png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
87
+ image = Image.open(io.BytesIO(png_data))
88
+ else:
89
+ # Use the placeholder SVG generator
90
+ svg_content = self.generate_placeholder_svg(prompt)
 
 
 
 
91
  try:
92
+ png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
93
+ image = Image.open(io.BytesIO(png_data))
94
+ except Exception as e:
95
+ print(f"Error converting SVG to PNG: {e}")
96
+ # Create a simple placeholder image
97
+ image = Image.new("RGB", (512, 512), color="#f0f0f0")
98
+ from PIL import ImageDraw
99
+ draw = ImageDraw.Draw(image)
100
+ draw.text((256, 256), prompt, fill="black", anchor="mm")
101
 
102
  # Return the PIL Image directly
103
  return image