import streamlit as st | |
import gymnasium as gym | |
import numpy as np | |
from PIL import Image | |
import time | |
# Initialize session state variables if they don't exist | |
if "env" not in st.session_state: | |
st.session_state.env = gym.make("LunarLander-v2", render_mode="rgb_array") | |
st.session_state.env.reset() | |
st.session_state.frame = st.session_state.env.render() | |
if "paused" not in st.session_state: | |
st.session_state.paused = False | |
# Function to reset the environment | |
def reset_environment(): | |
st.session_state.env.reset() | |
# Function to toggle pause state | |
def toggle_pause(): | |
st.session_state.paused = not st.session_state.paused | |
# Create the Streamlit app | |
st.title("Gymnasium Environment Viewer") | |
# Add control buttons in a horizontal layout | |
col1, col2 = st.columns(2) | |
with col1: | |
st.button("Reset Environment", on_click=reset_environment) | |
with col2: | |
if st.session_state.paused: | |
st.button("Resume", on_click=toggle_pause) | |
else: | |
st.button("Pause", on_click=toggle_pause) | |
# Create a placeholder for the image | |
image_placeholder = st.empty() | |
# Create a container for environment info | |
sidebar_container = st.sidebar.container() | |
# Main simulation loop using rerun | |
if not st.session_state.paused: | |
# Take a random action | |
action = st.session_state.env.action_space.sample() | |
observation, reward, terminated, truncated, info = ( | |
st.session_state.env.step(action) | |
) | |
# Render the environment | |
st.session_state.frame = st.session_state.env.render() | |
# Reset if the episode is done | |
if terminated or truncated: | |
st.session_state.env.reset() | |
# Display the frame | |
if st.session_state.paused: | |
image_placeholder.image( | |
st.session_state.frame, | |
caption="Environment Visualization (Paused)", | |
use_column_width=True, | |
) | |
else: | |
image_placeholder.image( | |
st.session_state.frame, | |
caption="Environment Visualization", | |
use_column_width=True, | |
) | |
# Display some information about the environment | |
with sidebar_container: | |
st.header("Environment Info") | |
st.write(f"Action Space: {st.session_state.env.action_space}") | |
st.write(f"Observation Space: {st.session_state.env.observation_space}") | |
# Add auto-refresh logic | |
if not st.session_state.paused: | |
time.sleep(0.1) # Add a small delay to control refresh rate | |
st.rerun() | |
# fig, ax = plt.subplots() | |
# ax.imshow(env.render()) | |
# st.pyplot(fig) | |
# st.image(env.render()) | |
# import gymnasium as gym | |
# import streamlit as st | |
# import numpy as np | |
# from udrl.policies import SklearnPolicy | |
# from udrl.agent import UpsideDownAgent, AgentHyper | |
# from pathlib import Path | |
# # import json | |
# def normalize_value(value, is_bounded, low=None, high=None): | |
# return (value - low) / (high - low) | |
# def visualize_environment( | |
# state, | |
# env, | |
# # paused, | |
# feature_importances, | |
# epoch, | |
# max_epoch=200, | |
# ): | |
# st.image(env.render()) | |
# st.image(e) | |
# # Render the Gym environment | |
# # env_render = env.render() | |
# # # Display the rendered image using Streamlit | |
# # st.image(env_render, caption=f"Epoch {epoch}", use_column_width=True) | |
# # Display feature importances using Streamlit metrics | |
# # cols = st.columns(len(feature_importances)) | |
# # for i, col in enumerate(cols): | |
# # col.metric( | |
# # label=f"Importance {i}", value=f"{feature_importances[i]:.2f}" | |
# # ) | |
# # Create buttons using Streamlit | |
# # reset_button = st.button("Reset") | |
# # pause_play_button = st.button("Pause" if not paused else "Play") | |
# # next_button = st.button("Next") | |
# # save_button = st.button("Save") | |
# # return reset_button, pause_play_button, next_button, save_button | |
# def run_visualization( | |
# env_name, | |
# agent, | |
# init_desired_return, | |
# init_desired_horizon, | |
# max_epoch, | |
# base_path, | |
# ): | |
# # base_path = ( | |
# # Path(base_path) / env_name / agent.policy.estimator.__str__()[:-2] | |
# # ) | |
# # base_path.mkdir(parents=True, exist_ok=True) | |
# desired_return = init_desired_return | |
# desired_horizon = init_desired_horizon | |
# # Initialize the Gym environment | |
# env = gym.make(env_name, render_mode="rgb_array") | |
# state, _ = env.reset() | |
# epoch = 0 | |
# # save_index = 0 | |
# # paused = False | |
# # step = False | |
# # # Use Streamlit session state to manage paused state | |
# # if "paused" not in st.session_state: | |
# # st.session_state.paused = False | |
# while True: | |
# # Render and display the environment | |
# env_render = env.render() | |
# # if not st.session_state.pausedor step: | |
# command = np.array( | |
# [ | |
# desired_return * agent.conf.return_scale, | |
# desired_horizon * agent.conf.horizon_scale, | |
# ] | |
# ) | |
# command = np.expand_dims(command, axis=0) | |
# state = np.expand_dims(state, axis=0) | |
# action = agent.policy(state, command, True) | |
# ext_state = np.concatenate((state, command), axis=1) | |
# state, reward, done, truncated, info = env.step(action) | |
# feature_importances = {idx: [] for idx in range(ext_state.shape[1])} | |
# for t in agent.policy.estimator.estimators_: | |
# branch = np.array(t.decision_path(ext_state).todense(), dtype=bool) | |
# imp = t.tree_.impurity[branch[0]] | |
# for f, i in zip( | |
# t.tree_.feature[branch[0]][:-1], imp[:-1] - imp[1:] | |
# ): | |
# feature_importances.setdefault(f, []).append(i) | |
# # Line 8 Algorithm 2 | |
# desired_return -= reward | |
# # Line 9 Algorithm 2 | |
# desired_horizon = max(desired_horizon - 1, 1) | |
# summed_importances = [ | |
# sum(feature_importances.get(k, [0.001])) | |
# for k in range(len(feature_importances.keys())) | |
# ] | |
# epoch += 1 | |
# visualize_environment( | |
# state, | |
# env, | |
# # st.session_state.paused, # Use session state | |
# summed_importances, | |
# epoch, | |
# max_epoch, | |
# ) | |
# # reset_button, pause_play_button, next_button, save_button = ( | |
# # ) | |
# if done or truncated: | |
# state, _ = env.reset() | |
# desired_horizon = init_desired_horizon | |
# desired_return = init_desired_return | |
# epoch = 0 | |
# # step = False | |
# # Handle button clicks | |
# # if reset_button: | |
# # state, _ = env.reset() | |
# # desired_horizon = init_desired_horizon | |
# # desired_return = init_desired_return | |
# # epoch = 0 | |
# # elif pause_play_button: | |
# # st.session_state.paused = ( | |
# # not st.session_state.paused | |
# # ) # Toggle paused state | |
# # elif next_button and st.session_state.paused: | |
# # step = True | |
# # elif save_button: | |
# # # Save image and info using Streamlit | |
# # st.image( | |
# # env_render, caption=f"Epoch {epoch}", use_column_width=True | |
# # ) | |
# # st.write( | |
# # { | |
# # "state": {i: str(val) for i, val in enumerate(state)}, | |
# # "feature": { | |
# # i: str(val) for i, val in enumerate(summed_importances) | |
# # }, | |
# # "action": str(action), | |
# # "reward": str(reward), | |
# # "desired_return": str(desired_return + reward), | |
# # "desired_horizon": str(desired_horizon + 1), | |
# # } | |
# # ) | |
# env.close() | |
# env = "Acrobot-v1" | |
# desired_return = -79 | |
# desired_horizon = 82 | |
# max_epoch = 500 | |
# policy = SklearnPolicy.load("policy") | |
# hyper = AgentHyper( | |
# env, | |
# warm_up=0, | |
# ) | |
# agent = UpsideDownAgent(hyper, policy) | |
# run_visualization( | |
# env, agent, desired_return, desired_horizon, max_epoch, "data/viz_examples" | |
# ) | |