#!/usr/bin/env python # -*- coding: utf-8 -*- import os import sys import torch import numpy as np from PIL import Image import io import base64 from handler_template import BaseHandler # Add DiffSketcher to path sys.path.append("/app/model") class Handler(BaseHandler): def initialize(self): """Load the DiffSketcher model""" try: from models.clip_text_encoder import CLIPTextEncoder from models.sketch_generator import SketchGenerator # Load text encoder self.text_encoder = CLIPTextEncoder() self.text_encoder.to(self.device) self.text_encoder.eval() # Load sketch generator self.model = SketchGenerator() weights_path = os.path.join("/app/model/weights", "diffsketcher_model.pth") if os.path.exists(weights_path): state_dict = torch.load(weights_path, map_location=self.device) self.model.load_state_dict(state_dict) else: raise FileNotFoundError(f"Model weights not found at {weights_path}") self.model.to(self.device) self.model.eval() self.initialized = True print("DiffSketcher model initialized successfully") except Exception as e: print(f"Error initializing DiffSketcher model: {str(e)}") raise def preprocess(self, data): """Process the input data""" try: # Extract prompt from the request prompt = data.get("prompt", "") if not prompt: raise ValueError("No prompt provided in the request") # Encode text with CLIP with torch.no_grad(): text_embedding = self.text_encoder.encode_text(prompt) return { "text_embedding": text_embedding, "prompt": prompt } except Exception as e: print(f"Error in preprocessing: {str(e)}") raise def inference(self, inputs): """Generate SVG from text embedding""" try: text_embedding = inputs["text_embedding"] # Run inference with torch.no_grad(): svg_data = self.model.generate(text_embedding) return svg_data except Exception as e: print(f"Error during inference: {str(e)}") raise def postprocess(self, inference_output): """Format the model output""" try: svg_content = inference_output["svg_content"] # Return both the SVG content and base64 encoded version return { "svg_content": svg_content, "svg_base64": self.svg_to_base64(svg_content) } except Exception as e: print(f"Error in postprocessing: {str(e)}") return {"error": str(e)}