DART-LLM-Multi-Model / dag_visualizer.py
Yongdong
Add DAG visualization functionality for robot task planning
1ef829e
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend for server environments
import networkx as nx
import json
import numpy as np
from loguru import logger
import os
import tempfile
from datetime import datetime
class DAGVisualizer:
def __init__(self):
# Configure Matplotlib to use IEEE-style parameters
plt.rcParams.update({
'font.family': 'DejaVu Sans', # Use available font instead of Times New Roman
'font.size': 10,
'axes.linewidth': 1.2,
'axes.labelsize': 12,
'xtick.labelsize': 10,
'ytick.labelsize': 10,
'legend.fontsize': 10,
'figure.titlesize': 14
})
def create_dag_from_tasks(self, task_data):
"""
Create a directed graph from task data.
Args:
task_data: Dictionary containing tasks with structure like:
{
"tasks": [
{
"task": "task_name",
"instruction_function": {
"name": "function_name",
"robot_ids": ["robot1", "robot2"],
"dependencies": ["dependency_task"],
"object_keywords": ["object1", "object2"]
}
}
]
}
Returns:
NetworkX DiGraph object
"""
if not task_data or "tasks" not in task_data:
logger.warning("Invalid task data structure")
return None
# Create a directed graph
G = nx.DiGraph()
# Add nodes and store mapping from task name to ID
task_mapping = {}
for i, task in enumerate(task_data["tasks"]):
task_id = i + 1
task_name = task["task"]
task_mapping[task_name] = task_id
# Add node with attributes
G.add_node(task_id,
name=task_name,
function=task["instruction_function"]["name"],
robots=task["instruction_function"].get("robot_ids", []),
objects=task["instruction_function"].get("object_keywords", []))
# Add dependency edges
for i, task in enumerate(task_data["tasks"]):
task_id = i + 1
dependencies = task["instruction_function"]["dependencies"]
for dep in dependencies:
if dep in task_mapping:
dep_id = task_mapping[dep]
G.add_edge(dep_id, task_id)
return G
def calculate_layout(self, G):
"""
Calculate hierarchical layout for the graph based on dependencies.
"""
if not G:
return {}
# Calculate layers based on dependencies
layers = {}
def get_layer(node_id, visited=None):
if visited is None:
visited = set()
if node_id in visited:
return 0
visited.add(node_id)
predecessors = list(G.predecessors(node_id))
if not predecessors:
return 0
return max(get_layer(pred, visited.copy()) for pred in predecessors) + 1
for node in G.nodes():
layer = get_layer(node)
layers.setdefault(layer, []).append(node)
# Calculate positions by layer
pos = {}
layer_height = 3.0
node_width = 4.0
for layer_idx, nodes in layers.items():
y = layer_height * (len(layers) - 1 - layer_idx)
start_x = -(len(nodes) - 1) * node_width / 2
for i, node in enumerate(sorted(nodes)):
pos[node] = (start_x + i * node_width, y)
return pos
def create_dag_visualization(self, task_data, title="Robot Task Dependency Graph"):
"""
Create a DAG visualization from task data and return the image path.
Args:
task_data: Task data dictionary
title: Title for the graph
Returns:
str: Path to the generated image file
"""
try:
# Create graph
G = self.create_dag_from_tasks(task_data)
if not G or len(G.nodes()) == 0:
logger.warning("No tasks found or invalid graph structure")
return None
# Calculate layout
pos = self.calculate_layout(G)
# Create figure
fig, ax = plt.subplots(1, 1, figsize=(max(12, len(G.nodes()) * 2), 8))
# Draw edges with arrows
nx.draw_networkx_edges(G, pos,
edge_color='#2E86AB',
arrows=True,
arrowsize=20,
arrowstyle='->',
width=2,
alpha=0.8,
connectionstyle="arc3,rad=0.1")
# Color nodes based on their position in the graph
node_colors = []
for node in G.nodes():
if G.in_degree(node) == 0: # Start nodes
node_colors.append('#F24236')
elif G.out_degree(node) == 0: # End nodes
node_colors.append('#A23B72')
else: # Intermediate nodes
node_colors.append('#F18F01')
# Draw nodes
nx.draw_networkx_nodes(G, pos,
node_color=node_colors,
node_size=3500,
alpha=0.9,
edgecolors='black',
linewidths=2)
# Label nodes with task IDs
node_labels = {node: f"T{node}" for node in G.nodes()}
nx.draw_networkx_labels(G, pos, node_labels,
font_size=18,
font_weight='bold',
font_color='white')
# Add detailed info text boxes for each task
for i, node in enumerate(G.nodes()):
x, y = pos[node]
function_name = G.nodes[node]['function']
robots = G.nodes[node]['robots']
objects = G.nodes[node]['objects']
# Create info text content
info_text = f"Task {node}: {function_name.replace('_', ' ').title()}\n"
if robots:
robot_text = ", ".join([r.replace('robot_', '').replace('_', ' ').title() for r in robots])
info_text += f"Robots: {robot_text}\n"
if objects:
object_text = ", ".join(objects)
info_text += f"Objects: {object_text}"
# Calculate offset based on node position to avoid overlaps
offset_x = 2.2 if i % 2 == 0 else -2.2
offset_y = 0.5 if i % 4 < 2 else -0.5
# Choose alignment based on offset direction
h_align = 'left' if offset_x > 0 else 'right'
# Draw text box
bbox_props = dict(boxstyle="round,pad=0.4",
facecolor='white',
edgecolor='gray',
alpha=0.95,
linewidth=1)
ax.text(x + offset_x, y + offset_y, info_text,
bbox=bbox_props,
fontsize=12,
verticalalignment='center',
horizontalalignment=h_align,
weight='bold')
# Draw dashed connector line from node to text box
ax.plot([x, x + offset_x], [y, y + offset_y],
linestyle='--', color='gray', alpha=0.6, linewidth=1)
# Expand axis limits to fit everything
x_vals = [coord[0] for coord in pos.values()]
y_vals = [coord[1] for coord in pos.values()]
ax.set_xlim(min(x_vals) - 4.0, max(x_vals) + 4.0)
ax.set_ylim(min(y_vals) - 2.0, max(y_vals) + 2.0)
# Set overall figure properties
ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
ax.set_aspect('equal')
ax.margins(0.2)
ax.axis('off')
# Add legend for node types - Hidden to avoid covering content
# legend_elements = [
# plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#F24236',
# markersize=10, label='Start Tasks', markeredgecolor='black'),
# plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#A23B72',
# markersize=10, label='End Tasks', markeredgecolor='black'),
# plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#F18F01',
# markersize=10, label='Intermediate Tasks', markeredgecolor='black'),
# plt.Line2D([0], [0], color='#2E86AB', linewidth=2, label='Dependencies')
# ]
# ax.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.05, 1.05))
# Adjust layout and save
plt.tight_layout()
# Create temporary file for saving the image
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
temp_dir = tempfile.gettempdir()
image_path = os.path.join(temp_dir, f'dag_visualization_{timestamp}.png')
plt.savefig(image_path, dpi=400, bbox_inches='tight',
pad_inches=0.1, facecolor='white', edgecolor='none')
plt.close(fig) # Close figure to free memory
logger.info(f"DAG visualization saved to: {image_path}")
return image_path
except Exception as e:
logger.error(f"Error creating DAG visualization: {e}")
return None
def create_simplified_dag_visualization(self, task_data, title="Robot Task Graph"):
"""
Create a simplified DAG visualization suitable for smaller displays.
Args:
task_data: Task data dictionary
title: Title for the graph
Returns:
str: Path to the generated image file
"""
try:
# Create graph
G = self.create_dag_from_tasks(task_data)
if not G or len(G.nodes()) == 0:
logger.warning("No tasks found or invalid graph structure")
return None
# Calculate layout
pos = self.calculate_layout(G)
# Create figure for simplified graph
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
# Draw edges
nx.draw_networkx_edges(G, pos,
edge_color='black',
arrows=True,
arrowsize=15,
arrowstyle='->',
width=1.5)
# Draw nodes
nx.draw_networkx_nodes(G, pos,
node_color='lightblue',
node_size=3000,
edgecolors='black',
linewidths=1.5)
# Add node labels with simplified names
labels = {}
for node in G.nodes():
function_name = G.nodes[node]['function']
simplified_name = function_name.replace('_', ' ').title()
if len(simplified_name) > 15:
simplified_name = simplified_name[:12] + "..."
labels[node] = f"T{node}\n{simplified_name}"
nx.draw_networkx_labels(G, pos, labels,
font_size=11,
font_weight='bold')
ax.set_title(title, fontsize=14, fontweight='bold')
ax.axis('off')
# Adjust layout and save
plt.tight_layout()
# Create temporary file for saving the image
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
temp_dir = tempfile.gettempdir()
image_path = os.path.join(temp_dir, f'simple_dag_{timestamp}.png')
plt.savefig(image_path, dpi=400, bbox_inches='tight')
plt.close(fig) # Close figure to free memory
logger.info(f"Simplified DAG visualization saved to: {image_path}")
return image_path
except Exception as e:
logger.error(f"Error creating simplified DAG visualization: {e}")
return None