Spaces:
Sleeping
Sleeping
import matplotlib.pyplot as plt | |
import matplotlib | |
matplotlib.use('Agg') # Use non-interactive backend for server environments | |
import networkx as nx | |
import json | |
import numpy as np | |
from loguru import logger | |
import os | |
import tempfile | |
from datetime import datetime | |
class DAGVisualizer: | |
def __init__(self): | |
# Configure Matplotlib to use IEEE-style parameters | |
plt.rcParams.update({ | |
'font.family': 'DejaVu Sans', # Use available font instead of Times New Roman | |
'font.size': 10, | |
'axes.linewidth': 1.2, | |
'axes.labelsize': 12, | |
'xtick.labelsize': 10, | |
'ytick.labelsize': 10, | |
'legend.fontsize': 10, | |
'figure.titlesize': 14 | |
}) | |
def create_dag_from_tasks(self, task_data): | |
""" | |
Create a directed graph from task data. | |
Args: | |
task_data: Dictionary containing tasks with structure like: | |
{ | |
"tasks": [ | |
{ | |
"task": "task_name", | |
"instruction_function": { | |
"name": "function_name", | |
"robot_ids": ["robot1", "robot2"], | |
"dependencies": ["dependency_task"], | |
"object_keywords": ["object1", "object2"] | |
} | |
} | |
] | |
} | |
Returns: | |
NetworkX DiGraph object | |
""" | |
if not task_data or "tasks" not in task_data: | |
logger.warning("Invalid task data structure") | |
return None | |
# Create a directed graph | |
G = nx.DiGraph() | |
# Add nodes and store mapping from task name to ID | |
task_mapping = {} | |
for i, task in enumerate(task_data["tasks"]): | |
task_id = i + 1 | |
task_name = task["task"] | |
task_mapping[task_name] = task_id | |
# Add node with attributes | |
G.add_node(task_id, | |
name=task_name, | |
function=task["instruction_function"]["name"], | |
robots=task["instruction_function"].get("robot_ids", []), | |
objects=task["instruction_function"].get("object_keywords", [])) | |
# Add dependency edges | |
for i, task in enumerate(task_data["tasks"]): | |
task_id = i + 1 | |
dependencies = task["instruction_function"]["dependencies"] | |
for dep in dependencies: | |
if dep in task_mapping: | |
dep_id = task_mapping[dep] | |
G.add_edge(dep_id, task_id) | |
return G | |
def calculate_layout(self, G): | |
""" | |
Calculate hierarchical layout for the graph based on dependencies. | |
""" | |
if not G: | |
return {} | |
# Calculate layers based on dependencies | |
layers = {} | |
def get_layer(node_id, visited=None): | |
if visited is None: | |
visited = set() | |
if node_id in visited: | |
return 0 | |
visited.add(node_id) | |
predecessors = list(G.predecessors(node_id)) | |
if not predecessors: | |
return 0 | |
return max(get_layer(pred, visited.copy()) for pred in predecessors) + 1 | |
for node in G.nodes(): | |
layer = get_layer(node) | |
layers.setdefault(layer, []).append(node) | |
# Calculate positions by layer | |
pos = {} | |
layer_height = 3.0 | |
node_width = 4.0 | |
for layer_idx, nodes in layers.items(): | |
y = layer_height * (len(layers) - 1 - layer_idx) | |
start_x = -(len(nodes) - 1) * node_width / 2 | |
for i, node in enumerate(sorted(nodes)): | |
pos[node] = (start_x + i * node_width, y) | |
return pos | |
def create_dag_visualization(self, task_data, title="Robot Task Dependency Graph"): | |
""" | |
Create a DAG visualization from task data and return the image path. | |
Args: | |
task_data: Task data dictionary | |
title: Title for the graph | |
Returns: | |
str: Path to the generated image file | |
""" | |
try: | |
# Create graph | |
G = self.create_dag_from_tasks(task_data) | |
if not G or len(G.nodes()) == 0: | |
logger.warning("No tasks found or invalid graph structure") | |
return None | |
# Calculate layout | |
pos = self.calculate_layout(G) | |
# Create figure | |
fig, ax = plt.subplots(1, 1, figsize=(max(12, len(G.nodes()) * 2), 8)) | |
# Draw edges with arrows | |
nx.draw_networkx_edges(G, pos, | |
edge_color='#2E86AB', | |
arrows=True, | |
arrowsize=20, | |
arrowstyle='->', | |
width=2, | |
alpha=0.8, | |
connectionstyle="arc3,rad=0.1") | |
# Color nodes based on their position in the graph | |
node_colors = [] | |
for node in G.nodes(): | |
if G.in_degree(node) == 0: # Start nodes | |
node_colors.append('#F24236') | |
elif G.out_degree(node) == 0: # End nodes | |
node_colors.append('#A23B72') | |
else: # Intermediate nodes | |
node_colors.append('#F18F01') | |
# Draw nodes | |
nx.draw_networkx_nodes(G, pos, | |
node_color=node_colors, | |
node_size=3500, | |
alpha=0.9, | |
edgecolors='black', | |
linewidths=2) | |
# Label nodes with task IDs | |
node_labels = {node: f"T{node}" for node in G.nodes()} | |
nx.draw_networkx_labels(G, pos, node_labels, | |
font_size=18, | |
font_weight='bold', | |
font_color='white') | |
# Add detailed info text boxes for each task | |
for i, node in enumerate(G.nodes()): | |
x, y = pos[node] | |
function_name = G.nodes[node]['function'] | |
robots = G.nodes[node]['robots'] | |
objects = G.nodes[node]['objects'] | |
# Create info text content | |
info_text = f"Task {node}: {function_name.replace('_', ' ').title()}\n" | |
if robots: | |
robot_text = ", ".join([r.replace('robot_', '').replace('_', ' ').title() for r in robots]) | |
info_text += f"Robots: {robot_text}\n" | |
if objects: | |
object_text = ", ".join(objects) | |
info_text += f"Objects: {object_text}" | |
# Calculate offset based on node position to avoid overlaps | |
offset_x = 2.2 if i % 2 == 0 else -2.2 | |
offset_y = 0.5 if i % 4 < 2 else -0.5 | |
# Choose alignment based on offset direction | |
h_align = 'left' if offset_x > 0 else 'right' | |
# Draw text box | |
bbox_props = dict(boxstyle="round,pad=0.4", | |
facecolor='white', | |
edgecolor='gray', | |
alpha=0.95, | |
linewidth=1) | |
ax.text(x + offset_x, y + offset_y, info_text, | |
bbox=bbox_props, | |
fontsize=12, | |
verticalalignment='center', | |
horizontalalignment=h_align, | |
weight='bold') | |
# Draw dashed connector line from node to text box | |
ax.plot([x, x + offset_x], [y, y + offset_y], | |
linestyle='--', color='gray', alpha=0.6, linewidth=1) | |
# Expand axis limits to fit everything | |
x_vals = [coord[0] for coord in pos.values()] | |
y_vals = [coord[1] for coord in pos.values()] | |
ax.set_xlim(min(x_vals) - 4.0, max(x_vals) + 4.0) | |
ax.set_ylim(min(y_vals) - 2.0, max(y_vals) + 2.0) | |
# Set overall figure properties | |
ax.set_title(title, fontsize=16, fontweight='bold', pad=20) | |
ax.set_aspect('equal') | |
ax.margins(0.2) | |
ax.axis('off') | |
# Add legend for node types - Hidden to avoid covering content | |
# legend_elements = [ | |
# plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#F24236', | |
# markersize=10, label='Start Tasks', markeredgecolor='black'), | |
# plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#A23B72', | |
# markersize=10, label='End Tasks', markeredgecolor='black'), | |
# plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#F18F01', | |
# markersize=10, label='Intermediate Tasks', markeredgecolor='black'), | |
# plt.Line2D([0], [0], color='#2E86AB', linewidth=2, label='Dependencies') | |
# ] | |
# ax.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.05, 1.05)) | |
# Adjust layout and save | |
plt.tight_layout() | |
# Create temporary file for saving the image | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
temp_dir = tempfile.gettempdir() | |
image_path = os.path.join(temp_dir, f'dag_visualization_{timestamp}.png') | |
plt.savefig(image_path, dpi=400, bbox_inches='tight', | |
pad_inches=0.1, facecolor='white', edgecolor='none') | |
plt.close(fig) # Close figure to free memory | |
logger.info(f"DAG visualization saved to: {image_path}") | |
return image_path | |
except Exception as e: | |
logger.error(f"Error creating DAG visualization: {e}") | |
return None | |
def create_simplified_dag_visualization(self, task_data, title="Robot Task Graph"): | |
""" | |
Create a simplified DAG visualization suitable for smaller displays. | |
Args: | |
task_data: Task data dictionary | |
title: Title for the graph | |
Returns: | |
str: Path to the generated image file | |
""" | |
try: | |
# Create graph | |
G = self.create_dag_from_tasks(task_data) | |
if not G or len(G.nodes()) == 0: | |
logger.warning("No tasks found or invalid graph structure") | |
return None | |
# Calculate layout | |
pos = self.calculate_layout(G) | |
# Create figure for simplified graph | |
fig, ax = plt.subplots(1, 1, figsize=(10, 6)) | |
# Draw edges | |
nx.draw_networkx_edges(G, pos, | |
edge_color='black', | |
arrows=True, | |
arrowsize=15, | |
arrowstyle='->', | |
width=1.5) | |
# Draw nodes | |
nx.draw_networkx_nodes(G, pos, | |
node_color='lightblue', | |
node_size=3000, | |
edgecolors='black', | |
linewidths=1.5) | |
# Add node labels with simplified names | |
labels = {} | |
for node in G.nodes(): | |
function_name = G.nodes[node]['function'] | |
simplified_name = function_name.replace('_', ' ').title() | |
if len(simplified_name) > 15: | |
simplified_name = simplified_name[:12] + "..." | |
labels[node] = f"T{node}\n{simplified_name}" | |
nx.draw_networkx_labels(G, pos, labels, | |
font_size=11, | |
font_weight='bold') | |
ax.set_title(title, fontsize=14, fontweight='bold') | |
ax.axis('off') | |
# Adjust layout and save | |
plt.tight_layout() | |
# Create temporary file for saving the image | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
temp_dir = tempfile.gettempdir() | |
image_path = os.path.join(temp_dir, f'simple_dag_{timestamp}.png') | |
plt.savefig(image_path, dpi=400, bbox_inches='tight') | |
plt.close(fig) # Close figure to free memory | |
logger.info(f"Simplified DAG visualization saved to: {image_path}") | |
return image_path | |
except Exception as e: | |
logger.error(f"Error creating simplified DAG visualization: {e}") | |
return None |