jree423 commited on
Commit
fdbaec8
·
verified ·
1 Parent(s): ac9a037

Upload diffsketcher_handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diffsketcher_handler.py +77 -134
diffsketcher_handler.py CHANGED
@@ -1,149 +1,92 @@
 
 
 
1
  import os
2
- import json
3
  import torch
4
- import base64
5
- from io import BytesIO
6
- from PIL import Image
7
- import cairosvg
8
  import numpy as np
 
 
 
 
9
 
10
- class DiffSketcherHandler:
11
- def __init__(self):
12
- self.initialized = False
13
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
- self.model = None
15
-
16
- def initialize(self, context):
17
- """Initialize the handler."""
18
- self.initialized = True
19
-
20
- # Import dependencies here to avoid issues during startup
21
- try:
22
- import pydiffvg
23
- self.diffvg = pydiffvg
24
- print("Successfully imported pydiffvg")
25
- except ImportError as e:
26
- print(f"Warning: Could not import pydiffvg: {e}")
27
- print("Will use placeholder SVG generation")
28
- self.diffvg = None
29
-
30
- # We'll initialize the actual model only when needed
31
- return None
32
-
33
- def _initialize_model(self):
34
- """Initialize the actual model when needed."""
35
- if self.model is not None:
36
- return
37
-
38
  try:
39
- # Try to import and initialize the actual model
40
- from diffusers import StableDiffusionPipeline
 
 
 
 
 
41
 
42
- # Load a small model for testing
43
- self.model = StableDiffusionPipeline.from_pretrained(
44
- "runwayml/stable-diffusion-v1-5",
45
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
46
- ).to(self.device)
 
 
 
 
 
 
47
 
48
- print("Successfully initialized the model")
 
49
  except Exception as e:
50
- print(f"Error initializing model: {e}")
51
- print("Will use placeholder generation")
52
- self.model = None
53
 
54
  def preprocess(self, data):
55
- """Preprocess the input data."""
56
- inputs = data.get("inputs", "")
57
- if not inputs:
58
- inputs = "a beautiful landscape"
59
-
60
- # Get parameters
61
- parameters = data.get("parameters", {})
62
- num_paths = parameters.get("num_paths", 96)
63
- token_ind = parameters.get("token_ind", 4)
64
- num_iter = parameters.get("num_iter", 800)
65
-
66
- return {
67
- "prompt": inputs,
68
- "num_paths": num_paths,
69
- "token_ind": token_ind,
70
- "num_iter": num_iter
71
- }
72
-
73
- def _generate_placeholder_svg(self, prompt):
74
- """Generate a placeholder SVG when the actual model is not available."""
75
- import svgwrite
76
-
77
- # Create a simple SVG
78
- dwg = svgwrite.Drawing(size=(512, 512))
79
- # Add a background rectangle
80
- dwg.add(dwg.rect(insert=(0, 0), size=('100%', '100%'), fill='#f0f0f0'))
81
- # Add a circle
82
- dwg.add(dwg.circle(center=(256, 256), r=100, fill='#3498db'))
83
- # Add the prompt as text
84
- dwg.add(dwg.text(prompt, insert=(50, 50), font_size=20, fill='black'))
85
- # Add a note that this is a placeholder
86
- dwg.add(dwg.text("Placeholder SVG - Model not available",
87
- insert=(50, 480), font_size=16, fill='red'))
88
-
89
- svg_string = dwg.tostring()
90
-
91
- # Convert SVG to PNG for preview
92
- png_data = cairosvg.svg2png(bytestring=svg_string.encode('utf-8'))
93
- image = Image.open(BytesIO(png_data))
94
-
95
- return svg_string, image
96
 
97
  def inference(self, inputs):
98
- """Run inference with the preprocessed inputs."""
99
- prompt = inputs["prompt"]
100
-
101
- # Try to initialize the model if not already done
102
- if self.model is None and self.diffvg is not None:
103
- try:
104
- self._initialize_model()
105
- except Exception as e:
106
- print(f"Error initializing model during inference: {e}")
107
-
108
- # If we have a working model, use it
109
- if self.model is not None and self.diffvg is not None:
110
- try:
111
- # This would be the actual DiffSketcher implementation
112
- # For now, we'll just generate a placeholder
113
- svg_string, image = self._generate_placeholder_svg(prompt)
114
- except Exception as e:
115
- print(f"Error during model inference: {e}")
116
- svg_string, image = self._generate_placeholder_svg(prompt)
117
- else:
118
- # Use placeholder if model is not available
119
- svg_string, image = self._generate_placeholder_svg(prompt)
120
-
121
- return {
122
- "svg": svg_string,
123
- "image": image
124
- }
125
 
126
  def postprocess(self, inference_output):
127
- """Post-process the model output."""
128
- svg_string = inference_output["svg"]
129
- image = inference_output["image"]
130
-
131
- # Convert image to base64 for JSON response
132
- buffered = BytesIO()
133
- image.save(buffered, format="PNG")
134
- img_str = base64.b64encode(buffered.getvalue()).decode()
135
- img_base64 = f"data:image/png;base64,{img_str}"
136
-
137
- return {
138
- "svg": svg_string,
139
- "image": img_base64
140
- }
141
-
142
- def handle(self, data, context):
143
- """Handle the request."""
144
- if not self.initialized:
145
- self.initialize(context)
146
 
147
- preprocessed_data = self.preprocess(data)
148
- inference_output = self.inference(preprocessed_data)
149
- return self.postprocess(inference_output)
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
  import os
5
+ import sys
6
  import torch
 
 
 
 
7
  import numpy as np
8
+ from PIL import Image
9
+ import io
10
+ import base64
11
+ from handler_template import BaseHandler
12
 
13
+ # Add DiffSketcher to path
14
+ sys.path.append("/app/model")
15
+
16
+ class Handler(BaseHandler):
17
+ def initialize(self):
18
+ """Load the DiffSketcher model"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  try:
20
+ from models.clip_text_encoder import CLIPTextEncoder
21
+ from models.sketch_generator import SketchGenerator
22
+
23
+ # Load text encoder
24
+ self.text_encoder = CLIPTextEncoder()
25
+ self.text_encoder.to(self.device)
26
+ self.text_encoder.eval()
27
 
28
+ # Load sketch generator
29
+ self.model = SketchGenerator()
30
+ weights_path = os.path.join("/app/model/weights", "diffsketcher_model.pth")
31
+ if os.path.exists(weights_path):
32
+ state_dict = torch.load(weights_path, map_location=self.device)
33
+ self.model.load_state_dict(state_dict)
34
+ else:
35
+ raise FileNotFoundError(f"Model weights not found at {weights_path}")
36
+
37
+ self.model.to(self.device)
38
+ self.model.eval()
39
 
40
+ self.initialized = True
41
+ print("DiffSketcher model initialized successfully")
42
  except Exception as e:
43
+ print(f"Error initializing DiffSketcher model: {str(e)}")
44
+ raise
 
45
 
46
  def preprocess(self, data):
47
+ """Process the input data"""
48
+ try:
49
+ # Extract prompt from the request
50
+ prompt = data.get("prompt", "")
51
+ if not prompt:
52
+ raise ValueError("No prompt provided in the request")
53
+
54
+ # Encode text with CLIP
55
+ with torch.no_grad():
56
+ text_embedding = self.text_encoder.encode_text(prompt)
57
+
58
+ return {
59
+ "text_embedding": text_embedding,
60
+ "prompt": prompt
61
+ }
62
+ except Exception as e:
63
+ print(f"Error in preprocessing: {str(e)}")
64
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def inference(self, inputs):
67
+ """Generate SVG from text embedding"""
68
+ try:
69
+ text_embedding = inputs["text_embedding"]
70
+
71
+ # Run inference
72
+ with torch.no_grad():
73
+ svg_data = self.model.generate(text_embedding)
74
+
75
+ return svg_data
76
+ except Exception as e:
77
+ print(f"Error during inference: {str(e)}")
78
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  def postprocess(self, inference_output):
81
+ """Format the model output"""
82
+ try:
83
+ svg_content = inference_output["svg_content"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ # Return both the SVG content and base64 encoded version
86
+ return {
87
+ "svg_content": svg_content,
88
+ "svg_base64": self.svg_to_base64(svg_content)
89
+ }
90
+ except Exception as e:
91
+ print(f"Error in postprocessing: {str(e)}")
92
+ return {"error": str(e)}