import gradio as gr import spaces # Import spaces module for ZeroGPU from huggingface_hub import login import os from json_processor import JsonProcessor from dag_visualizer import DAGVisualizer import json # 1) Read Secrets hf_token = os.getenv("HUGGINGFACE_TOKEN") if not hf_token: raise RuntimeError("❌ HUGGINGFACE_TOKEN not detected, please check Space Settings → Secrets") # 2) Login to ensure all subsequent from_pretrained calls have proper permissions login(hf_token) import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel import warnings import os warnings.filterwarnings("ignore") # Model configurations MODEL_CONFIGS = { "1B": { "name": "Dart-llm-model-1B", "base_model": "meta-llama/Llama-3.2-1B", "lora_model": "YongdongWang/llama-3.2-1b-lora-qlora-dart-llm" }, "3B": { "name": "Dart-llm-model-3B", "base_model": "meta-llama/Llama-3.2-3B", "lora_model": "YongdongWang/llama-3.2-3b-lora-qlora-dart-llm" }, "8B": { "name": "Dart-llm-model-8B", "base_model": "meta-llama/Llama-3.1-8B", "lora_model": "YongdongWang/llama-3.1-8b-lora-qlora-dart-llm" } } DEFAULT_MODEL = "1B" # Set 1B as default # Global variables to store model and tokenizer model = None tokenizer = None current_model_config = None model_loaded = False def load_model_and_tokenizer(selected_model=DEFAULT_MODEL): """Load tokenizer - executed on CPU""" global tokenizer, model_loaded, current_model_config if model_loaded and current_model_config == selected_model: return print(f"🔄 Loading tokenizer for {MODEL_CONFIGS[selected_model]['name']}...") # Load tokenizer (on CPU) base_model = MODEL_CONFIGS[selected_model]["base_model"] tokenizer = AutoTokenizer.from_pretrained( base_model, use_fast=False, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token current_model_config = selected_model model_loaded = True print("✅ Tokenizer loaded successfully!") @spaces.GPU(duration=60) # Request GPU for loading model at startup def load_model_on_gpu(selected_model=DEFAULT_MODEL): """Load model on GPU""" global model # If model is already loaded and it's the same model, return it if model is not None and current_model_config == selected_model: return model # Clear existing model if switching if model is not None: print("🗑️ Clearing existing model from GPU...") del model torch.cuda.empty_cache() model = None model_config = MODEL_CONFIGS[selected_model] print(f"🔄 Loading {model_config['name']} on GPU...") try: # 4-bit quantization configuration bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) # Load base model base_model = AutoModelForCausalLM.from_pretrained( model_config["base_model"], quantization_config=bnb_config, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True ) # Load LoRA adapter model = PeftModel.from_pretrained( base_model, model_config["lora_model"], torch_dtype=torch.float16, use_safetensors=True ) model.eval() print(f"✅ {model_config['name']} loaded on GPU successfully!") return model except Exception as load_error: print(f"❌ Model loading failed: {load_error}") raise load_error def process_json_in_response(response): """Process and format JSON content in the response""" try: # Check if response contains JSON-like content if '{' in response and '}' in response: processor = JsonProcessor() # Try to process the response for JSON content processed_json = processor.process_response(response) if processed_json: # Format the JSON nicely formatted_json = json.dumps(processed_json, indent=2, ensure_ascii=False) # Replace the JSON part in the response import re json_pattern = r'\{.*\}' match = re.search(json_pattern, response, re.DOTALL) if match: # Replace the matched JSON with the formatted version response = response.replace(match.group(), formatted_json) return response except Exception: # If processing fails, return original response return response @spaces.GPU(duration=60) # GPU inference def generate_response_gpu(prompt, max_tokens=512, selected_model=DEFAULT_MODEL): """Generate response - executed on GPU""" global model # Ensure tokenizer is loaded if tokenizer is None or current_model_config != selected_model: load_model_and_tokenizer(selected_model) # Ensure model is loaded on GPU if model is None or current_model_config != selected_model: model = load_model_on_gpu(selected_model) if model is None: return "❌ Model failed to load. Please check the Space logs." try: formatted_prompt = ( "### Instruction:\n" f"{prompt.strip()}\n\n" "### Response:\n" ) # Encode input inputs = tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=2048 ).to(model.device) # Generate response with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=False, temperature=None, top_p=None, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1, early_stopping=True, no_repeat_ngram_size=3 ) # Decode output response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract generated part if "### Response:" in response: response = response.split("### Response:")[-1].strip() elif len(response) > len(formatted_prompt): response = response[len(formatted_prompt):].strip() # Process JSON if present in response response = process_json_in_response(response) return response if response else "❌ No response generated. Please try again with a different prompt." except Exception as generation_error: return f"❌ Generation Error: {str(generation_error)}" def create_dag_visualization(task_json_str): """Create DAG visualization from task JSON""" try: if not task_json_str.strip(): return None, "Please provide task JSON data" # Parse JSON task_data = json.loads(task_json_str) # Create DAG visualizer dag_visualizer = DAGVisualizer() # Generate visualization image_path = dag_visualizer.create_dag_visualization(task_data) if image_path: return image_path, "✅ DAG visualization created successfully!" else: return None, "❌ Failed to create DAG visualization" except json.JSONDecodeError as e: return None, f"❌ JSON Parse Error: {str(e)}" except Exception as e: return None, f"❌ DAG Creation Error: {str(e)}" def chat_interface(message, history, max_tokens, selected_model): """Chat interface - runs on CPU, calls GPU functions""" if not message.strip(): return history, "" # Initialize tokenizer (if needed) if tokenizer is None or current_model_config != selected_model: load_model_and_tokenizer(selected_model) try: # Call GPU function to generate response response = generate_response_gpu(message, max_tokens, selected_model) history.append((message, response)) return history, "" except Exception as chat_error: error_msg = f"❌ Chat Error: {str(chat_error)}" history.append((message, error_msg)) return history, "" # Load tokenizer at startup with default model load_model_and_tokenizer(DEFAULT_MODEL) # Create Gradio application with gr.Blocks( title="Robot Task Planning - DART-LLM Multi-Model", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px; margin: auto; } """ ) as app: gr.Markdown(""" # 🤖 DART-LLM Multi-Model - Robot Task Planning Choose from **three fine-tuned models** specialized for **robot task planning** using QLoRA technique: - **🚀 Dart-llm-model-1B**: Ready for Jetson Nano deployment (870MB GGUF) - **⚖️ Dart-llm-model-3B**: Ready for Jetson Xavier NX deployment (1.9GB GGUF) - **🎯 Dart-llm-model-8B**: Ready for Jetson AGX Xavier/Orin deployment (4.6GB GGUF) **Capabilities**: Convert natural language robot commands into structured task sequences for excavators, dump trucks, and other construction robots. **Edge-ready for Jetson devices with DAG Visualization!** ## 🔧 Recommended for Jetson Deployment (GGUF Models) For optimal edge deployment performance, use these GGUF quantized models: - **[YongdongWang/llama-3.2-1b-lora-qlora-dart-llm-gguf](https://huggingface.co/YongdongWang/llama-3.2-1b-lora-qlora-dart-llm-gguf)** (870MB) - Jetson Nano/Orin Nano - **[YongdongWang/llama-3.2-3b-lora-qlora-dart-llm-gguf](https://huggingface.co/YongdongWang/llama-3.2-3b-lora-qlora-dart-llm-gguf)** (1.9GB) - Jetson Orin NX/AGX Orin - **[YongdongWang/llama-3.1-8b-lora-qlora-dart-llm-gguf](https://huggingface.co/YongdongWang/llama-3.1-8b-lora-qlora-dart-llm-gguf)** (4.6GB) - High-end Jetson AGX Orin 💡 **Deploy with**: Ollama, llama.cpp, or llama-cpp-python for efficient edge inference """) with gr.Tabs(): with gr.Tab("💬 Task Planning"): with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( label="Task Planning Results", height=500, show_label=True, container=True, bubble_full_width=False, show_copy_button=True ) msg = gr.Textbox( label="Robot Command", placeholder="Enter robot task command (e.g., 'Deploy Excavator 1 to Soil Area 1')...", lines=2, max_lines=5, show_label=True, container=True ) with gr.Row(): send_btn = gr.Button("🚀 Generate Tasks", variant="primary", size="sm") clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="sm") with gr.Column(scale=1): gr.Markdown("### ⚙️ Generation Settings") model_selector = gr.Dropdown( choices=[(config["name"], key) for key, config in MODEL_CONFIGS.items()], value=DEFAULT_MODEL, label="Model Size", info="Select model for your Jetson device (1B = Nano, 3B = Xavier NX, 8B = AGX)", interactive=True ) max_tokens = gr.Slider( minimum=50, maximum=5000, value=512, step=10, label="Max Tokens", info="Maximum number of tokens to generate" ) gr.Markdown(""" ### 🔧 GGUF Models for Jetson Deployment **Recommended for edge deployment:** - **1B (870MB)**: Jetson Nano/Orin Nano (2GB RAM) - **3B (1.9GB)**: Jetson Orin NX/AGX Orin (4GB RAM) - **8B (4.6GB)**: High-end Jetson AGX Orin (8GB RAM) 💡 Use **Ollama** or **llama.cpp** for efficient inference """) with gr.Tab("📊 DAG Visualization"): with gr.Row(): with gr.Column(scale=2): json_input = gr.Textbox( label="Task JSON Data", placeholder="Paste the generated task JSON here to create a DAG visualization...", lines=15, max_lines=25, show_label=True, container=True ) with gr.Row(): dag_btn = gr.Button("🎨 Generate DAG", variant="primary", size="sm") dag_clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="sm") dag_status = gr.Textbox( label="Status", value="Ready to generate DAG visualization", interactive=False, show_label=True ) with gr.Column(scale=3): dag_output = gr.Image( label="Task Dependency Graph", show_label=True, container=True, height=600 ) gr.Markdown(""" ### 📈 DAG Features - **Node Colors**: Red (Start), Orange (Intermediate), Purple (End) - **Arrows**: Show task dependencies - **Layout**: Hierarchical based on dependencies - **Details**: Task info boxes with robots and objects """) # Example conversations gr.Examples( examples=[ "Dump truck 1 goes to the puddle for inspection, after which all robots avoid the puddle.", "Drive the Excavator 1 to the obstacle, and perform excavation to clear the obstacle.", "Send Excavator 1 and Dump Truck 1 to the soil area; Excavator 1 will excavate and unload, followed by Dump Truck 1 proceeding to the puddle for unloading.", "Move Excavator 1 and Dump Truck 1 to soil area 2; Excavator 1 will excavate and unload, then Dump Truck 1 returns to the starting position to unload.", "Excavator 1 is guided to the obstacle to excavate and unload to clear the obstacle, then excavator 1 and dump truck 1 are moved to the soil area, and the excavator excavates and unloads. Finally, dump truck 1 unloads the soil into the puddle.", "Excavator 1 goes to the obstacle to excavate and unload to clear the obstacle. Once the obstacle is cleared, mobilize all available robots to proceed to the puddle area for inspection.", ], inputs=msg, label="💡 Example Operator Commands" ) # Event handling msg.submit( chat_interface, inputs=[msg, chatbot, max_tokens, model_selector], outputs=[chatbot, msg] ) send_btn.click( chat_interface, inputs=[msg, chatbot, max_tokens, model_selector], outputs=[chatbot, msg] ) clear_btn.click( lambda: ([], ""), outputs=[chatbot, msg] ) # DAG visualization event handlers dag_btn.click( create_dag_visualization, inputs=[json_input], outputs=[dag_output, dag_status] ) dag_clear_btn.click( lambda: ("", None, "Ready to generate DAG visualization"), outputs=[json_input, dag_output, dag_status] ) if __name__ == "__main__": app.launch( server_name="0.0.0.0", server_port=7860, share=True, show_error=True )