diffsketcher / app.py
jree423's picture
Update model files for Inference API
4039872 verified
raw
history blame
2.17 kB
import os
import sys
import json
import torch
from pathlib import Path
# Determine which model we're running based on the repository name
def get_model_type():
# Default to diffsketcher if we can't determine
model_type = "diffsketcher"
# Check if we're in a Hugging Face environment
if os.path.exists("/repository"):
repo_path = Path("/repository")
# Try to determine model type from repository name
if os.path.exists("/repository/.git"):
try:
with open("/repository/.git/config", "r") as f:
config = f.read()
if "svgdreamer" in config.lower():
model_type = "svgdreamer"
elif "diffsketcher_edit" in config.lower() or "diffsketcher-edit" in config.lower():
model_type = "diffsketcher_edit"
except:
pass
print(f"Detected model type: {model_type}")
return model_type
# Import the appropriate handler based on model type
def import_handler():
model_type = get_model_type()
if model_type == "svgdreamer":
from svgdreamer_handler import SVGDreamerHandler
return SVGDreamerHandler()
elif model_type == "diffsketcher_edit":
from diffsketcher_edit_handler import DiffSketcherEditHandler
return DiffSketcherEditHandler()
else:
from diffsketcher_handler import DiffSketcherHandler
return DiffSketcherHandler()
# Initialize the handler
handler = import_handler()
handler.initialize(None)
# Define the inference function for the API
def inference(model_inputs):
global handler
return handler.handle(model_inputs, None)
# This is used when running locally
if __name__ == "__main__":
# Test the handler with a sample input
sample_input = {
"inputs": "a beautiful mountain landscape",
"parameters": {}
}
result = inference(sample_input)
print(f"Generated SVG with {len(result['svg'])} characters")
# Save the SVG to a file
with open("output.svg", "w") as f:
f.write(result["svg"])
print("SVG saved to output.svg")