Yongdong commited on
Commit
1ef829e
Β·
1 Parent(s): 792bd1c

Add DAG visualization functionality for robot task planning

Browse files
Files changed (5) hide show
  1. app.py +131 -52
  2. dag_visualizer.py +334 -0
  3. json_processor.py +46 -46
  4. requirements.txt +3 -0
  5. test_dag_integration.py +175 -0
app.py CHANGED
@@ -3,6 +3,7 @@ import spaces # Import spaces module for ZeroGPU
3
  from huggingface_hub import login
4
  import os
5
  from json_processor import JsonProcessor
 
6
  import json
7
 
8
  # 1) Read Secrets
@@ -213,6 +214,31 @@ def generate_response_gpu(prompt, max_tokens=512, selected_model=DEFAULT_MODEL):
213
  except Exception as generation_error:
214
  return f"❌ Generation Error: {str(generation_error)}"
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  def chat_interface(message, history, max_tokens, selected_model):
217
  """Chat interface - runs on CPU, calls GPU functions"""
218
  if not message.strip():
@@ -255,7 +281,7 @@ with gr.Blocks(
255
  - **βš–οΈ Dart-llm-model-3B**: Balanced performance and quality
256
  - **🎯 Dart-llm-model-8B**: Best quality output, higher latency
257
 
258
- **Capabilities**: Convert natural language robot commands into structured task sequences for excavators, dump trucks, and other construction robots.
259
 
260
  **Models**:
261
  - [YongdongWang/llama-3.2-1b-lora-qlora-dart-llm](https://huggingface.co/YongdongWang/llama-3.2-1b-lora-qlora-dart-llm) (Default)
@@ -265,59 +291,100 @@ with gr.Blocks(
265
  ⚑ **Using ZeroGPU**: This Space uses dynamic GPU allocation (Nvidia H200). First generation might take a bit longer.
266
  """)
267
 
268
- with gr.Row():
269
- with gr.Column(scale=3):
270
- chatbot = gr.Chatbot(
271
- label="Task Planning Results",
272
- height=500,
273
- show_label=True,
274
- container=True,
275
- bubble_full_width=False,
276
- show_copy_button=True
277
- )
278
-
279
- msg = gr.Textbox(
280
- label="Robot Command",
281
- placeholder="Enter robot task command (e.g., 'Deploy Excavator 1 to Soil Area 1')...",
282
- lines=2,
283
- max_lines=5,
284
- show_label=True,
285
- container=True
286
- )
287
-
288
  with gr.Row():
289
- send_btn = gr.Button("πŸš€ Generate Tasks", variant="primary", size="sm")
290
- clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", size="sm")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
- with gr.Column(scale=1):
293
- gr.Markdown("### βš™οΈ Generation Settings")
294
-
295
- model_selector = gr.Dropdown(
296
- choices=[(config["name"], key) for key, config in MODEL_CONFIGS.items()],
297
- value=DEFAULT_MODEL,
298
- label="Model Size",
299
- info="Select model size (1B = fastest, 8B = best quality)",
300
- interactive=True
301
- )
302
-
303
- max_tokens = gr.Slider(
304
- minimum=50,
305
- maximum=5000,
306
- value=512,
307
- step=10,
308
- label="Max Tokens",
309
- info="Maximum number of tokens to generate"
310
- )
311
-
312
- gr.Markdown("""
313
- ### πŸ“Š Model Status
314
- - **Hardware**: ZeroGPU (Dynamic Nvidia H200)
315
- - **Status**: Ready
316
- - **Note**: First generation allocates GPU resources
317
- - **Dart-llm-model-1B**: Fastest inference (Default)
318
- - **Dart-llm-model-3B**: Balanced speed/quality
319
- - **Dart-llm-model-8B**: Best quality, slower
320
- """)
 
 
 
 
 
 
 
 
 
321
 
322
  # Example conversations
323
  gr.Examples(
@@ -350,6 +417,18 @@ with gr.Blocks(
350
  lambda: ([], ""),
351
  outputs=[chatbot, msg]
352
  )
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
  if __name__ == "__main__":
355
  app.launch(
 
3
  from huggingface_hub import login
4
  import os
5
  from json_processor import JsonProcessor
6
+ from dag_visualizer import DAGVisualizer
7
  import json
8
 
9
  # 1) Read Secrets
 
214
  except Exception as generation_error:
215
  return f"❌ Generation Error: {str(generation_error)}"
216
 
217
+ def create_dag_visualization(task_json_str):
218
+ """Create DAG visualization from task JSON"""
219
+ try:
220
+ if not task_json_str.strip():
221
+ return None, "Please provide task JSON data"
222
+
223
+ # Parse JSON
224
+ task_data = json.loads(task_json_str)
225
+
226
+ # Create DAG visualizer
227
+ dag_visualizer = DAGVisualizer()
228
+
229
+ # Generate visualization
230
+ image_path = dag_visualizer.create_dag_visualization(task_data)
231
+
232
+ if image_path:
233
+ return image_path, "βœ… DAG visualization created successfully!"
234
+ else:
235
+ return None, "❌ Failed to create DAG visualization"
236
+
237
+ except json.JSONDecodeError as e:
238
+ return None, f"❌ JSON Parse Error: {str(e)}"
239
+ except Exception as e:
240
+ return None, f"❌ DAG Creation Error: {str(e)}"
241
+
242
  def chat_interface(message, history, max_tokens, selected_model):
243
  """Chat interface - runs on CPU, calls GPU functions"""
244
  if not message.strip():
 
281
  - **βš–οΈ Dart-llm-model-3B**: Balanced performance and quality
282
  - **🎯 Dart-llm-model-8B**: Best quality output, higher latency
283
 
284
+ **Capabilities**: Convert natural language robot commands into structured task sequences for excavators, dump trucks, and other construction robots. **Now with DAG Visualization!**
285
 
286
  **Models**:
287
  - [YongdongWang/llama-3.2-1b-lora-qlora-dart-llm](https://huggingface.co/YongdongWang/llama-3.2-1b-lora-qlora-dart-llm) (Default)
 
291
  ⚑ **Using ZeroGPU**: This Space uses dynamic GPU allocation (Nvidia H200). First generation might take a bit longer.
292
  """)
293
 
294
+ with gr.Tabs():
295
+ with gr.Tab("πŸ’¬ Task Planning"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  with gr.Row():
297
+ with gr.Column(scale=3):
298
+ chatbot = gr.Chatbot(
299
+ label="Task Planning Results",
300
+ height=500,
301
+ show_label=True,
302
+ container=True,
303
+ bubble_full_width=False,
304
+ show_copy_button=True
305
+ )
306
+
307
+ msg = gr.Textbox(
308
+ label="Robot Command",
309
+ placeholder="Enter robot task command (e.g., 'Deploy Excavator 1 to Soil Area 1')...",
310
+ lines=2,
311
+ max_lines=5,
312
+ show_label=True,
313
+ container=True
314
+ )
315
+
316
+ with gr.Row():
317
+ send_btn = gr.Button("πŸš€ Generate Tasks", variant="primary", size="sm")
318
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", size="sm")
319
+
320
+ with gr.Column(scale=1):
321
+ gr.Markdown("### βš™οΈ Generation Settings")
322
+
323
+ model_selector = gr.Dropdown(
324
+ choices=[(config["name"], key) for key, config in MODEL_CONFIGS.items()],
325
+ value=DEFAULT_MODEL,
326
+ label="Model Size",
327
+ info="Select model size (1B = fastest, 8B = best quality)",
328
+ interactive=True
329
+ )
330
+
331
+ max_tokens = gr.Slider(
332
+ minimum=50,
333
+ maximum=5000,
334
+ value=512,
335
+ step=10,
336
+ label="Max Tokens",
337
+ info="Maximum number of tokens to generate"
338
+ )
339
+
340
+ gr.Markdown("""
341
+ ### πŸ“Š Model Status
342
+ - **Hardware**: ZeroGPU (Dynamic Nvidia H200)
343
+ - **Status**: Ready
344
+ - **Note**: First generation allocates GPU resources
345
+ - **Dart-llm-model-1B**: Fastest inference (Default)
346
+ - **Dart-llm-model-3B**: Balanced speed/quality
347
+ - **Dart-llm-model-8B**: Best quality, slower
348
+ """)
349
 
350
+ with gr.Tab("πŸ“Š DAG Visualization"):
351
+ with gr.Row():
352
+ with gr.Column(scale=2):
353
+ json_input = gr.Textbox(
354
+ label="Task JSON Data",
355
+ placeholder="Paste the generated task JSON here to create a DAG visualization...",
356
+ lines=15,
357
+ max_lines=25,
358
+ show_label=True,
359
+ container=True
360
+ )
361
+
362
+ with gr.Row():
363
+ dag_btn = gr.Button("🎨 Generate DAG", variant="primary", size="sm")
364
+ dag_clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", size="sm")
365
+
366
+ dag_status = gr.Textbox(
367
+ label="Status",
368
+ value="Ready to generate DAG visualization",
369
+ interactive=False,
370
+ show_label=True
371
+ )
372
+
373
+ with gr.Column(scale=3):
374
+ dag_output = gr.Image(
375
+ label="Task Dependency Graph",
376
+ show_label=True,
377
+ container=True,
378
+ height=600
379
+ )
380
+
381
+ gr.Markdown("""
382
+ ### πŸ“ˆ DAG Features
383
+ - **Node Colors**: Red (Start), Orange (Intermediate), Purple (End)
384
+ - **Arrows**: Show task dependencies
385
+ - **Layout**: Hierarchical based on dependencies
386
+ - **Details**: Task info boxes with robots and objects
387
+ """)
388
 
389
  # Example conversations
390
  gr.Examples(
 
417
  lambda: ([], ""),
418
  outputs=[chatbot, msg]
419
  )
420
+
421
+ # DAG visualization event handlers
422
+ dag_btn.click(
423
+ create_dag_visualization,
424
+ inputs=[json_input],
425
+ outputs=[dag_output, dag_status]
426
+ )
427
+
428
+ dag_clear_btn.click(
429
+ lambda: ("", None, "Ready to generate DAG visualization"),
430
+ outputs=[json_input, dag_output, dag_status]
431
+ )
432
 
433
  if __name__ == "__main__":
434
  app.launch(
dag_visualizer.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import matplotlib
3
+ matplotlib.use('Agg') # Use non-interactive backend for server environments
4
+ import networkx as nx
5
+ import json
6
+ import numpy as np
7
+ from loguru import logger
8
+ import os
9
+ import tempfile
10
+ from datetime import datetime
11
+
12
+ class DAGVisualizer:
13
+ def __init__(self):
14
+ # Configure Matplotlib to use IEEE-style parameters
15
+ plt.rcParams.update({
16
+ 'font.family': 'DejaVu Sans', # Use available font instead of Times New Roman
17
+ 'font.size': 10,
18
+ 'axes.linewidth': 1.2,
19
+ 'axes.labelsize': 12,
20
+ 'xtick.labelsize': 10,
21
+ 'ytick.labelsize': 10,
22
+ 'legend.fontsize': 10,
23
+ 'figure.titlesize': 14
24
+ })
25
+
26
+ def create_dag_from_tasks(self, task_data):
27
+ """
28
+ Create a directed graph from task data.
29
+
30
+ Args:
31
+ task_data: Dictionary containing tasks with structure like:
32
+ {
33
+ "tasks": [
34
+ {
35
+ "task": "task_name",
36
+ "instruction_function": {
37
+ "name": "function_name",
38
+ "robot_ids": ["robot1", "robot2"],
39
+ "dependencies": ["dependency_task"],
40
+ "object_keywords": ["object1", "object2"]
41
+ }
42
+ }
43
+ ]
44
+ }
45
+
46
+ Returns:
47
+ NetworkX DiGraph object
48
+ """
49
+ if not task_data or "tasks" not in task_data:
50
+ logger.warning("Invalid task data structure")
51
+ return None
52
+
53
+ # Create a directed graph
54
+ G = nx.DiGraph()
55
+
56
+ # Add nodes and store mapping from task name to ID
57
+ task_mapping = {}
58
+ for i, task in enumerate(task_data["tasks"]):
59
+ task_id = i + 1
60
+ task_name = task["task"]
61
+ task_mapping[task_name] = task_id
62
+
63
+ # Add node with attributes
64
+ G.add_node(task_id,
65
+ name=task_name,
66
+ function=task["instruction_function"]["name"],
67
+ robots=task["instruction_function"].get("robot_ids", []),
68
+ objects=task["instruction_function"].get("object_keywords", []))
69
+
70
+ # Add dependency edges
71
+ for i, task in enumerate(task_data["tasks"]):
72
+ task_id = i + 1
73
+ dependencies = task["instruction_function"]["dependencies"]
74
+ for dep in dependencies:
75
+ if dep in task_mapping:
76
+ dep_id = task_mapping[dep]
77
+ G.add_edge(dep_id, task_id)
78
+
79
+ return G
80
+
81
+ def calculate_layout(self, G):
82
+ """
83
+ Calculate hierarchical layout for the graph based on dependencies.
84
+ """
85
+ if not G:
86
+ return {}
87
+
88
+ # Calculate layers based on dependencies
89
+ layers = {}
90
+
91
+ def get_layer(node_id, visited=None):
92
+ if visited is None:
93
+ visited = set()
94
+ if node_id in visited:
95
+ return 0
96
+ visited.add(node_id)
97
+
98
+ predecessors = list(G.predecessors(node_id))
99
+ if not predecessors:
100
+ return 0
101
+ return max(get_layer(pred, visited.copy()) for pred in predecessors) + 1
102
+
103
+ for node in G.nodes():
104
+ layer = get_layer(node)
105
+ layers.setdefault(layer, []).append(node)
106
+
107
+ # Calculate positions by layer
108
+ pos = {}
109
+ layer_height = 3.0
110
+ node_width = 4.0
111
+
112
+ for layer_idx, nodes in layers.items():
113
+ y = layer_height * (len(layers) - 1 - layer_idx)
114
+ start_x = -(len(nodes) - 1) * node_width / 2
115
+ for i, node in enumerate(sorted(nodes)):
116
+ pos[node] = (start_x + i * node_width, y)
117
+
118
+ return pos
119
+
120
+ def create_dag_visualization(self, task_data, title="Robot Task Dependency Graph"):
121
+ """
122
+ Create a DAG visualization from task data and return the image path.
123
+
124
+ Args:
125
+ task_data: Task data dictionary
126
+ title: Title for the graph
127
+
128
+ Returns:
129
+ str: Path to the generated image file
130
+ """
131
+ try:
132
+ # Create graph
133
+ G = self.create_dag_from_tasks(task_data)
134
+ if not G or len(G.nodes()) == 0:
135
+ logger.warning("No tasks found or invalid graph structure")
136
+ return None
137
+
138
+ # Calculate layout
139
+ pos = self.calculate_layout(G)
140
+
141
+ # Create figure
142
+ fig, ax = plt.subplots(1, 1, figsize=(max(12, len(G.nodes()) * 2), 8))
143
+
144
+ # Draw edges with arrows
145
+ nx.draw_networkx_edges(G, pos,
146
+ edge_color='#2E86AB',
147
+ arrows=True,
148
+ arrowsize=20,
149
+ arrowstyle='->',
150
+ width=2,
151
+ alpha=0.8,
152
+ connectionstyle="arc3,rad=0.1")
153
+
154
+ # Color nodes based on their position in the graph
155
+ node_colors = []
156
+ for node in G.nodes():
157
+ if G.in_degree(node) == 0: # Start nodes
158
+ node_colors.append('#F24236')
159
+ elif G.out_degree(node) == 0: # End nodes
160
+ node_colors.append('#A23B72')
161
+ else: # Intermediate nodes
162
+ node_colors.append('#F18F01')
163
+
164
+ # Draw nodes
165
+ nx.draw_networkx_nodes(G, pos,
166
+ node_color=node_colors,
167
+ node_size=3500,
168
+ alpha=0.9,
169
+ edgecolors='black',
170
+ linewidths=2)
171
+
172
+ # Label nodes with task IDs
173
+ node_labels = {node: f"T{node}" for node in G.nodes()}
174
+ nx.draw_networkx_labels(G, pos, node_labels,
175
+ font_size=18,
176
+ font_weight='bold',
177
+ font_color='white')
178
+
179
+ # Add detailed info text boxes for each task
180
+ for i, node in enumerate(G.nodes()):
181
+ x, y = pos[node]
182
+ function_name = G.nodes[node]['function']
183
+ robots = G.nodes[node]['robots']
184
+ objects = G.nodes[node]['objects']
185
+
186
+ # Create info text content
187
+ info_text = f"Task {node}: {function_name.replace('_', ' ').title()}\n"
188
+ if robots:
189
+ robot_text = ", ".join([r.replace('robot_', '').replace('_', ' ').title() for r in robots])
190
+ info_text += f"Robots: {robot_text}\n"
191
+ if objects:
192
+ object_text = ", ".join(objects)
193
+ info_text += f"Objects: {object_text}"
194
+
195
+ # Calculate offset based on node position to avoid overlaps
196
+ offset_x = 2.2 if i % 2 == 0 else -2.2
197
+ offset_y = 0.5 if i % 4 < 2 else -0.5
198
+
199
+ # Choose alignment based on offset direction
200
+ h_align = 'left' if offset_x > 0 else 'right'
201
+
202
+ # Draw text box
203
+ bbox_props = dict(boxstyle="round,pad=0.4",
204
+ facecolor='white',
205
+ edgecolor='gray',
206
+ alpha=0.95,
207
+ linewidth=1)
208
+
209
+ ax.text(x + offset_x, y + offset_y, info_text,
210
+ bbox=bbox_props,
211
+ fontsize=12,
212
+ verticalalignment='center',
213
+ horizontalalignment=h_align,
214
+ weight='bold')
215
+
216
+ # Draw dashed connector line from node to text box
217
+ ax.plot([x, x + offset_x], [y, y + offset_y],
218
+ linestyle='--', color='gray', alpha=0.6, linewidth=1)
219
+
220
+ # Expand axis limits to fit everything
221
+ x_vals = [coord[0] for coord in pos.values()]
222
+ y_vals = [coord[1] for coord in pos.values()]
223
+ ax.set_xlim(min(x_vals) - 4.0, max(x_vals) + 4.0)
224
+ ax.set_ylim(min(y_vals) - 2.0, max(y_vals) + 2.0)
225
+
226
+ # Set overall figure properties
227
+ ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
228
+ ax.set_aspect('equal')
229
+ ax.margins(0.2)
230
+ ax.axis('off')
231
+
232
+ # Add legend for node types - Hidden to avoid covering content
233
+ # legend_elements = [
234
+ # plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#F24236',
235
+ # markersize=10, label='Start Tasks', markeredgecolor='black'),
236
+ # plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#A23B72',
237
+ # markersize=10, label='End Tasks', markeredgecolor='black'),
238
+ # plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#F18F01',
239
+ # markersize=10, label='Intermediate Tasks', markeredgecolor='black'),
240
+ # plt.Line2D([0], [0], color='#2E86AB', linewidth=2, label='Dependencies')
241
+ # ]
242
+ # ax.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.05, 1.05))
243
+
244
+ # Adjust layout and save
245
+ plt.tight_layout()
246
+
247
+ # Create temporary file for saving the image
248
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
249
+ temp_dir = tempfile.gettempdir()
250
+ image_path = os.path.join(temp_dir, f'dag_visualization_{timestamp}.png')
251
+
252
+ plt.savefig(image_path, dpi=400, bbox_inches='tight',
253
+ pad_inches=0.1, facecolor='white', edgecolor='none')
254
+ plt.close(fig) # Close figure to free memory
255
+
256
+ logger.info(f"DAG visualization saved to: {image_path}")
257
+ return image_path
258
+
259
+ except Exception as e:
260
+ logger.error(f"Error creating DAG visualization: {e}")
261
+ return None
262
+
263
+ def create_simplified_dag_visualization(self, task_data, title="Robot Task Graph"):
264
+ """
265
+ Create a simplified DAG visualization suitable for smaller displays.
266
+
267
+ Args:
268
+ task_data: Task data dictionary
269
+ title: Title for the graph
270
+
271
+ Returns:
272
+ str: Path to the generated image file
273
+ """
274
+ try:
275
+ # Create graph
276
+ G = self.create_dag_from_tasks(task_data)
277
+ if not G or len(G.nodes()) == 0:
278
+ logger.warning("No tasks found or invalid graph structure")
279
+ return None
280
+
281
+ # Calculate layout
282
+ pos = self.calculate_layout(G)
283
+
284
+ # Create figure for simplified graph
285
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
286
+
287
+ # Draw edges
288
+ nx.draw_networkx_edges(G, pos,
289
+ edge_color='black',
290
+ arrows=True,
291
+ arrowsize=15,
292
+ arrowstyle='->',
293
+ width=1.5)
294
+
295
+ # Draw nodes
296
+ nx.draw_networkx_nodes(G, pos,
297
+ node_color='lightblue',
298
+ node_size=3000,
299
+ edgecolors='black',
300
+ linewidths=1.5)
301
+
302
+ # Add node labels with simplified names
303
+ labels = {}
304
+ for node in G.nodes():
305
+ function_name = G.nodes[node]['function']
306
+ simplified_name = function_name.replace('_', ' ').title()
307
+ if len(simplified_name) > 15:
308
+ simplified_name = simplified_name[:12] + "..."
309
+ labels[node] = f"T{node}\n{simplified_name}"
310
+
311
+ nx.draw_networkx_labels(G, pos, labels,
312
+ font_size=11,
313
+ font_weight='bold')
314
+
315
+ ax.set_title(title, fontsize=14, fontweight='bold')
316
+ ax.axis('off')
317
+
318
+ # Adjust layout and save
319
+ plt.tight_layout()
320
+
321
+ # Create temporary file for saving the image
322
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
323
+ temp_dir = tempfile.gettempdir()
324
+ image_path = os.path.join(temp_dir, f'simple_dag_{timestamp}.png')
325
+
326
+ plt.savefig(image_path, dpi=400, bbox_inches='tight')
327
+ plt.close(fig) # Close figure to free memory
328
+
329
+ logger.info(f"Simplified DAG visualization saved to: {image_path}")
330
+ return image_path
331
+
332
+ except Exception as e:
333
+ logger.error(f"Error creating simplified DAG visualization: {e}")
334
+ return None
json_processor.py CHANGED
@@ -1,46 +1,46 @@
1
- import json
2
- import re
3
- import ast
4
- from loguru import logger
5
-
6
- class JsonProcessor:
7
- def process_response(self, response):
8
- try:
9
- # Search for JSON string in the response
10
- json_str_match = re.search(r'\{.*\}', response, re.DOTALL)
11
- if json_str_match:
12
- # Get the matched JSON string
13
- json_str = json_str_match.group()
14
- logger.debug(f"Full JSON string: {json_str}")
15
-
16
- # Try to parse as Python literal first, then convert to JSON
17
- try:
18
- # First try to evaluate as Python literal
19
- python_obj = ast.literal_eval(json_str)
20
- # Convert to proper JSON
21
- response_json = json.loads(json.dumps(python_obj))
22
- except (ValueError, SyntaxError):
23
- # Fall back to string replacement method
24
- # Replace escape characters and remove trailing commas
25
- json_str = json_str.replace("\\", "")
26
- json_str = json_str.replace(r'\\_', '_')
27
- json_str = re.sub(r',\s*}', '}', json_str)
28
- json_str = re.sub(r',\s*\]', ']', json_str)
29
-
30
- # Convert Python format to JSON format
31
- json_str = json_str.replace("'", '"') # Single quotes to double quotes
32
- json_str = json_str.replace('None', 'null') # Python None to JSON null
33
-
34
- # Parse the JSON string
35
- response_json = json.loads(json_str)
36
- return response_json
37
- else:
38
- logger.error("No JSON string match found in response.")
39
- return None
40
-
41
- except json.JSONDecodeError as e:
42
- logger.error(f"JSONDecodeError: {e}")
43
- except Exception as e:
44
- logger.error(f"Unexpected error: {e}")
45
-
46
- return None
 
1
+ import json
2
+ import re
3
+ import ast
4
+ from loguru import logger
5
+
6
+ class JsonProcessor:
7
+ def process_response(self, response):
8
+ try:
9
+ # Search for JSON string in the response
10
+ json_str_match = re.search(r'\{.*\}', response, re.DOTALL)
11
+ if json_str_match:
12
+ # Get the matched JSON string
13
+ json_str = json_str_match.group()
14
+ logger.debug(f"Full JSON string: {json_str}")
15
+
16
+ # Try to parse as Python literal first, then convert to JSON
17
+ try:
18
+ # First try to evaluate as Python literal
19
+ python_obj = ast.literal_eval(json_str)
20
+ # Convert to proper JSON
21
+ response_json = json.loads(json.dumps(python_obj))
22
+ except (ValueError, SyntaxError):
23
+ # Fall back to string replacement method
24
+ # Replace escape characters and remove trailing commas
25
+ json_str = json_str.replace("\\", "")
26
+ json_str = json_str.replace(r'\\_', '_')
27
+ json_str = re.sub(r',\s*}', '}', json_str)
28
+ json_str = re.sub(r',\s*\]', ']', json_str)
29
+
30
+ # Convert Python format to JSON format
31
+ json_str = json_str.replace("'", '"') # Single quotes to double quotes
32
+ json_str = json_str.replace('None', 'null') # Python None to JSON null
33
+
34
+ # Parse the JSON string
35
+ response_json = json.loads(json_str)
36
+ return response_json
37
+ else:
38
+ logger.error("No JSON string match found in response.")
39
+ return None
40
+
41
+ except json.JSONDecodeError as e:
42
+ logger.error(f"JSONDecodeError: {e}")
43
+ except Exception as e:
44
+ logger.error(f"Unexpected error: {e}")
45
+
46
+ return None
requirements.txt CHANGED
@@ -10,3 +10,6 @@ sentencepiece
10
  protobuf
11
  spaces
12
  loguru
 
 
 
 
10
  protobuf
11
  spaces
12
  loguru
13
+ matplotlib
14
+ networkx
15
+ numpy
test_dag_integration.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for DAG integration in DART-LLM-Multi-Model
4
+ """
5
+
6
+ from dag_visualizer import DAGVisualizer
7
+ from json_processor import JsonProcessor
8
+ import json
9
+
10
+ def test_dag_integration():
11
+ """Test the DAG integration with sample task data"""
12
+ print("Testing DAG integration...")
13
+
14
+ # Sample response with task data (similar to what the model might generate)
15
+ sample_response = """
16
+ Based on your command, here are the robot tasks:
17
+
18
+ {
19
+ "tasks": [
20
+ {
21
+ "task": "move_excavator_to_soil_area",
22
+ "instruction_function": {
23
+ "name": "move_to_position",
24
+ "robot_ids": ["robot_excavator_01"],
25
+ "dependencies": [],
26
+ "object_keywords": ["soil_area_1"]
27
+ }
28
+ },
29
+ {
30
+ "task": "excavate_soil",
31
+ "instruction_function": {
32
+ "name": "excavate_material",
33
+ "robot_ids": ["robot_excavator_01"],
34
+ "dependencies": ["move_excavator_to_soil_area"],
35
+ "object_keywords": ["soil"]
36
+ }
37
+ },
38
+ {
39
+ "task": "move_dump_truck",
40
+ "instruction_function": {
41
+ "name": "move_to_position",
42
+ "robot_ids": ["robot_dump_truck_01"],
43
+ "dependencies": [],
44
+ "object_keywords": ["soil_area_1"]
45
+ }
46
+ },
47
+ {
48
+ "task": "load_soil_to_truck",
49
+ "instruction_function": {
50
+ "name": "transfer_material",
51
+ "robot_ids": ["robot_excavator_01", "robot_dump_truck_01"],
52
+ "dependencies": ["excavate_soil", "move_dump_truck"],
53
+ "object_keywords": ["soil"]
54
+ }
55
+ },
56
+ {
57
+ "task": "transport_to_dump_site",
58
+ "instruction_function": {
59
+ "name": "move_to_position",
60
+ "robot_ids": ["robot_dump_truck_01"],
61
+ "dependencies": ["load_soil_to_truck"],
62
+ "object_keywords": ["dump_site"]
63
+ }
64
+ }
65
+ ]
66
+ }
67
+ """
68
+
69
+ # Test JSON processing
70
+ processor = JsonProcessor()
71
+ print("1. Testing JSON processing...")
72
+ processed_json = processor.process_response(sample_response)
73
+
74
+ if processed_json:
75
+ print("βœ“ JSON processing successful")
76
+ print(f" Found {len(processed_json['tasks'])} tasks")
77
+ else:
78
+ print("βœ— JSON processing failed")
79
+ return False
80
+
81
+ # Test DAG visualization
82
+ print("2. Testing DAG visualization...")
83
+ visualizer = DAGVisualizer()
84
+
85
+ try:
86
+ dag_image_path = visualizer.create_dag_visualization(
87
+ processed_json,
88
+ title="Test Robot Task Dependency Graph"
89
+ )
90
+
91
+ if dag_image_path:
92
+ print(f"βœ“ DAG visualization created: {dag_image_path}")
93
+ return True
94
+ else:
95
+ print("βœ— DAG visualization failed")
96
+ return False
97
+
98
+ except Exception as e:
99
+ print(f"βœ— DAG visualization error: {e}")
100
+ return False
101
+
102
+ def test_simplified_dag():
103
+ """Test simplified DAG visualization"""
104
+ print("\n3. Testing simplified DAG...")
105
+
106
+ simple_task_data = {
107
+ "tasks": [
108
+ {
109
+ "task": "move_robot",
110
+ "instruction_function": {
111
+ "name": "move_to_position",
112
+ "robot_ids": ["robot_01"],
113
+ "dependencies": [],
114
+ "object_keywords": ["target_area"]
115
+ }
116
+ },
117
+ {
118
+ "task": "perform_operation",
119
+ "instruction_function": {
120
+ "name": "excavate",
121
+ "robot_ids": ["robot_01"],
122
+ "dependencies": ["move_robot"],
123
+ "object_keywords": ["soil"]
124
+ }
125
+ }
126
+ ]
127
+ }
128
+
129
+ visualizer = DAGVisualizer()
130
+
131
+ try:
132
+ dag_image_path = visualizer.create_simplified_dag_visualization(
133
+ simple_task_data,
134
+ title="Simplified Test DAG"
135
+ )
136
+
137
+ if dag_image_path:
138
+ print(f"βœ“ Simplified DAG visualization created: {dag_image_path}")
139
+ return True
140
+ else:
141
+ print("βœ— Simplified DAG visualization failed")
142
+ return False
143
+
144
+ except Exception as e:
145
+ print(f"βœ— Simplified DAG visualization error: {e}")
146
+ return False
147
+
148
+ def main():
149
+ """Run all tests"""
150
+ print("=" * 60)
151
+ print("DART-LLM-Multi-Model DAG Integration Test")
152
+ print("=" * 60)
153
+
154
+ success_count = 0
155
+ total_tests = 2
156
+
157
+ if test_dag_integration():
158
+ success_count += 1
159
+
160
+ if test_simplified_dag():
161
+ success_count += 1
162
+
163
+ print("\n" + "=" * 60)
164
+ print(f"Test Results: {success_count}/{total_tests} passed")
165
+
166
+ if success_count == total_tests:
167
+ print("πŸŽ‰ All DAG integration tests passed!")
168
+ return True
169
+ else:
170
+ print("❌ Some tests failed!")
171
+ return False
172
+
173
+ if __name__ == "__main__":
174
+ success = main()
175
+ exit(0 if success else 1)