udrl / app.py
vimmoos@Thor
base gym env
c6a28ec
raw
history blame
8.07 kB
import streamlit as st
import gymnasium as gym
import numpy as np
from PIL import Image
import time
# Initialize session state variables if they don't exist
if "env" not in st.session_state:
st.session_state.env = gym.make("LunarLander-v2", render_mode="rgb_array")
st.session_state.env.reset()
st.session_state.frame = st.session_state.env.render()
if "paused" not in st.session_state:
st.session_state.paused = False
# Function to reset the environment
def reset_environment():
st.session_state.env.reset()
# Function to toggle pause state
def toggle_pause():
st.session_state.paused = not st.session_state.paused
# Create the Streamlit app
st.title("Gymnasium Environment Viewer")
# Add control buttons in a horizontal layout
col1, col2 = st.columns(2)
with col1:
st.button("Reset Environment", on_click=reset_environment)
with col2:
if st.session_state.paused:
st.button("Resume", on_click=toggle_pause)
else:
st.button("Pause", on_click=toggle_pause)
# Create a placeholder for the image
image_placeholder = st.empty()
# Create a container for environment info
sidebar_container = st.sidebar.container()
# Main simulation loop using rerun
if not st.session_state.paused:
# Take a random action
action = st.session_state.env.action_space.sample()
observation, reward, terminated, truncated, info = (
st.session_state.env.step(action)
)
# Render the environment
st.session_state.frame = st.session_state.env.render()
# Reset if the episode is done
if terminated or truncated:
st.session_state.env.reset()
# Display the frame
if st.session_state.paused:
image_placeholder.image(
st.session_state.frame,
caption="Environment Visualization (Paused)",
use_column_width=True,
)
else:
image_placeholder.image(
st.session_state.frame,
caption="Environment Visualization",
use_column_width=True,
)
# Display some information about the environment
with sidebar_container:
st.header("Environment Info")
st.write(f"Action Space: {st.session_state.env.action_space}")
st.write(f"Observation Space: {st.session_state.env.observation_space}")
# Add auto-refresh logic
if not st.session_state.paused:
time.sleep(0.1) # Add a small delay to control refresh rate
st.rerun()
# fig, ax = plt.subplots()
# ax.imshow(env.render())
# st.pyplot(fig)
# st.image(env.render())
# import gymnasium as gym
# import streamlit as st
# import numpy as np
# from udrl.policies import SklearnPolicy
# from udrl.agent import UpsideDownAgent, AgentHyper
# from pathlib import Path
# # import json
# def normalize_value(value, is_bounded, low=None, high=None):
# return (value - low) / (high - low)
# def visualize_environment(
# state,
# env,
# # paused,
# feature_importances,
# epoch,
# max_epoch=200,
# ):
# st.image(env.render())
# st.image(e)
# # Render the Gym environment
# # env_render = env.render()
# # # Display the rendered image using Streamlit
# # st.image(env_render, caption=f"Epoch {epoch}", use_column_width=True)
# # Display feature importances using Streamlit metrics
# # cols = st.columns(len(feature_importances))
# # for i, col in enumerate(cols):
# # col.metric(
# # label=f"Importance {i}", value=f"{feature_importances[i]:.2f}"
# # )
# # Create buttons using Streamlit
# # reset_button = st.button("Reset")
# # pause_play_button = st.button("Pause" if not paused else "Play")
# # next_button = st.button("Next")
# # save_button = st.button("Save")
# # return reset_button, pause_play_button, next_button, save_button
# def run_visualization(
# env_name,
# agent,
# init_desired_return,
# init_desired_horizon,
# max_epoch,
# base_path,
# ):
# # base_path = (
# # Path(base_path) / env_name / agent.policy.estimator.__str__()[:-2]
# # )
# # base_path.mkdir(parents=True, exist_ok=True)
# desired_return = init_desired_return
# desired_horizon = init_desired_horizon
# # Initialize the Gym environment
# env = gym.make(env_name, render_mode="rgb_array")
# state, _ = env.reset()
# epoch = 0
# # save_index = 0
# # paused = False
# # step = False
# # # Use Streamlit session state to manage paused state
# # if "paused" not in st.session_state:
# # st.session_state.paused = False
# while True:
# # Render and display the environment
# env_render = env.render()
# # if not st.session_state.pausedor step:
# command = np.array(
# [
# desired_return * agent.conf.return_scale,
# desired_horizon * agent.conf.horizon_scale,
# ]
# )
# command = np.expand_dims(command, axis=0)
# state = np.expand_dims(state, axis=0)
# action = agent.policy(state, command, True)
# ext_state = np.concatenate((state, command), axis=1)
# state, reward, done, truncated, info = env.step(action)
# feature_importances = {idx: [] for idx in range(ext_state.shape[1])}
# for t in agent.policy.estimator.estimators_:
# branch = np.array(t.decision_path(ext_state).todense(), dtype=bool)
# imp = t.tree_.impurity[branch[0]]
# for f, i in zip(
# t.tree_.feature[branch[0]][:-1], imp[:-1] - imp[1:]
# ):
# feature_importances.setdefault(f, []).append(i)
# # Line 8 Algorithm 2
# desired_return -= reward
# # Line 9 Algorithm 2
# desired_horizon = max(desired_horizon - 1, 1)
# summed_importances = [
# sum(feature_importances.get(k, [0.001]))
# for k in range(len(feature_importances.keys()))
# ]
# epoch += 1
# visualize_environment(
# state,
# env,
# # st.session_state.paused, # Use session state
# summed_importances,
# epoch,
# max_epoch,
# )
# # reset_button, pause_play_button, next_button, save_button = (
# # )
# if done or truncated:
# state, _ = env.reset()
# desired_horizon = init_desired_horizon
# desired_return = init_desired_return
# epoch = 0
# # step = False
# # Handle button clicks
# # if reset_button:
# # state, _ = env.reset()
# # desired_horizon = init_desired_horizon
# # desired_return = init_desired_return
# # epoch = 0
# # elif pause_play_button:
# # st.session_state.paused = (
# # not st.session_state.paused
# # ) # Toggle paused state
# # elif next_button and st.session_state.paused:
# # step = True
# # elif save_button:
# # # Save image and info using Streamlit
# # st.image(
# # env_render, caption=f"Epoch {epoch}", use_column_width=True
# # )
# # st.write(
# # {
# # "state": {i: str(val) for i, val in enumerate(state)},
# # "feature": {
# # i: str(val) for i, val in enumerate(summed_importances)
# # },
# # "action": str(action),
# # "reward": str(reward),
# # "desired_return": str(desired_return + reward),
# # "desired_horizon": str(desired_horizon + 1),
# # }
# # )
# env.close()
# env = "Acrobot-v1"
# desired_return = -79
# desired_horizon = 82
# max_epoch = 500
# policy = SklearnPolicy.load("policy")
# hyper = AgentHyper(
# env,
# warm_up=0,
# )
# agent = UpsideDownAgent(hyper, policy)
# run_visualization(
# env, agent, desired_return, desired_horizon, max_epoch, "data/viz_examples"
# )