vimmoos@Thor commited on
Commit
c6a28ec
·
1 Parent(s): 62f50f1

base gym env

Browse files
app.py CHANGED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import gymnasium as gym
3
+ import numpy as np
4
+ from PIL import Image
5
+ import time
6
+
7
+ # Initialize session state variables if they don't exist
8
+ if "env" not in st.session_state:
9
+ st.session_state.env = gym.make("LunarLander-v2", render_mode="rgb_array")
10
+ st.session_state.env.reset()
11
+ st.session_state.frame = st.session_state.env.render()
12
+ if "paused" not in st.session_state:
13
+ st.session_state.paused = False
14
+
15
+
16
+ # Function to reset the environment
17
+ def reset_environment():
18
+ st.session_state.env.reset()
19
+
20
+
21
+ # Function to toggle pause state
22
+ def toggle_pause():
23
+ st.session_state.paused = not st.session_state.paused
24
+
25
+
26
+ # Create the Streamlit app
27
+ st.title("Gymnasium Environment Viewer")
28
+
29
+ # Add control buttons in a horizontal layout
30
+ col1, col2 = st.columns(2)
31
+ with col1:
32
+ st.button("Reset Environment", on_click=reset_environment)
33
+ with col2:
34
+ if st.session_state.paused:
35
+ st.button("Resume", on_click=toggle_pause)
36
+ else:
37
+ st.button("Pause", on_click=toggle_pause)
38
+
39
+ # Create a placeholder for the image
40
+ image_placeholder = st.empty()
41
+
42
+ # Create a container for environment info
43
+ sidebar_container = st.sidebar.container()
44
+
45
+ # Main simulation loop using rerun
46
+ if not st.session_state.paused:
47
+ # Take a random action
48
+ action = st.session_state.env.action_space.sample()
49
+ observation, reward, terminated, truncated, info = (
50
+ st.session_state.env.step(action)
51
+ )
52
+
53
+ # Render the environment
54
+ st.session_state.frame = st.session_state.env.render()
55
+
56
+ # Reset if the episode is done
57
+ if terminated or truncated:
58
+ st.session_state.env.reset()
59
+ # Display the frame
60
+ if st.session_state.paused:
61
+ image_placeholder.image(
62
+ st.session_state.frame,
63
+ caption="Environment Visualization (Paused)",
64
+ use_column_width=True,
65
+ )
66
+ else:
67
+ image_placeholder.image(
68
+ st.session_state.frame,
69
+ caption="Environment Visualization",
70
+ use_column_width=True,
71
+ )
72
+
73
+ # Display some information about the environment
74
+ with sidebar_container:
75
+ st.header("Environment Info")
76
+ st.write(f"Action Space: {st.session_state.env.action_space}")
77
+ st.write(f"Observation Space: {st.session_state.env.observation_space}")
78
+
79
+ # Add auto-refresh logic
80
+ if not st.session_state.paused:
81
+ time.sleep(0.1) # Add a small delay to control refresh rate
82
+ st.rerun()
83
+
84
+ # fig, ax = plt.subplots()
85
+ # ax.imshow(env.render())
86
+ # st.pyplot(fig)
87
+ # st.image(env.render())
88
+
89
+
90
+ # import gymnasium as gym
91
+ # import streamlit as st
92
+ # import numpy as np
93
+ # from udrl.policies import SklearnPolicy
94
+ # from udrl.agent import UpsideDownAgent, AgentHyper
95
+ # from pathlib import Path
96
+
97
+ # # import json
98
+
99
+
100
+ # def normalize_value(value, is_bounded, low=None, high=None):
101
+ # return (value - low) / (high - low)
102
+
103
+
104
+ # def visualize_environment(
105
+ # state,
106
+ # env,
107
+ # # paused,
108
+ # feature_importances,
109
+ # epoch,
110
+ # max_epoch=200,
111
+ # ):
112
+
113
+ # st.image(env.render())
114
+ # st.image(e)
115
+ # # Render the Gym environment
116
+ # # env_render = env.render()
117
+
118
+ # # # Display the rendered image using Streamlit
119
+ # # st.image(env_render, caption=f"Epoch {epoch}", use_column_width=True)
120
+
121
+ # # Display feature importances using Streamlit metrics
122
+ # # cols = st.columns(len(feature_importances))
123
+ # # for i, col in enumerate(cols):
124
+ # # col.metric(
125
+ # # label=f"Importance {i}", value=f"{feature_importances[i]:.2f}"
126
+ # # )
127
+
128
+ # # Create buttons using Streamlit
129
+ # # reset_button = st.button("Reset")
130
+ # # pause_play_button = st.button("Pause" if not paused else "Play")
131
+ # # next_button = st.button("Next")
132
+ # # save_button = st.button("Save")
133
+
134
+ # # return reset_button, pause_play_button, next_button, save_button
135
+
136
+
137
+ # def run_visualization(
138
+ # env_name,
139
+ # agent,
140
+ # init_desired_return,
141
+ # init_desired_horizon,
142
+ # max_epoch,
143
+ # base_path,
144
+ # ):
145
+ # # base_path = (
146
+ # # Path(base_path) / env_name / agent.policy.estimator.__str__()[:-2]
147
+ # # )
148
+ # # base_path.mkdir(parents=True, exist_ok=True)
149
+ # desired_return = init_desired_return
150
+ # desired_horizon = init_desired_horizon
151
+
152
+ # # Initialize the Gym environment
153
+ # env = gym.make(env_name, render_mode="rgb_array")
154
+ # state, _ = env.reset()
155
+
156
+ # epoch = 0
157
+ # # save_index = 0
158
+
159
+ # # paused = False
160
+ # # step = False
161
+
162
+ # # # Use Streamlit session state to manage paused state
163
+ # # if "paused" not in st.session_state:
164
+ # # st.session_state.paused = False
165
+
166
+ # while True:
167
+ # # Render and display the environment
168
+ # env_render = env.render()
169
+ # # if not st.session_state.pausedor step:
170
+ # command = np.array(
171
+ # [
172
+ # desired_return * agent.conf.return_scale,
173
+ # desired_horizon * agent.conf.horizon_scale,
174
+ # ]
175
+ # )
176
+ # command = np.expand_dims(command, axis=0)
177
+ # state = np.expand_dims(state, axis=0)
178
+
179
+ # action = agent.policy(state, command, True)
180
+
181
+ # ext_state = np.concatenate((state, command), axis=1)
182
+
183
+ # state, reward, done, truncated, info = env.step(action)
184
+
185
+ # feature_importances = {idx: [] for idx in range(ext_state.shape[1])}
186
+
187
+ # for t in agent.policy.estimator.estimators_:
188
+ # branch = np.array(t.decision_path(ext_state).todense(), dtype=bool)
189
+ # imp = t.tree_.impurity[branch[0]]
190
+
191
+ # for f, i in zip(
192
+ # t.tree_.feature[branch[0]][:-1], imp[:-1] - imp[1:]
193
+ # ):
194
+ # feature_importances.setdefault(f, []).append(i)
195
+
196
+ # # Line 8 Algorithm 2
197
+ # desired_return -= reward
198
+ # # Line 9 Algorithm 2
199
+ # desired_horizon = max(desired_horizon - 1, 1)
200
+
201
+ # summed_importances = [
202
+ # sum(feature_importances.get(k, [0.001]))
203
+ # for k in range(len(feature_importances.keys()))
204
+ # ]
205
+
206
+ # epoch += 1
207
+ # visualize_environment(
208
+ # state,
209
+ # env,
210
+ # # st.session_state.paused, # Use session state
211
+ # summed_importances,
212
+ # epoch,
213
+ # max_epoch,
214
+ # )
215
+ # # reset_button, pause_play_button, next_button, save_button = (
216
+
217
+ # # )
218
+
219
+ # if done or truncated:
220
+ # state, _ = env.reset()
221
+ # desired_horizon = init_desired_horizon
222
+ # desired_return = init_desired_return
223
+ # epoch = 0
224
+
225
+ # # step = False
226
+
227
+ # # Handle button clicks
228
+ # # if reset_button:
229
+ # # state, _ = env.reset()
230
+ # # desired_horizon = init_desired_horizon
231
+ # # desired_return = init_desired_return
232
+ # # epoch = 0
233
+ # # elif pause_play_button:
234
+ # # st.session_state.paused = (
235
+ # # not st.session_state.paused
236
+ # # ) # Toggle paused state
237
+ # # elif next_button and st.session_state.paused:
238
+ # # step = True
239
+ # # elif save_button:
240
+ # # # Save image and info using Streamlit
241
+ # # st.image(
242
+ # # env_render, caption=f"Epoch {epoch}", use_column_width=True
243
+ # # )
244
+ # # st.write(
245
+ # # {
246
+ # # "state": {i: str(val) for i, val in enumerate(state)},
247
+ # # "feature": {
248
+ # # i: str(val) for i, val in enumerate(summed_importances)
249
+ # # },
250
+ # # "action": str(action),
251
+ # # "reward": str(reward),
252
+ # # "desired_return": str(desired_return + reward),
253
+ # # "desired_horizon": str(desired_horizon + 1),
254
+ # # }
255
+ # # )
256
+
257
+ # env.close()
258
+
259
+
260
+ # env = "Acrobot-v1"
261
+ # desired_return = -79
262
+ # desired_horizon = 82
263
+ # max_epoch = 500
264
+
265
+
266
+ # policy = SklearnPolicy.load("policy")
267
+ # hyper = AgentHyper(
268
+ # env,
269
+ # warm_up=0,
270
+ # )
271
+
272
+ # agent = UpsideDownAgent(hyper, policy)
273
+
274
+ # run_visualization(
275
+ # env, agent, desired_return, desired_horizon, max_epoch, "data/viz_examples"
276
+ # )
old_code/experiment_3/q_networks/buffers/CartPole-v0/1/DQN/memory_buffer.p DELETED
The diff for this file is too large to render. See raw diff
 
poetry.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -26,6 +26,7 @@ python = "3.10.14"
26
  scikit-learn = "^1.5.2"
27
  matplotlib = "^3.9.2"
28
  gymnasium = {extras = ["box2d"], version = "^0.29.1"}
 
29
  scikit-image = "^0.24.0"
30
  tqdm = "^4.66.5"
31
  torch = "^2.4.1"
 
26
  scikit-learn = "^1.5.2"
27
  matplotlib = "^3.9.2"
28
  gymnasium = {extras = ["box2d"], version = "^0.29.1"}
29
+ numpy = "1.24.4"
30
  scikit-image = "^0.24.0"
31
  tqdm = "^4.66.5"
32
  torch = "^2.4.1"
udrl/__main__.py CHANGED
@@ -1,6 +1,6 @@
1
- from .agent import UpsideDownAgent, AgentHyper
2
- from .policies import SklearnPolicy, NeuralPolicy
3
- from .catch import CatchAdaptor
4
  from dataclasses import dataclass, asdict
5
  import gymnasium as gym
6
  from tqdm import trange
 
1
+ from udrl.agent import UpsideDownAgent, AgentHyper
2
+ from udrl.policies import SklearnPolicy, NeuralPolicy
3
+ from udrl.catch import CatchAdaptor
4
  from dataclasses import dataclass, asdict
5
  import gymnasium as gym
6
  from tqdm import trange
udrl/agent.py CHANGED
@@ -2,9 +2,9 @@ from dataclasses import dataclass
2
  import gymnasium as gym
3
  import numpy as np
4
 
5
- from .catch import CatchAdaptor
6
- from .policies import ABCPolicy
7
- from .buffer import ReplayBuffer
8
 
9
 
10
  @dataclass
 
2
  import gymnasium as gym
3
  import numpy as np
4
 
5
+ from udrl.catch import CatchAdaptor
6
+ from udrl.policies import ABCPolicy
7
+ from udrl.buffer import ReplayBuffer
8
 
9
 
10
  @dataclass
udrl/inference.py CHANGED
@@ -1,7 +1,7 @@
1
  import matplotlib.pyplot as plt
2
  import numpy as np
3
- from .policies import SklearnPolicy, NeuralPolicy
4
- from .agent import UpsideDownAgent, AgentHyper
5
  from pathlib import Path
6
  from collections import Counter
7
  from tqdm import trange
 
1
  import matplotlib.pyplot as plt
2
  import numpy as np
3
+ from udrl.policies import SklearnPolicy, NeuralPolicy
4
+ from udrl.agent import UpsideDownAgent, AgentHyper
5
  from pathlib import Path
6
  from collections import Counter
7
  from tqdm import trange
udrl/plot.py CHANGED
@@ -1,5 +1,5 @@
1
- from .policies import SklearnPolicy
2
- from .agent import UpsideDownAgent, AgentHyper
3
  from pathlib import Path
4
  import matplotlib.pyplot as plt
5
  import numpy as np
 
1
+ from udrl.policies import SklearnPolicy
2
+ from udrl.agent import UpsideDownAgent, AgentHyper
3
  from pathlib import Path
4
  import matplotlib.pyplot as plt
5
  import numpy as np
udrl/viz.py CHANGED
@@ -1,8 +1,8 @@
1
  import gymnasium as gym
2
  import pygame
3
  import numpy as np
4
- from .policies import SklearnPolicy
5
- from .agent import UpsideDownAgent, AgentHyper
6
  from pathlib import Path
7
  import json
8
 
 
1
  import gymnasium as gym
2
  import pygame
3
  import numpy as np
4
+ from udrl.policies import SklearnPolicy
5
+ from udrl.agent import UpsideDownAgent, AgentHyper
6
  from pathlib import Path
7
  import json
8