import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch import gc import spaces import xml.etree.ElementTree as ET import re import os # Clear GPU memory torch.cuda.empty_cache() gc.collect() # Alpaca prompt template alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. ### Instruction: {} ### Input: {} ### Response: {}""" # Load model with memory optimizations model_path = "vinoku89/qwen3-4B-svg-code-gen" tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True, trust_remote_code=True # Add this if needed for custom models ) def validate_svg(svg_content): """ Validate if SVG content is properly formatted and renderable """ try: # Clean up the SVG content svg_content = svg_content.strip() # If it doesn't start with ]*>.*?', svg_content, re.DOTALL | re.IGNORECASE) if svg_match: svg_content = svg_match.group(0) else: # If no complete SVG found, wrap content in SVG tags if any(tag in svg_content.lower() for tag in ['{svg_content}' else: raise ValueError("No valid SVG elements found") # Parse XML to validate structure ET.fromstring(svg_content) return True, svg_content except ET.ParseError as e: return False, f"XML Parse Error: {str(e)}" except Exception as e: return False, f"Validation Error: {str(e)}" @spaces.GPU(duration=60) # Add duration limit def generate_svg(prompt): # Clear cache before generation torch.cuda.empty_cache() # Format the prompt using Alpaca template instruction = "Generate SVG code based on the given description." formatted_prompt = alpaca_prompt.format( instruction, prompt, "" # Empty response - model will fill this ) inputs = tokenizer(formatted_prompt, return_tensors="pt") # Move inputs to the same device as model if hasattr(model, 'device'): inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): # Disable gradient computation to save memory outputs = model.generate( **inputs, max_length=1024, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id, max_new_tokens=512 # Limit new tokens instead of total length ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the response part (after "### Response:") response_start = generated_text.find("### Response:") if response_start != -1: svg_code = generated_text[response_start + len("### Response:"):].strip() else: # Fallback: remove the original formatted prompt svg_code = generated_text[len(formatted_prompt):].strip() # Validate SVG is_valid, result = validate_svg(svg_code) if is_valid: # SVG is valid validated_svg = result # Ensure the SVG has proper dimensions for display (keep moderate size) if 'width=' not in validated_svg or 'height=' not in validated_svg: validated_svg = validated_svg.replace('

🚫 Preview Not Available

Generated SVG contains errors:
{result}

""" # Clear cache after generation torch.cuda.empty_cache() return svg_code, svg_display # Authentication function using HF Space secrets def authenticate(username, password): """ Authentication function for Gradio using HF Space secrets Returns True if credentials are valid, False otherwise """ # Get credentials from HF Space secrets valid_username = os.getenv("user") # This matches your secret name "user" valid_password = os.getenv("password") # This matches your secret name "password" # Fallback credentials if secrets are not available (for local testing) if valid_username is None: valid_username = "user" print("Warning: 'user' secret not found, using fallback") if valid_password is None: valid_password = "password" print("Warning: 'password' secret not found, using fallback") return username == valid_username and password == valid_password # Minimal CSS for slightly larger HTML preview only custom_css = """ div[data-testid="HTML"] { min-height: 320px !important; } """ gradio_app = gr.Interface( fn=generate_svg, inputs=gr.Textbox( lines=2, placeholder="Describe the SVG you want (e.g., 'a red circle with blue border')..." ), outputs=[ gr.Code(label="Generated SVG Code", language="html"), gr.HTML(label="SVG Preview") ], title="SVG Code Generator", description="Generate SVG code from natural language using a fine-tuned LLM.", css=custom_css ) if __name__ == "__main__": gradio_app.launch(auth=(os.getenv("user"), os.getenv("password")), share=True, ssr_mode=False)