vimmoos@Thor
commited on
Commit
·
610d19f
1
Parent(s):
bcb5af3
update website
Browse files- udrl/app/home.py +239 -52
- udrl/app/sim.py +4 -3
- udrl/data_proc.py +108 -0
- udrl/inference.py +152 -28
- udrl/test.py +0 -137
udrl/app/home.py
CHANGED
@@ -4,63 +4,188 @@ import streamlit as st
|
|
4 |
st.image("logo.jpg")
|
5 |
|
6 |
st.html(
|
7 |
-
"""
|
8 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
</div>
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
<li>Match the performance of neural networks</li>
|
27 |
-
<li>Provide more interpretable policies</li>
|
28 |
-
<li>Enhance the explainability of reinforcement learning systems</li>
|
29 |
-
</ul>
|
30 |
-
|
31 |
-
<div class="results">
|
32 |
-
<h2>Results</h2>
|
33 |
-
<p>We tested three different implementations of the Behaviour Function:</p>
|
34 |
-
<ul>
|
35 |
-
<li>Neural Networks (NN)</li>
|
36 |
-
<li>Random Forests (RF)</li>
|
37 |
-
<li>Extremely Randomized Trees (ET)</li>
|
38 |
-
</ul>
|
39 |
-
<p>Tests were conducted on three popular OpenAI Gym environments:</p>
|
40 |
-
<ul>
|
41 |
-
<li>CartPole</li>
|
42 |
-
<li>Acrobot</li>
|
43 |
-
<li>Lunar-Lander</li>
|
44 |
-
</ul>
|
45 |
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
<h2>Key Findings</h2>
|
49 |
-
<ul>
|
50 |
-
<li>Tree-based methods performed comparably to neural networks</li>
|
51 |
-
<li>Random Forests and Extremely Randomized Trees provided fully interpretable policies</li>
|
52 |
-
<li>Feature importance analysis revealed insights into decision-making processes</li>
|
53 |
-
</ul>
|
54 |
-
|
55 |
-
<h2>Implications</h2>
|
56 |
-
<p>This research opens new avenues for:</p>
|
57 |
-
<ul>
|
58 |
-
<li>More explainable reinforcement learning systems</li>
|
59 |
-
<li>Enhanced safety in AI decision-making</li>
|
60 |
-
<li>Better understanding of agent behavior in complex environments</li>
|
61 |
-
</ul>
|
62 |
</div>
|
63 |
-
|
|
|
64 |
)
|
65 |
st.html(
|
66 |
"""
|
@@ -74,8 +199,12 @@ st.html(
|
|
74 |
margin: 0 auto;
|
75 |
padding: 20px;
|
76 |
}
|
|
|
|
|
|
|
77 |
h1, h2 {
|
78 |
color: #81a1c1;
|
|
|
79 |
}
|
80 |
.abstract {
|
81 |
background-color: #2e3440;
|
@@ -108,3 +237,61 @@ st.html(
|
|
108 |
</style>
|
109 |
"""
|
110 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
st.image("logo.jpg")
|
5 |
|
6 |
st.html(
|
7 |
+
"""
|
8 |
+
<body>
|
9 |
+
<div class="container">
|
10 |
+
<h1 >Upside-Down Reinforcement Learning for More Interpretable Optimal Control</h1>
|
11 |
+
<div class="authors" style="text-align: center;">
|
12 |
+
<p class="author-names">Juan Cardenas-Cartagena, Massimiliano Falzari, Marco Zullich, Matthia Sabatelli</p>
|
13 |
+
<p class="institution">Bernoulli Institute, University of Groningen, The Netherlands</p>
|
14 |
+
</div>
|
15 |
+
<h2><a href="https://arxiv.org/abs/2411.11457" target="_blank">Read the full paper on arXiv</a></h2>
|
16 |
|
17 |
+
<section class="motivation">
|
18 |
+
<h2>Research Motivation</h2>
|
19 |
+
<p>The dramatic growth in adoption of Neural Networks (NNs) within the last 15 years has sparked a crucial need for increased transparency, especially in high-stake applications. While NNs have demonstrated remarkable performance across various domains, they are essentially black boxes whose decision-making processes remain opaque to human understanding. This research addresses this fundamental challenge by exploring alternative approaches that maintain performance while dramatically improving interpretability.</p>
|
|
|
20 |
|
21 |
+
<div class="key-challenges">
|
22 |
+
<h3>Current Challenges in Reinforcement Learning</h3>
|
23 |
+
<p>Traditional approaches to Reinforcement Learning (RL) face several key limitations:</p>
|
24 |
+
<ul>
|
25 |
+
<li>Complex neural network policies are difficult to interpret and explain</li>
|
26 |
+
<li>Lack of transparency in decision-making processes poses risks in critical applications</li>
|
27 |
+
<li>Traditional RL approaches either focus on predicting rewards or learning environment models, making interpretation challenging</li>
|
28 |
+
<li>The gap between performance and interpretability has been difficult to bridge</li>
|
29 |
+
</ul>
|
30 |
+
</div>
|
31 |
+
</section>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
<section class="udrl-framework">
|
34 |
+
<h2>The UDRL Framework: A Novel Approach</h2>
|
35 |
+
<p>Upside-Down Reinforcement Learning represents a fundamental shift in how we approach reinforcement learning problems. Instead of traditional methods that focus on predicting rewards or learning environment models, UDRL transforms the reinforcement learning problem into a supervised learning task.</p>
|
36 |
+
|
37 |
+
<div class="framework-details">
|
38 |
+
<h3>Key Components</h3>
|
39 |
+
<p>The UDRL approach centers around learning a behavior function f(st, dr, dt) = at where:</p>
|
40 |
+
<ul>
|
41 |
+
<li><strong>st</strong>: The current state of the environment</li>
|
42 |
+
<li><strong>dr</strong>: The desired reward the agent aims to achieve</li>
|
43 |
+
<li><strong>dt</strong>: The time horizon within which to achieve the reward</li>
|
44 |
+
<li><strong>at</strong>: The action to take to achieve the desired reward</li>
|
45 |
+
</ul>
|
46 |
+
</div>
|
47 |
+
|
48 |
+
<div class="mathematical-framework">
|
49 |
+
<h3>Mathematical Foundation</h3>
|
50 |
+
<p>The framework is built on a Markov Decision Process (MDP) defined as a tuple M = ⟨S,A,P,R⟩ where:</p>
|
51 |
+
<ul>
|
52 |
+
<li>S: The state space of the environment</li>
|
53 |
+
<li>A: The action space modeling all possible actions</li>
|
54 |
+
<li>P: The transition function P : S × A × S → [0,1]</li>
|
55 |
+
<li>R: The reward function R : S × A × S → R</li>
|
56 |
+
</ul>
|
57 |
+
</div>
|
58 |
+
</section>
|
59 |
+
|
60 |
+
<section class="implementation">
|
61 |
+
<h2>Implementation and Methodology</h2>
|
62 |
+
<div class="algorithms-detailed">
|
63 |
+
<h3>Studied Algorithms</h3>
|
64 |
+
<div class="tree-based">
|
65 |
+
<h4>Tree-Based Methods</h4>
|
66 |
+
<p>We extensively evaluated two primary tree-based approaches:</p>
|
67 |
+
<ul>
|
68 |
+
<li><strong>Random Forests (RF):</strong> An ensemble method building multiple decision trees and merging their predictions</li>
|
69 |
+
<li><strong>Extremely Randomized Trees (ET):</strong> A variation that adds additional randomization in the tree-building process</li>
|
70 |
+
</ul>
|
71 |
+
</div>
|
72 |
+
|
73 |
+
<div class="boosting">
|
74 |
+
<h4>Boosting Algorithms</h4>
|
75 |
+
<p>We also investigated sequential ensemble methods:</p>
|
76 |
+
<ul>
|
77 |
+
<li><strong>AdaBoost:</strong> Adaptive Boosting for sequential tree construction</li>
|
78 |
+
<li><strong>XGBoost:</strong> A more advanced implementation of gradient boosting</li>
|
79 |
+
</ul>
|
80 |
+
</div>
|
81 |
+
|
82 |
+
<div class="baseline">
|
83 |
+
<h4>Baseline Methods</h4>
|
84 |
+
<ul>
|
85 |
+
<li><strong>Neural Networks:</strong> Traditional multi-layer perceptron architecture</li>
|
86 |
+
<li><strong>K-Nearest Neighbours:</strong> Non-parametric baseline for comparison</li>
|
87 |
+
</ul>
|
88 |
+
</div>
|
89 |
+
</div>
|
90 |
+
|
91 |
+
<div class="experimental-environments">
|
92 |
+
<h3>Test Environments</h3>
|
93 |
+
<div class="cartpole">
|
94 |
+
<h4>CartPole</h4>
|
95 |
+
<p>A 4-dimensional continuous state space including:</p>
|
96 |
+
<ul>
|
97 |
+
<li>Cart Position (x)</li>
|
98 |
+
<li>Cart Velocity (ẋ)</li>
|
99 |
+
<li>Pole Angle (θ)</li>
|
100 |
+
<li>Pole Angular Velocity (θ̇)</li>
|
101 |
+
</ul>
|
102 |
+
</div>
|
103 |
+
|
104 |
+
<div class="acrobot">
|
105 |
+
<h4>Acrobot</h4>
|
106 |
+
<p>A 6-dimensional state space representing:</p>
|
107 |
+
<ul>
|
108 |
+
<li>First Link: sin(θ1), cos(θ1), θ̇1</li>
|
109 |
+
<li>Second Link: sin(θ2), cos(θ2), θ̇2</li>
|
110 |
+
</ul>
|
111 |
+
</div>
|
112 |
+
|
113 |
+
<div class="lunar-lander">
|
114 |
+
<h4>Lunar Lander</h4>
|
115 |
+
<p>An 8-dimensional state space including:</p>
|
116 |
+
<ul>
|
117 |
+
<li>Position (x, y)</li>
|
118 |
+
<li>Velocity (ẋ, ẏ)</li>
|
119 |
+
<li>Angle (θ) and Angular velocity (θ̇)</li>
|
120 |
+
<li>Left and right leg contact points</li>
|
121 |
+
</ul>
|
122 |
+
</div>
|
123 |
+
</div>
|
124 |
+
</section>
|
125 |
+
|
126 |
+
<section class="results">
|
127 |
+
<h2>Comprehensive Results and Analysis</h2>
|
128 |
+
<div class="performance-analysis">
|
129 |
+
<h3>Performance Metrics</h3>
|
130 |
+
<p>Our experiments revealed surprising competitiveness of tree-based methods:</p>
|
131 |
+
<ul>
|
132 |
+
<li><strong>CartPole Environment:</strong>
|
133 |
+
<ul>
|
134 |
+
<li>Neural Networks: 199.93 ± 0.255</li>
|
135 |
+
<li>Random Forests: 188.25 ± 13.82</li>
|
136 |
+
<li>XGBoost: 199.27 ± 4.06</li>
|
137 |
+
</ul>
|
138 |
+
</li>
|
139 |
+
<li><strong>Acrobot Environment:</strong>
|
140 |
+
<ul>
|
141 |
+
<li>Neural Networks: -75.00 ± 15.36</li>
|
142 |
+
<li>Random Forests: -100.05 ± 62.80</li>
|
143 |
+
<li>Extra Trees: -100.00 ± 93.72</li>
|
144 |
+
</ul>
|
145 |
+
</li>
|
146 |
+
<li><strong>Lunar Lander:</strong>
|
147 |
+
<ul>
|
148 |
+
<li>Random Forests: -54.74 ± 96.22</li>
|
149 |
+
<li>XGBoost: -76.96 ± 89.69</li>
|
150 |
+
<li>Neural Networks: -157.04 ± 71.26</li>
|
151 |
+
</ul>
|
152 |
+
</li>
|
153 |
+
</ul>
|
154 |
+
</div>
|
155 |
+
|
156 |
+
<div class="interpretability-analysis">
|
157 |
+
<h3>Interpretability Insights</h3>
|
158 |
+
<p>The tree-based methods provided unprecedented insights into decision-making:</p>
|
159 |
+
<ul>
|
160 |
+
<li><strong>CartPole:</strong> Pole angular velocity emerged as the most crucial feature for balancing</li>
|
161 |
+
<li><strong>Acrobot:</strong> Angular velocities of both links proved essential for control</li>
|
162 |
+
<li><strong>Lunar Lander:</strong> Vertical position showed highest importance for landing decisions</li>
|
163 |
+
</ul>
|
164 |
+
</div>
|
165 |
+
</section>
|
166 |
+
|
167 |
+
<section class="future-directions">
|
168 |
+
<h2>Future Research Directions</h2>
|
169 |
+
<p>Our findings open several promising avenues for future research:</p>
|
170 |
+
<ul>
|
171 |
+
<li>Scaling to High-Dimensional Spaces:
|
172 |
+
<p>Investigating the applicability of tree-based UDRL to more complex environments with higher-dimensional state spaces</p>
|
173 |
+
</li>
|
174 |
+
<li>Enhanced Interpretation Tools:
|
175 |
+
<p>Development of specialized tools for analyzing and visualizing decision processes in tree-based UDRL systems</p>
|
176 |
+
</li>
|
177 |
+
<li>Real-World Applications:
|
178 |
+
<p>Exploring applications in safety-critical domains where interpretability is crucial</p>
|
179 |
+
</li>
|
180 |
+
<li>Theoretical Analysis:
|
181 |
+
<p>Deeper investigation of the theoretical foundations underlying the success of tree-based methods in UDRL</p>
|
182 |
+
</li>
|
183 |
+
</ul>
|
184 |
+
</section>
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
</div>
|
187 |
+
</body>
|
188 |
+
"""
|
189 |
)
|
190 |
st.html(
|
191 |
"""
|
|
|
199 |
margin: 0 auto;
|
200 |
padding: 20px;
|
201 |
}
|
202 |
+
h3 {
|
203 |
+
text-align: center;
|
204 |
+
}
|
205 |
h1, h2 {
|
206 |
color: #81a1c1;
|
207 |
+
text-align: center;
|
208 |
}
|
209 |
.abstract {
|
210 |
background-color: #2e3440;
|
|
|
237 |
</style>
|
238 |
"""
|
239 |
)
|
240 |
+
|
241 |
+
|
242 |
+
# """<div>
|
243 |
+
# <h1>Upside-Down Reinforcement Learning for More Interpretable Optimal Control</h1>
|
244 |
+
# <div class="abstract">
|
245 |
+
# <h2>Abstract</h2>
|
246 |
+
# <p>This research introduces a novel approach to reinforcement learning that emphasizes interpretability and explainability. By leveraging tree-based methods within the Upside-Down Reinforcement Learning (UDRL) framework, we demonstrate that it's possible to achieve performance comparable to neural networks while gaining significant advantages in terms of interpretability.</p>
|
247 |
+
# </div>
|
248 |
+
|
249 |
+
# <h2>What is Upside-Down Reinforcement Learning?</h2>
|
250 |
+
# <p>UDRL is an innovative paradigm that transforms reinforcement learning problems into supervised learning tasks. Unlike traditional approaches that focus on predicting rewards or learning environment models, UDRL learns to predict actions based on:</p>
|
251 |
+
# <ul>
|
252 |
+
# <li>Current state (s<sub>t</sub>)</li>
|
253 |
+
# <li>Desired reward (d<sub>r</sub>)</li>
|
254 |
+
# <li>Time horizon (d<sub>t</sub>)</li>
|
255 |
+
# </ul>
|
256 |
+
|
257 |
+
# <h2>Motivation</h2>
|
258 |
+
# <p>While neural networks have been the go-to choice for implementing UDRL, they lack interpretability. Our research explores whether other supervised learning algorithms, particularly tree-based methods, can:</p>
|
259 |
+
# <ul>
|
260 |
+
# <li>Match the performance of neural networks</li>
|
261 |
+
# <li>Provide more interpretable policies</li>
|
262 |
+
# <li>Enhance the explainability of reinforcement learning systems</li>
|
263 |
+
# </ul>
|
264 |
+
|
265 |
+
# <div class="results">
|
266 |
+
# <h2>Results</h2>
|
267 |
+
# <p>We tested three different implementations of the Behaviour Function:</p>
|
268 |
+
# <ul>
|
269 |
+
# <li>Neural Networks (NN)</li>
|
270 |
+
# <li>Random Forests (RF)</li>
|
271 |
+
# <li>Extremely Randomized Trees (ET)</li>
|
272 |
+
# </ul>
|
273 |
+
# <p>Tests were conducted on three popular OpenAI Gym environments:</p>
|
274 |
+
# <ul>
|
275 |
+
# <li>CartPole</li>
|
276 |
+
# <li>Acrobot</li>
|
277 |
+
# <li>Lunar-Lander</li>
|
278 |
+
# </ul>
|
279 |
+
|
280 |
+
# </div>
|
281 |
+
|
282 |
+
# <h2>Key Findings</h2>
|
283 |
+
# <ul>
|
284 |
+
# <li>Tree-based methods performed comparably to neural networks</li>
|
285 |
+
# <li>Random Forests and Extremely Randomized Trees provided fully interpretable policies</li>
|
286 |
+
# <li>Feature importance analysis revealed insights into decision-making processes</li>
|
287 |
+
# </ul>
|
288 |
+
|
289 |
+
# <h2>Implications</h2>
|
290 |
+
# <p>This research opens new avenues for:</p>
|
291 |
+
# <ul>
|
292 |
+
# <li>More explainable reinforcement learning systems</li>
|
293 |
+
# <li>Enhanced safety in AI decision-making</li>
|
294 |
+
# <li>Better understanding of agent behavior in complex environments</li>
|
295 |
+
# </ul>
|
296 |
+
# </div>
|
297 |
+
# """
|
udrl/app/sim.py
CHANGED
@@ -265,7 +265,8 @@ def make_viz(state):
|
|
265 |
|
266 |
def make_commands(state):
|
267 |
# Add control buttons in a horizontal layout
|
268 |
-
col1, col2, col3, col4 = st.columns(4)
|
|
|
269 |
with col1:
|
270 |
st.button("Reset Environment", on_click=state.reset_env)
|
271 |
with col2:
|
@@ -274,8 +275,8 @@ def make_commands(state):
|
|
274 |
)
|
275 |
with col3:
|
276 |
st.button("Next", on_click=state.next_epoch)
|
277 |
-
with col4:
|
278 |
-
|
279 |
return state
|
280 |
|
281 |
|
|
|
265 |
|
266 |
def make_commands(state):
|
267 |
# Add control buttons in a horizontal layout
|
268 |
+
# col1, col2, col3, col4 = st.columns(4)
|
269 |
+
col1, col2, col3 = st.columns(3)
|
270 |
with col1:
|
271 |
st.button("Reset Environment", on_click=state.reset_env)
|
272 |
with col2:
|
|
|
275 |
)
|
276 |
with col3:
|
277 |
st.button("Next", on_click=state.next_epoch)
|
278 |
+
# with col4:
|
279 |
+
# st.button("Save")
|
280 |
return state
|
281 |
|
282 |
|
udrl/data_proc.py
CHANGED
@@ -3,10 +3,112 @@ import numpy as np
|
|
3 |
import json
|
4 |
import csv
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
naming = {
|
7 |
"neural": "NN",
|
8 |
"ensemble.ExtraTreesClassifier": "ET",
|
9 |
"ensemble.RandomForestClassifier": "RF",
|
|
|
|
|
|
|
|
|
10 |
}
|
11 |
|
12 |
if __name__ == "__main__":
|
@@ -23,6 +125,10 @@ if __name__ == "__main__":
|
|
23 |
"neural": ([], [], [], []),
|
24 |
"ensemble.ExtraTreesClassifier": ([], [], [], []),
|
25 |
"ensemble.RandomForestClassifier": ([], [], [], []),
|
|
|
|
|
|
|
|
|
26 |
}
|
27 |
for exp in all_paths:
|
28 |
print(exp)
|
@@ -30,6 +136,8 @@ if __name__ == "__main__":
|
|
30 |
|
31 |
with open((exp / "conf.json"), "r") as f:
|
32 |
conf = json.load(f)
|
|
|
|
|
33 |
|
34 |
estimators[conf["estimator_name"]][0].append(list(rewards[:, 0]))
|
35 |
estimators[conf["estimator_name"]][1].append(list(rewards[:, 1]))
|
|
|
3 |
import json
|
4 |
import csv
|
5 |
|
6 |
+
|
7 |
+
def convert_json_to_pgfplots(json_file, output_file):
|
8 |
+
"""
|
9 |
+
Convert JSON feature data to pgfplots-compatible format
|
10 |
+
"""
|
11 |
+
# Read JSON file
|
12 |
+
with open(json_file, "r") as f:
|
13 |
+
data = json.load(f)
|
14 |
+
|
15 |
+
# Extract feature values
|
16 |
+
features = data["feature"]
|
17 |
+
|
18 |
+
# Create pgfplots data
|
19 |
+
# Each line will be "x y" format
|
20 |
+
plot_data = []
|
21 |
+
for i, value in features.items():
|
22 |
+
plot_data.append(f"{i} {float(value)}")
|
23 |
+
|
24 |
+
# Write to file
|
25 |
+
with open(output_file, "w") as f:
|
26 |
+
f.write("\n".join(plot_data))
|
27 |
+
|
28 |
+
|
29 |
+
def convert_all_viz(
|
30 |
+
base=Path("data") / "viz_examples", algo="RandomForestClassifier"
|
31 |
+
):
|
32 |
+
|
33 |
+
for env_p in base.iterdir():
|
34 |
+
path = env_p / algo
|
35 |
+
for fil in path.iterdir():
|
36 |
+
if "info" not in fil.name:
|
37 |
+
continue
|
38 |
+
convert_json_to_pgfplots(
|
39 |
+
str(fil), f"{str(fil.parent)}/{fil.stem}.dat"
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def smooth_curve(
|
44 |
+
df,
|
45 |
+
column_name="mean_reward",
|
46 |
+
window_size=10,
|
47 |
+
method="exponential",
|
48 |
+
alpha=0.1,
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
Smooth a column in a pandas DataFrame using different methods.
|
52 |
+
|
53 |
+
Parameters:
|
54 |
+
- df: pandas DataFrame containing the data
|
55 |
+
- column_name: name of the column to smooth
|
56 |
+
- window_size: size of the rolling window for simple moving average
|
57 |
+
- method: 'simple' for simple moving average, 'exponential' for exponential moving average
|
58 |
+
- alpha: smoothing factor for exponential moving average (0 < alpha < 1)
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
- DataFrame with both original and smoothed data
|
62 |
+
"""
|
63 |
+
|
64 |
+
# Create a copy to avoid modifying the original DataFrame
|
65 |
+
df_smoothed = df.copy()
|
66 |
+
|
67 |
+
if method == "simple":
|
68 |
+
# Simple Moving Average
|
69 |
+
df_smoothed[f"{column_name}_smoothed"] = (
|
70 |
+
df[column_name]
|
71 |
+
.rolling(window=window_size, center=True, min_periods=1)
|
72 |
+
.mean()
|
73 |
+
)
|
74 |
+
|
75 |
+
# # Handle NaN values at the beginning and end
|
76 |
+
df_smoothed[f"{column_name}_smoothed"].fillna(
|
77 |
+
df[column_name], inplace=True
|
78 |
+
)
|
79 |
+
|
80 |
+
elif method == "exponential":
|
81 |
+
# Exponential Moving Average
|
82 |
+
df_smoothed[f"{column_name}_smoothed"] = (
|
83 |
+
df[column_name].ewm(alpha=alpha, adjust=False).mean()
|
84 |
+
)
|
85 |
+
|
86 |
+
return df_smoothed
|
87 |
+
|
88 |
+
|
89 |
+
# import pandas as pd
|
90 |
+
|
91 |
+
# path = "data/csvs/Acrobot-v1.csv"
|
92 |
+
# path = "data/csvs/CartPole-v0.csv"
|
93 |
+
# path = "data/csvs/LunarLander-v2.csv"
|
94 |
+
|
95 |
+
# dat = pd.read_csv(path)
|
96 |
+
|
97 |
+
|
98 |
+
# for name in ["NN", "ET", "RF", "KNN", "SVM", "AdaBoost", "XGBoost"]:
|
99 |
+
# dat = smooth_curve(dat, name + "_mean", window_size=20, method="simple")
|
100 |
+
|
101 |
+
# dat.to_csv(path)
|
102 |
+
|
103 |
+
|
104 |
naming = {
|
105 |
"neural": "NN",
|
106 |
"ensemble.ExtraTreesClassifier": "ET",
|
107 |
"ensemble.RandomForestClassifier": "RF",
|
108 |
+
"neighbors.KNeighborsClassifier": "KNN",
|
109 |
+
"svm.SVC": "SVM",
|
110 |
+
"ensemble.AdaBoostClassifier": "AdaBoost",
|
111 |
+
"ensemble.GradientBoostingClassifier": "XGBoost",
|
112 |
}
|
113 |
|
114 |
if __name__ == "__main__":
|
|
|
125 |
"neural": ([], [], [], []),
|
126 |
"ensemble.ExtraTreesClassifier": ([], [], [], []),
|
127 |
"ensemble.RandomForestClassifier": ([], [], [], []),
|
128 |
+
"neighbors.KNeighborsClassifier": ([], [], [], []),
|
129 |
+
"svm.SVC": ([], [], [], []),
|
130 |
+
"ensemble.AdaBoostClassifier": ([], [], [], []),
|
131 |
+
"ensemble.GradientBoostingClassifier": ([], [], [], []),
|
132 |
}
|
133 |
for exp in all_paths:
|
134 |
print(exp)
|
|
|
136 |
|
137 |
with open((exp / "conf.json"), "r") as f:
|
138 |
conf = json.load(f)
|
139 |
+
if conf["estimator_name"] not in list(estimators.keys()):
|
140 |
+
continue
|
141 |
|
142 |
estimators[conf["estimator_name"]][0].append(list(rewards[:, 0]))
|
143 |
estimators[conf["estimator_name"]][1].append(list(rewards[:, 1]))
|
udrl/inference.py
CHANGED
@@ -1,22 +1,25 @@
|
|
1 |
-
import matplotlib.pyplot as plt
|
2 |
import numpy as np
|
3 |
from udrl.policies import SklearnPolicy, NeuralPolicy
|
4 |
from udrl.agent import UpsideDownAgent, AgentHyper
|
5 |
from pathlib import Path
|
6 |
from collections import Counter
|
7 |
-
from
|
|
|
|
|
|
|
8 |
|
|
|
9 |
|
10 |
-
def get_common(base, env, conf, seed):
|
11 |
|
12 |
-
|
|
|
|
|
13 |
|
14 |
if not path.exists():
|
15 |
print("Cannot find path")
|
16 |
return None, None
|
17 |
-
algo_name
|
18 |
-
|
19 |
-
)
|
20 |
|
21 |
des_ret = np.load(str(path / "desired_returns.npy")).astype(int)
|
22 |
des_hor = np.load(str(path / "desired_horizons.npy")).astype(int)
|
@@ -53,14 +56,12 @@ def get_common(base, env, conf, seed):
|
|
53 |
return common_ret, common_hor
|
54 |
|
55 |
|
56 |
-
def test_desired(base, env, conf, des_ret, des_hor):
|
57 |
|
58 |
-
algo_name = (
|
59 |
-
"NN" if "neural" in conf else ("ET" if "Extra" in conf else "RT")
|
60 |
-
)
|
61 |
if des_hor is None or des_ret is None:
|
62 |
print(f"Invalid desired for {env}:{algo_name}")
|
63 |
return
|
|
|
64 |
for path in (base / env / conf).iterdir():
|
65 |
if "neural" in conf:
|
66 |
policy = NeuralPolicy.load(str(path / "policy"))
|
@@ -80,43 +81,166 @@ def test_desired(base, env, conf, des_ret, des_hor):
|
|
80 |
)[0]
|
81 |
for _ in range(100)
|
82 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
print(
|
84 |
f"{env}:{algo_name}:{path.name}:r.{des_ret}:h.{des_hor}"
|
85 |
f" -> {np.median(final_r):.2f} +- {np.std(final_r):.2f}"
|
86 |
f",max {np.max(final_r):.2f},min {np.min(final_r):.2f}"
|
87 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
|
90 |
-
base = Path("/home/vimmoos/upside_down_rl/data")
|
91 |
-
confs = {
|
92 |
-
"NN": "estimator_nameneural_batch_size256_warm_up260",
|
93 |
-
"ET": "estimator_nameensemble.ExtraTreesClassifier_train_per_iter1",
|
94 |
-
"RT": "train_per_iter1",
|
95 |
-
}
|
96 |
envs = ["LunarLander-v2", "Acrobot-v1"]
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
res = {}
|
100 |
|
101 |
|
102 |
for env in envs:
|
103 |
res[env] = {}
|
104 |
-
|
105 |
-
|
106 |
-
for
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
|
111 |
pprint(res)
|
|
|
112 |
|
113 |
for env, algos in res.items():
|
114 |
for algo, seeds in algos.items():
|
115 |
for _, vals in seeds.items():
|
116 |
-
test_desired(base, env,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
from udrl.policies import SklearnPolicy, NeuralPolicy
|
3 |
from udrl.agent import UpsideDownAgent, AgentHyper
|
4 |
from pathlib import Path
|
5 |
from collections import Counter
|
6 |
+
from pprint import pprint
|
7 |
+
import re
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from typing import Dict, Any
|
10 |
|
11 |
+
# from tqdm import trange, tqdm
|
12 |
|
|
|
13 |
|
14 |
+
def get_common(base="", env="", conf="", seed="", path=None, algo_name=None):
|
15 |
+
if path is None:
|
16 |
+
path = base / env / conf / seed
|
17 |
|
18 |
if not path.exists():
|
19 |
print("Cannot find path")
|
20 |
return None, None
|
21 |
+
if algo_name is None:
|
22 |
+
algo_name = "UNKNOWN"
|
|
|
23 |
|
24 |
des_ret = np.load(str(path / "desired_returns.npy")).astype(int)
|
25 |
des_hor = np.load(str(path / "desired_horizons.npy")).astype(int)
|
|
|
56 |
return common_ret, common_hor
|
57 |
|
58 |
|
59 |
+
def test_desired(base, env, conf, algo_name, des_ret, des_hor):
|
60 |
|
|
|
|
|
|
|
61 |
if des_hor is None or des_ret is None:
|
62 |
print(f"Invalid desired for {env}:{algo_name}")
|
63 |
return
|
64 |
+
ret = []
|
65 |
for path in (base / env / conf).iterdir():
|
66 |
if "neural" in conf:
|
67 |
policy = NeuralPolicy.load(str(path / "policy"))
|
|
|
81 |
)[0]
|
82 |
for _ in range(100)
|
83 |
]
|
84 |
+
ret.append(
|
85 |
+
{
|
86 |
+
"env": env,
|
87 |
+
"algo": algo_name,
|
88 |
+
"seed": path.name,
|
89 |
+
"des_ret": des_ret,
|
90 |
+
"des_hor": des_hor,
|
91 |
+
"final_r": np.median(final_r),
|
92 |
+
"final_r_std": np.std(final_r),
|
93 |
+
"final_r_max": np.max(final_r),
|
94 |
+
"final_r_min": np.min(final_r),
|
95 |
+
"final_raw": final_r,
|
96 |
+
}
|
97 |
+
)
|
98 |
print(
|
99 |
f"{env}:{algo_name}:{path.name}:r.{des_ret}:h.{des_hor}"
|
100 |
f" -> {np.median(final_r):.2f} +- {np.std(final_r):.2f}"
|
101 |
f",max {np.max(final_r):.2f},min {np.min(final_r):.2f}"
|
102 |
)
|
103 |
+
return ret
|
104 |
+
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class RunStats:
|
108 |
+
median: float
|
109 |
+
mean: float
|
110 |
+
std: float
|
111 |
+
min_val: float
|
112 |
+
max_val: float
|
113 |
+
infos: Dict[str, Any] = field(repr=False)
|
114 |
+
weights: Dict[str, float] = field(
|
115 |
+
repr=False,
|
116 |
+
default_factory=lambda: {
|
117 |
+
"median": 0.5,
|
118 |
+
"mean": 0.1,
|
119 |
+
"std_penalty": 0.3,
|
120 |
+
"range_penalty": 0.1,
|
121 |
+
},
|
122 |
+
)
|
123 |
+
score: float = field(init=False, default=-np.inf)
|
124 |
+
|
125 |
+
def __post_init__(self):
|
126 |
+
self.score = self.calculate_score()
|
127 |
+
|
128 |
+
def calculate_score(self):
|
129 |
+
"""
|
130 |
+
Calculate a composite score for a run.
|
131 |
+
|
132 |
+
The score is calculated as:
|
133 |
+
score = (w1 * median + w2 * mean) * stability_factor
|
134 |
+
|
135 |
+
where stability_factor penalizes high std and wide ranges
|
136 |
+
"""
|
137 |
+
shift = 1000 if self.median < 0 else 0
|
138 |
+
|
139 |
+
base_score = self.weights["median"] * (
|
140 |
+
self.median + shift
|
141 |
+
) + self.weights["mean"] * (self.mean + shift)
|
142 |
+
|
143 |
+
std_factor = 1 / (1 + self.weights["std_penalty"] * self.std)
|
144 |
+
range_factor = 1 / (
|
145 |
+
1 + self.weights["range_penalty"] * (self.max_val - self.min_val)
|
146 |
+
)
|
147 |
+
|
148 |
+
return base_score + std_factor + range_factor
|
149 |
+
|
150 |
+
|
151 |
+
def extract_statistics(run):
|
152 |
+
return RunStats(
|
153 |
+
median=run["final_r"],
|
154 |
+
mean=np.mean(run["final_raw"]),
|
155 |
+
std=run["final_r_std"],
|
156 |
+
min_val=run["final_r_min"],
|
157 |
+
max_val=run["final_r_max"],
|
158 |
+
infos=run,
|
159 |
+
)
|
160 |
+
|
161 |
+
|
162 |
+
base = Path("data")
|
163 |
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
envs = ["LunarLander-v2", "Acrobot-v1"]
|
166 |
+
algo_name_extract = r"(.*/)+(estimator_name(.+?)_train.*save_desired.*)$"
|
167 |
+
available_algo = []
|
168 |
+
special_confs = {
|
169 |
+
"NN": "estimator_nameneural_batch_size256_warm_up260_save_desiredTrue",
|
170 |
+
"RT": "train_per_iter1_save_desiredTrue",
|
171 |
+
}
|
172 |
+
wanted_algo = [
|
173 |
+
"ensemble.ExtraTreesClassifier",
|
174 |
+
"neighbors.KNeighborsClassifier",
|
175 |
+
"NN",
|
176 |
+
"RT",
|
177 |
+
"svm.SVC",
|
178 |
+
"ensemble.AdaBoostClassifier",
|
179 |
+
"ensemble.GradientBoostingClassifier",
|
180 |
+
]
|
181 |
|
182 |
res = {}
|
183 |
|
184 |
|
185 |
for env in envs:
|
186 |
res[env] = {}
|
187 |
+
available_algo = [
|
188 |
+
re.match(algo_name_extract, str(x))
|
189 |
+
for x in (base / env).iterdir()
|
190 |
+
if re.match(algo_name_extract, str(x))
|
191 |
+
]
|
192 |
+
for algo_match in available_algo:
|
193 |
+
conf = algo_match.group(2)
|
194 |
+
algo_name = algo_match.group(3)
|
195 |
+
res[env][(algo_name, conf)] = {}
|
196 |
+
for seed_path in Path(algo_match.string).iterdir():
|
197 |
+
seed = seed_path.stem
|
198 |
+
ret, hor = get_common(path=seed_path, algo_name=algo_name)
|
199 |
+
res[env][(algo_name, conf)][seed] = (ret, hor)
|
200 |
+
for algo_name, conf in special_confs.items():
|
201 |
+
res[env][(algo_name, conf)] = {}
|
202 |
+
for seed_path in (base / env / conf).iterdir():
|
203 |
+
seed = seed_path.stem
|
204 |
+
ret, hor = get_common(path=seed_path, algo_name=algo_name)
|
205 |
+
res[env][(algo_name, conf)][seed] = (ret, hor)
|
206 |
|
207 |
|
208 |
pprint(res)
|
209 |
+
data = []
|
210 |
|
211 |
for env, algos in res.items():
|
212 |
for algo, seeds in algos.items():
|
213 |
for _, vals in seeds.items():
|
214 |
+
data.append(test_desired(base, env, algo[1], algo[0], *vals))
|
215 |
+
|
216 |
+
|
217 |
+
best_res = {env: {algo: None for algo in wanted_algo} for env in envs}
|
218 |
+
|
219 |
+
for runs in data:
|
220 |
+
if runs[0]["algo"] not in wanted_algo:
|
221 |
+
continue
|
222 |
+
for run in runs:
|
223 |
+
stats = extract_statistics(run)
|
224 |
+
print(
|
225 |
+
f"{run['env']}:{run['algo']}:{run['seed']}:r.{run['des_ret']}:h.{run['des_hor']}"
|
226 |
+
f" -> SCORE {stats.score} \t {run['final_r']:.2f} +- {run['final_r_std']:.2f}"
|
227 |
+
f",max {run['final_r_max']:.2f},min {run['final_r_min']:.2f}"
|
228 |
+
)
|
229 |
+
current_best = best_res[run["env"]][run["algo"]]
|
230 |
+
if current_best is None:
|
231 |
+
best_res[run["env"]][run["algo"]] = stats
|
232 |
+
continue
|
233 |
+
|
234 |
+
if current_best.score < stats.score:
|
235 |
+
best_res[run["env"]][run["algo"]] = stats
|
236 |
|
237 |
|
238 |
+
pprint(best_res)
|
239 |
+
|
240 |
+
best_res["LunarLander-v2"]["svm.SVC"].infos
|
241 |
+
for env, vs in best_res.items():
|
242 |
+
for algo, v in vs.items():
|
243 |
+
run = v.infos
|
244 |
+
print(
|
245 |
+
f"{run['env']}:{run['algo']}: r. {run['des_ret']} : h. {run['des_hor']}"
|
246 |
+
)
|
udrl/test.py
DELETED
@@ -1,137 +0,0 @@
|
|
1 |
-
# import gymnasium as gym
|
2 |
-
# import pygame
|
3 |
-
# import numpy as np
|
4 |
-
|
5 |
-
|
6 |
-
# def normalize_value(value, is_bounded, low=None, high=None):
|
7 |
-
# if is_bounded:
|
8 |
-
# return (value - low) / (high - low)
|
9 |
-
# else:
|
10 |
-
# return 0.5 * (np.tanh(value / 2) + 1)
|
11 |
-
|
12 |
-
|
13 |
-
# def draw_bar(screen, start, value, max_length, color, height=20):
|
14 |
-
# bar_length = value * max_length
|
15 |
-
# pygame.draw.rect(screen, color, (*start, bar_length, height))
|
16 |
-
# pygame.draw.rect(
|
17 |
-
# screen, (0, 0, 0), (*start, max_length, height), 2
|
18 |
-
# ) # Border
|
19 |
-
# mid_x = start[0] + max_length / 2
|
20 |
-
# pygame.draw.line(
|
21 |
-
# screen, (0, 0, 0), (mid_x, start[1]), (mid_x, start[1] + height), 2
|
22 |
-
# )
|
23 |
-
|
24 |
-
|
25 |
-
# def visualize_environment(screen, state, env):
|
26 |
-
# screen_width, screen_height = screen.get_size()
|
27 |
-
# screen.fill((255, 255, 255))
|
28 |
-
|
29 |
-
# # Visualize environment-specific elements
|
30 |
-
# if env.spec.id.startswith("CartPole"):
|
31 |
-
# cart_x = int(state[0] * 50 + screen_width // 2)
|
32 |
-
# cart_y = screen_height - 100
|
33 |
-
# pole_angle = state[2]
|
34 |
-
# pygame.draw.rect(screen, (0, 0, 0), (cart_x - 30, cart_y - 15, 60, 30))
|
35 |
-
# pygame.draw.line(
|
36 |
-
# screen,
|
37 |
-
# (0, 0, 0),
|
38 |
-
# (cart_x, cart_y),
|
39 |
-
# (
|
40 |
-
# cart_x + int(np.sin(pole_angle) * 100),
|
41 |
-
# cart_y - int(np.cos(pole_angle) * 100),
|
42 |
-
# ),
|
43 |
-
# 6,
|
44 |
-
# )
|
45 |
-
# elif env.spec.id.startswith("Acrobot"):
|
46 |
-
# center_x, center_y = screen_width // 2, screen_height // 2
|
47 |
-
# l1, l2 = 100, 100 # Length of links
|
48 |
-
# s0, s1 = state[0], state[1] # sin(theta1), sin(theta2)
|
49 |
-
# c0, c1 = state[2], state[3] # cos(theta1), cos(theta2)
|
50 |
-
# x0, y0 = center_x, center_y
|
51 |
-
# x1 = x0 + l1 * s0
|
52 |
-
# y1 = y0 + l1 * c0
|
53 |
-
# x2 = x1 + l2 * s1
|
54 |
-
# y2 = y1 + l2 * c1
|
55 |
-
# pygame.draw.line(screen, (0, 0, 0), (x0, y0), (x1, y1), 6)
|
56 |
-
# pygame.draw.line(screen, (0, 0, 0), (x1, y1), (x2, y2), 6)
|
57 |
-
# pygame.draw.circle(screen, (0, 0, 255), (int(x0), int(y0)), 10)
|
58 |
-
# pygame.draw.circle(screen, (0, 255, 0), (int(x1), int(y1)), 10)
|
59 |
-
# pygame.draw.circle(screen, (255, 0, 0), (int(x2), int(y2)), 10)
|
60 |
-
# # Add more environment-specific visualizations here as needed
|
61 |
-
|
62 |
-
# # Draw bars for each state dimension
|
63 |
-
# num_dims = env.observation_space.shape[0]
|
64 |
-
# bar_colors = [
|
65 |
-
# (255, 0, 0),
|
66 |
-
# (0, 255, 0),
|
67 |
-
# (0, 0, 255),
|
68 |
-
# (255, 255, 0),
|
69 |
-
# (255, 0, 255),
|
70 |
-
# (0, 255, 255),
|
71 |
-
# ]
|
72 |
-
# bar_starts = [(50, 50 + i * 70) for i in range(num_dims)]
|
73 |
-
# max_length = 300
|
74 |
-
|
75 |
-
# for i, (start, color) in enumerate(zip(bar_starts, bar_colors)):
|
76 |
-
# is_bounded = not (
|
77 |
-
# env.observation_space.high[i] > 100
|
78 |
-
# ) and not np.isinf(env.observation_space.low[i] < -100)
|
79 |
-
# normalized_value = normalize_value(
|
80 |
-
# state[i],
|
81 |
-
# is_bounded,
|
82 |
-
# env.observation_space.low[i],
|
83 |
-
# env.observation_space.high[i],
|
84 |
-
# )
|
85 |
-
# draw_bar(screen, start, normalized_value, max_length, color)
|
86 |
-
|
87 |
-
# # Draw labels
|
88 |
-
# font = pygame.font.Font(None, 30)
|
89 |
-
# text = font.render(f"Dim {i}: {state[i]:.2f}", True, (0, 0, 0))
|
90 |
-
# screen.blit(text, (start[0], start[1] - 30))
|
91 |
-
|
92 |
-
# # Add description of bar representation
|
93 |
-
# if is_bounded:
|
94 |
-
# desc = f"(Range: {env.observation_space.low[i]:.2f} to {env.observation_space.high[i]:.2f})"
|
95 |
-
# else:
|
96 |
-
# desc = "(Unbounded: Center is 0, edges are ±∞)"
|
97 |
-
# desc_text = pygame.font.Font(None, 24).render(
|
98 |
-
# desc, True, (100, 100, 100)
|
99 |
-
# )
|
100 |
-
# screen.blit(desc_text, (start[0], start[1] + 25))
|
101 |
-
|
102 |
-
# pygame.display.flip()
|
103 |
-
|
104 |
-
|
105 |
-
# def run_visualization(env_name):
|
106 |
-
# pygame.init()
|
107 |
-
# screen = pygame.display.set_mode((800, 600))
|
108 |
-
# pygame.display.set_caption(f"{env_name} Visualization")
|
109 |
-
|
110 |
-
# env = gym.make(env_name)
|
111 |
-
# state, _ = env.reset()
|
112 |
-
|
113 |
-
# clock = pygame.time.Clock()
|
114 |
-
|
115 |
-
# running = True
|
116 |
-
# while running:
|
117 |
-
# visualize_environment(screen, state, env)
|
118 |
-
# action = env.action_space.sample()
|
119 |
-
# state, reward, done, truncated, info = env.step(action)
|
120 |
-
|
121 |
-
# if done or truncated:
|
122 |
-
# state, _ = env.reset()
|
123 |
-
|
124 |
-
# for event in pygame.event.get():
|
125 |
-
# if event.type == pygame.QUIT:
|
126 |
-
# running = False
|
127 |
-
|
128 |
-
# clock.tick(60) # Limit to 60 FPS
|
129 |
-
|
130 |
-
# env.close()
|
131 |
-
# pygame.quit()
|
132 |
-
|
133 |
-
|
134 |
-
# # Example usage
|
135 |
-
# # run_visualization("CartPole-v1")
|
136 |
-
# # Uncomment the line below to run Acrobot visualization
|
137 |
-
# run_visualization("Acrobot-v1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|