import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') # Use non-interactive backend for server environments import networkx as nx import json import numpy as np from loguru import logger import os import tempfile from datetime import datetime class DAGVisualizer: def __init__(self): # Configure Matplotlib to use IEEE-style parameters plt.rcParams.update({ 'font.family': 'DejaVu Sans', # Use available font instead of Times New Roman 'font.size': 10, 'axes.linewidth': 1.2, 'axes.labelsize': 12, 'xtick.labelsize': 10, 'ytick.labelsize': 10, 'legend.fontsize': 10, 'figure.titlesize': 14 }) def create_dag_from_tasks(self, task_data): """ Create a directed graph from task data. Args: task_data: Dictionary containing tasks with structure like: { "tasks": [ { "task": "task_name", "instruction_function": { "name": "function_name", "robot_ids": ["robot1", "robot2"], "dependencies": ["dependency_task"], "object_keywords": ["object1", "object2"] } } ] } Returns: NetworkX DiGraph object """ if not task_data or "tasks" not in task_data: logger.warning("Invalid task data structure") return None # Create a directed graph G = nx.DiGraph() # Add nodes and store mapping from task name to ID task_mapping = {} for i, task in enumerate(task_data["tasks"]): task_id = i + 1 task_name = task["task"] task_mapping[task_name] = task_id # Add node with attributes G.add_node(task_id, name=task_name, function=task["instruction_function"]["name"], robots=task["instruction_function"].get("robot_ids", []), objects=task["instruction_function"].get("object_keywords", [])) # Add dependency edges for i, task in enumerate(task_data["tasks"]): task_id = i + 1 dependencies = task["instruction_function"]["dependencies"] for dep in dependencies: if dep in task_mapping: dep_id = task_mapping[dep] G.add_edge(dep_id, task_id) return G def calculate_layout(self, G): """ Calculate hierarchical layout for the graph based on dependencies. """ if not G: return {} # Calculate layers based on dependencies layers = {} def get_layer(node_id, visited=None): if visited is None: visited = set() if node_id in visited: return 0 visited.add(node_id) predecessors = list(G.predecessors(node_id)) if not predecessors: return 0 return max(get_layer(pred, visited.copy()) for pred in predecessors) + 1 for node in G.nodes(): layer = get_layer(node) layers.setdefault(layer, []).append(node) # Calculate positions by layer pos = {} layer_height = 3.0 node_width = 4.0 for layer_idx, nodes in layers.items(): y = layer_height * (len(layers) - 1 - layer_idx) start_x = -(len(nodes) - 1) * node_width / 2 for i, node in enumerate(sorted(nodes)): pos[node] = (start_x + i * node_width, y) return pos def create_dag_visualization(self, task_data, title="Robot Task Dependency Graph"): """ Create a DAG visualization from task data and return the image path. Args: task_data: Task data dictionary title: Title for the graph Returns: str: Path to the generated image file """ try: # Create graph G = self.create_dag_from_tasks(task_data) if not G or len(G.nodes()) == 0: logger.warning("No tasks found or invalid graph structure") return None # Calculate layout pos = self.calculate_layout(G) # Create figure fig, ax = plt.subplots(1, 1, figsize=(max(12, len(G.nodes()) * 2), 8)) # Draw edges with arrows nx.draw_networkx_edges(G, pos, edge_color='#2E86AB', arrows=True, arrowsize=20, arrowstyle='->', width=2, alpha=0.8, connectionstyle="arc3,rad=0.1") # Color nodes based on their position in the graph node_colors = [] for node in G.nodes(): if G.in_degree(node) == 0: # Start nodes node_colors.append('#F24236') elif G.out_degree(node) == 0: # End nodes node_colors.append('#A23B72') else: # Intermediate nodes node_colors.append('#F18F01') # Draw nodes nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=3500, alpha=0.9, edgecolors='black', linewidths=2) # Label nodes with task IDs node_labels = {node: f"T{node}" for node in G.nodes()} nx.draw_networkx_labels(G, pos, node_labels, font_size=18, font_weight='bold', font_color='white') # Add detailed info text boxes for each task for i, node in enumerate(G.nodes()): x, y = pos[node] function_name = G.nodes[node]['function'] robots = G.nodes[node]['robots'] objects = G.nodes[node]['objects'] # Create info text content info_text = f"Task {node}: {function_name.replace('_', ' ').title()}\n" if robots: robot_text = ", ".join([r.replace('robot_', '').replace('_', ' ').title() for r in robots]) info_text += f"Robots: {robot_text}\n" if objects: object_text = ", ".join(objects) info_text += f"Objects: {object_text}" # Calculate offset based on node position to avoid overlaps offset_x = 2.2 if i % 2 == 0 else -2.2 offset_y = 0.5 if i % 4 < 2 else -0.5 # Choose alignment based on offset direction h_align = 'left' if offset_x > 0 else 'right' # Draw text box bbox_props = dict(boxstyle="round,pad=0.4", facecolor='white', edgecolor='gray', alpha=0.95, linewidth=1) ax.text(x + offset_x, y + offset_y, info_text, bbox=bbox_props, fontsize=12, verticalalignment='center', horizontalalignment=h_align, weight='bold') # Draw dashed connector line from node to text box ax.plot([x, x + offset_x], [y, y + offset_y], linestyle='--', color='gray', alpha=0.6, linewidth=1) # Expand axis limits to fit everything x_vals = [coord[0] for coord in pos.values()] y_vals = [coord[1] for coord in pos.values()] ax.set_xlim(min(x_vals) - 4.0, max(x_vals) + 4.0) ax.set_ylim(min(y_vals) - 2.0, max(y_vals) + 2.0) # Set overall figure properties ax.set_title(title, fontsize=16, fontweight='bold', pad=20) ax.set_aspect('equal') ax.margins(0.2) ax.axis('off') # Add legend for node types - Hidden to avoid covering content # legend_elements = [ # plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#F24236', # markersize=10, label='Start Tasks', markeredgecolor='black'), # plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#A23B72', # markersize=10, label='End Tasks', markeredgecolor='black'), # plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#F18F01', # markersize=10, label='Intermediate Tasks', markeredgecolor='black'), # plt.Line2D([0], [0], color='#2E86AB', linewidth=2, label='Dependencies') # ] # ax.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.05, 1.05)) # Adjust layout and save plt.tight_layout() # Create temporary file for saving the image timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") temp_dir = tempfile.gettempdir() image_path = os.path.join(temp_dir, f'dag_visualization_{timestamp}.png') plt.savefig(image_path, dpi=400, bbox_inches='tight', pad_inches=0.1, facecolor='white', edgecolor='none') plt.close(fig) # Close figure to free memory logger.info(f"DAG visualization saved to: {image_path}") return image_path except Exception as e: logger.error(f"Error creating DAG visualization: {e}") return None def create_simplified_dag_visualization(self, task_data, title="Robot Task Graph"): """ Create a simplified DAG visualization suitable for smaller displays. Args: task_data: Task data dictionary title: Title for the graph Returns: str: Path to the generated image file """ try: # Create graph G = self.create_dag_from_tasks(task_data) if not G or len(G.nodes()) == 0: logger.warning("No tasks found or invalid graph structure") return None # Calculate layout pos = self.calculate_layout(G) # Create figure for simplified graph fig, ax = plt.subplots(1, 1, figsize=(10, 6)) # Draw edges nx.draw_networkx_edges(G, pos, edge_color='black', arrows=True, arrowsize=15, arrowstyle='->', width=1.5) # Draw nodes nx.draw_networkx_nodes(G, pos, node_color='lightblue', node_size=3000, edgecolors='black', linewidths=1.5) # Add node labels with simplified names labels = {} for node in G.nodes(): function_name = G.nodes[node]['function'] simplified_name = function_name.replace('_', ' ').title() if len(simplified_name) > 15: simplified_name = simplified_name[:12] + "..." labels[node] = f"T{node}\n{simplified_name}" nx.draw_networkx_labels(G, pos, labels, font_size=11, font_weight='bold') ax.set_title(title, fontsize=14, fontweight='bold') ax.axis('off') # Adjust layout and save plt.tight_layout() # Create temporary file for saving the image timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") temp_dir = tempfile.gettempdir() image_path = os.path.join(temp_dir, f'simple_dag_{timestamp}.png') plt.savefig(image_path, dpi=400, bbox_inches='tight') plt.close(fig) # Close figure to free memory logger.info(f"Simplified DAG visualization saved to: {image_path}") return image_path except Exception as e: logger.error(f"Error creating simplified DAG visualization: {e}") return None