import dm_env from absl import logging import rclpy from sensor_msgs.msg import Image, JointState from std_msgs.msg import Bool from std_msgs.msg import Int32 import numpy as np import threading import time # from visualize_utils import window import random from scipy.spatial.transform import Rotation from glob import glob import os import h5py import cv2 class AnubisRobotEnv: def __init__(self, hz=20, max_timestep=1000, task_name='', num_rollout=1): rclpy.init() # initialize ROS2 node self._node = rclpy.create_node('anubis_robot_env_node') self._subscriber_bringup() print('ROS2 node created') self.window = None self.start = False self.thread_done = False self.hz = hz # control frequency self.action_counter = 0 self.num_rollout = num_rollout self.rollout_counter = 0 self.lang_dict = { 'anubis_brush_to_pan' : 'insert the brush to the dustpan', 'anubis_carrot_to_bag' : 'pick up the carrot and put into the bag', 'anubis_towel_kirby' : 'take the towel off the kirby doll' } self.task_name = task_name self.instruction = self.lang_dict[self.task_name] self.data_list = glob(f'/home/rllab/workspace/jellyho/demo_collection/{self.task_name}/*.hdf5') self.overlay_img = None self.max_timestep = max_timestep self.init_action = JointState() self.init_action.position = [ 0.20620185010895048, 0.16183641523267392, 0.2277105000367078, -0.42093861525667453, 0.6546518510233503, -0.5770953981378887, 0.24739146627474096, -1.6, # 0.21136149716403216, -0.16027684481842075, 0.21879985782478842, 0.6606782591766969, -0.428768621033297, 0.2340722378552696, -0.569975345900049, -1.6 ] print('Initializing Anubis Robot Environment') self.thread = PeriodicThread(1/self.hz, self.timer_callback) self.thread.start() self.video_thread = PeriodicThread(1/30, self.video_timer_callback) self.video_thread.start() self.timer_thread = threading.Thread(target=rclpy.spin, args=(self._node,), daemon=True) self.timer_thread.start() print('Threads started') self.bringup_model() self.initialize() logging.set_verbosity(logging.INFO) logging.info('AnubisRobotEnv successfully initialized.') def init_robot_pose(self, demo): print('Initializing robot pose', demo % len(self.data_list)) root = h5py.File(self.data_list[demo % len(self.data_list)], 'r') first_action = root['action']['eef_pose'][0] self.publish_action(first_action) def initialize(self): self.curr_timestep = 0 if self.window is None: from visualize_utils import window self.window = window('ENV Observation', video_path=f'{self.model_name}-{self.task_name}', video_fps=30, video_size=(640, 480), show=False) else: self.window.init_video() self.send_demo(self.rollout_counter) self.init_robot_pose(self.rollout_counter) def reset(self): while not self.thread_done: time.sleep(0.01) continue self.thread_done = False return dm_env.restart(observation=self._observation()) def bringup_model(self): raise NotImplementedError def inference(self): raise NotImplementedError def ros_close(self): self.thread.stop() self.timer_thread.stop() self._node.destroy_node() rclpy.shutdown() def _subscriber_bringup(self): ''' Note: This function creates all the subscribers \ for reading joint and gripper states. ''' ###### Initial Setup ##### self.obs = {} self.action = {} ###### OBSERVATION ###### # image self._node.create_subscription(Image, '/camera_center/camera/color/image_raw', self.agentview_image_callback, 10) self.obs['agentview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8) self._node.create_subscription(Image, '/camera_right/camera/color/image_raw', self.rightview_image_callback, 10) self.obs['rightview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8) self._node.create_subscription(Image, '/camera_left/camera/color/image_raw', self.leftview_image_callback, 10) self.obs['leftview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8) # # arm pose states self._node.create_subscription(JointState, '/eef_pose', self.eef_pose_callback, 10) self.obs['eef_pose'] = np.zeros(shape=(20,), dtype=np.float64) # # gripper joint states self.obs['language_instruction'] = '' ##### TRIGGER ##### self._node.create_subscription(Bool, '/done', self.done_callback, 10) self.demo_pub = self._node.create_publisher(Int32, '/demo', 10) self.action_pub = self._node.create_publisher(JointState, '/teleop/eef_pose', 10) def send_demo(self, num): demo_msg = Int32() demo_msg.data = num self.demo_pub.publish(demo_msg) #### OBS ########### def agentview_image_callback(self, msg): self.obs['agentview_image'] = np.reshape(msg.data, (480, 640, 3)) def rightview_image_callback(self, msg): rightview = np.reshape(msg.data, (480, 640, 3)) self.obs['rightview_image'] = np.rot90(rightview, 2) def leftview_image_callback(self, msg): self.obs['leftview_image'] = np.reshape(msg.data, (480, 640, 3)) def eef_pose_callback(self, msg): recevied_data = np.array(msg.position) eef_pose_data = np.zeros(shape=(20,), dtype=np.float64) eef_pose_data[:3] = recevied_data[:3] eef_pose_data[3:9] = self.quat_to_6d(recevied_data[3:7], scalar_first=False) eef_pose_data[9] = recevied_data[7] eef_pose_data[10:13] = recevied_data[8:11] eef_pose_data[13:19] = self.quat_to_6d(recevied_data[11:15], scalar_first=False) eef_pose_data[19] = recevied_data[15] self.obs['eef_pose'] = eef_pose_data def send_action(self, act): if self.start: action_msg = JointState() # print('action msg', act) # print(act, act[9] < 0, act[-1] < 0) # act[9] = -1.6 if act[9] > 0 else 0.1 # act[-1] = -1.6 if act[-1] > 0 else 0.1 # Assign the NumPy array to the data field of the message action_msg_data = np.zeros(16) action_msg_data[0:3] = act[0:3] action_msg_data[3:7] = self.sixd_to_quat(act[3:9]) action_msg_data[7] = act[9] action_msg_data[8:11] = act[10:13] action_msg_data[11:15] = self.sixd_to_quat(act[13:19]) action_msg_data[15] = act[19] action_msg.position = action_msg_data.astype(float).tolist() self.action_pub.publish(action_msg) def publish_action(self, action): action_msg = JointState() # Assign the NumPy array to the data field of the message # Squeeze the action to remove any extra dimensions action = action.squeeze() action_msg_data = np.zeros(16) action_msg_data[0:3] = action[0:3] action_msg_data[3:7] = self.sixd_to_quat(action[3:9]) action_msg_data[7] = action[9] action_msg_data[8:11] = action[10:13] action_msg_data[11:15] = self.sixd_to_quat(action[13:19]) action_msg_data[15] = action[19] action_msg.position = action_msg_data.astype(float).tolist() self.action_pub.publish(action_msg) def done_callback(self, msg): if not self.start: print('Inference & Video Recording Start') self.start = True self.window.video_start() else: self.start = False self.action_counter = 0 self.rollout_counter += 1 if self.window.video_recording: self.window.video_stop() self.initialize() print('Next Inference Ready') def timer_callback(self): if self.start: self.inference() self.curr_timestep += 1 if self.curr_timestep >= self.max_timestep: print("Max timestep reached, resetting environment.") self.start = False if self.window.video_recording: self.window.video_stop() self.rollout_counter += 1 self.action_counter = 0 self.curr_timestep = 0 self.initialize() self.thread_done = True def video_timer_callback(self): if self.start and self.window.video_recording: self.window.video_write() def quat_to_6d(self, quat, scalar_first=False): r = Rotation.from_quat(quat, scalar_first=scalar_first) mat = r.as_matrix() return mat[:, :2].flatten() def sixd_to_quat(self, sixd, scalar_first=False): mat = np.zeros((3, 3)) mat[:, :2] = sixd.reshape(3, 2) mat[:, 2] = np.cross(mat[:, 0], mat[:, 1]) r = Rotation.from_matrix(mat) return r.as_quat(scalar_first=scalar_first) def ros_close(self): if self.window.video_recording: self.window.video_stop() self.thread.stop() self.video_thread.stop() self.timer_thread.stop() self._node.destroy_node() rclpy.shutdown() class PeriodicThread(threading.Thread): def __init__(self, interval, function, *args, **kwargs): super().__init__() self.interval = interval self.function = function self.args = args self.kwargs = kwargs self.stop_event = threading.Event() self._lock = threading.Lock() def run(self): while not self.stop_event.is_set(): start_time = time.time() self.function(*self.args, **self.kwargs) elapsed_time = time.time() - start_time sleep_time = max(0, self.interval - elapsed_time) time.sleep(sleep_time) def stop(self): self.stop_event.set() def change_period(self, new_interval): with self._lock: self.interval = new_interval