|
import os |
|
import sys |
|
import json |
|
import torch |
|
from pathlib import Path |
|
|
|
|
|
def get_model_type(): |
|
|
|
model_type = "diffsketcher" |
|
|
|
|
|
if os.path.exists("/repository"): |
|
repo_path = Path("/repository") |
|
|
|
if os.path.exists("/repository/.git"): |
|
try: |
|
with open("/repository/.git/config", "r") as f: |
|
config = f.read() |
|
if "svgdreamer" in config.lower(): |
|
model_type = "svgdreamer" |
|
elif "diffsketcher_edit" in config.lower() or "diffsketcher-edit" in config.lower(): |
|
model_type = "diffsketcher_edit" |
|
except: |
|
pass |
|
|
|
print(f"Detected model type: {model_type}") |
|
return model_type |
|
|
|
|
|
def import_handler(): |
|
model_type = get_model_type() |
|
|
|
if model_type == "svgdreamer": |
|
from svgdreamer_handler import SVGDreamerHandler |
|
return SVGDreamerHandler() |
|
elif model_type == "diffsketcher_edit": |
|
from diffsketcher_edit_handler import DiffSketcherEditHandler |
|
return DiffSketcherEditHandler() |
|
else: |
|
from diffsketcher_handler import DiffSketcherHandler |
|
return DiffSketcherHandler() |
|
|
|
|
|
handler = import_handler() |
|
handler.initialize(None) |
|
|
|
|
|
def inference(model_inputs): |
|
global handler |
|
return handler.handle(model_inputs, None) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
sample_input = { |
|
"inputs": "a beautiful mountain landscape", |
|
"parameters": {} |
|
} |
|
|
|
result = inference(sample_input) |
|
print(f"Generated SVG with {len(result['svg'])} characters") |
|
|
|
|
|
with open("output.svg", "w") as f: |
|
f.write(result["svg"]) |
|
|
|
print("SVG saved to output.svg") |