|  | from collections import defaultdict | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  | from prismatic.vla.action_tokenizer import ActionTokenizer | 
					
						
						|  | from transformers import AutoTokenizer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Solver: | 
					
						
						|  | def __init__(self, action_tokenizer=None, verbose=True) -> None: | 
					
						
						|  | self.verbose = verbose | 
					
						
						|  | self.action_tokenizer = action_tokenizer | 
					
						
						|  | self.coordinates_key = "NEXT GRIPPER:" | 
					
						
						|  | self.movement_key = "MOVEMENT:" | 
					
						
						|  | self.policy_key = "POLICIES:" | 
					
						
						|  |  | 
					
						
						|  | def compare_movement(self, pred_pos, label_pos): | 
					
						
						|  |  | 
					
						
						|  | dist = np.sum(np.abs(pred_pos - label_pos)) | 
					
						
						|  | relative_dist = np.sum(np.abs(dist / label_pos)) | 
					
						
						|  | return dist, relative_dist, dist == 0 | 
					
						
						|  |  | 
					
						
						|  | def compare_policy(self, pred_pol, label_pol): | 
					
						
						|  | dist = 0 | 
					
						
						|  | cnt = 0 | 
					
						
						|  | for i in range(min(len(label_pol), len(pred_pol))): | 
					
						
						|  | for j in range(len(label_pol[0])): | 
					
						
						|  | dist += label_pol[i][j] == pred_pol[i][j] | 
					
						
						|  | cnt += 1 | 
					
						
						|  | assert cnt % 7 == 0 | 
					
						
						|  | return dist / cnt | 
					
						
						|  |  | 
					
						
						|  | def extract_2d_coordinates(self, text): | 
					
						
						|  | try: | 
					
						
						|  | coordinates_index = text.index(self.coordinates_key) + len(self.coordinates_key) | 
					
						
						|  | coord = text[coordinates_index:] | 
					
						
						|  | coord = [o for o in coord.split("\n") if len(o.strip()) != 0] | 
					
						
						|  | coord = eval(coord[0].strip()) | 
					
						
						|  | except Exception: | 
					
						
						|  | coord = [0, 0] | 
					
						
						|  | return coord | 
					
						
						|  |  | 
					
						
						|  | def extract_movement_plan(self, text): | 
					
						
						|  | require_unorm = None | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | movement_index = text.index(self.movement_key) + len(self.movement_key) | 
					
						
						|  | movement_level = text[movement_index:] | 
					
						
						|  | movement_level = [o for o in movement_level.split("\n") if len(o.strip()) != 0] | 
					
						
						|  | movement_level = movement_level[0].strip() | 
					
						
						|  |  | 
					
						
						|  | if "gripper" not in movement_level: | 
					
						
						|  | require_unorm = True | 
					
						
						|  | movement_token_ids = self.action_tokenizer.tokenizer(movement_level, add_special_tokens=False).input_ids | 
					
						
						|  | movement_norm = self.action_tokenizer.decode_token_ids_to_actions(np.array(movement_token_ids)) | 
					
						
						|  | movement_norm = movement_norm[1:8] | 
					
						
						|  | assert len(movement_norm) == 7 | 
					
						
						|  | else: | 
					
						
						|  | require_unorm = False | 
					
						
						|  | movement_level = [o for o in movement_level.split(";") if len(o) > 0] | 
					
						
						|  | movement_level = movement_level[:7] | 
					
						
						|  |  | 
					
						
						|  | position = defaultdict(int) | 
					
						
						|  | movement_to_pos = dict( | 
					
						
						|  | move_backward=(-1, "y"), | 
					
						
						|  | move_forward=(1, "y"), | 
					
						
						|  | move_right=(-1, "x"), | 
					
						
						|  | move_left=(1, "x"), | 
					
						
						|  | move_downward=(-1, "z"), | 
					
						
						|  | move_upward=(1, "z"), | 
					
						
						|  | roll_downward=(-1, "ox"), | 
					
						
						|  | roll_upward=(1, "ox"), | 
					
						
						|  | swing_downward=(-1, "ox"), | 
					
						
						|  | swing_upward=(1, "ox"), | 
					
						
						|  | pitch_downward=(-1, "oy"), | 
					
						
						|  | pitch_upward=(1, "oy"), | 
					
						
						|  | yaw_downward=(-1, "oz"), | 
					
						
						|  | yaw_upward=(1, "oz"), | 
					
						
						|  | rotate_clockwise=(-1, "oz"), | 
					
						
						|  | rotate_counterclockwise=(1, "oz"), | 
					
						
						|  | close_gripper=(-1, "grip"), | 
					
						
						|  | open_gripper=(1, "grip"), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | for ml in movement_level: | 
					
						
						|  | direction = "_".join(ml.split()[:2]) | 
					
						
						|  | sign, axis = movement_to_pos[direction] | 
					
						
						|  | scale = 1 | 
					
						
						|  | if "o" in axis: | 
					
						
						|  | scale = scale * 1e-3 | 
					
						
						|  | elif "grip" in axis: | 
					
						
						|  | scale = scale | 
					
						
						|  | else: | 
					
						
						|  | scale = scale / 180 * np.pi | 
					
						
						|  |  | 
					
						
						|  | if "grip" in axis: | 
					
						
						|  | level = round("open" in ml) | 
					
						
						|  | else: | 
					
						
						|  | level = int(ml.split()[2]) | 
					
						
						|  |  | 
					
						
						|  | position[axis] += sign * scale * level | 
					
						
						|  | movement_norm = [position[idx] for idx in ["x", "y", "z", "ox", "oy", "oz", "grip"]] | 
					
						
						|  |  | 
					
						
						|  | except: | 
					
						
						|  | movement_norm = [-100] * 7 | 
					
						
						|  |  | 
					
						
						|  | return require_unorm, np.array(movement_norm) | 
					
						
						|  |  | 
					
						
						|  | def extract_action_policies(self, text): | 
					
						
						|  | try: | 
					
						
						|  | if self.policy_key in text: | 
					
						
						|  |  | 
					
						
						|  | policy_index = text.index(self.policy_key) + len(self.policy_key) | 
					
						
						|  | policy = text[policy_index:] | 
					
						
						|  | remain_text = text[: text.index(self.policy_key)] | 
					
						
						|  | policies = [o for o in policy.split("\n") if len(o.strip()) != 0] | 
					
						
						|  | policies = policies[0].strip() | 
					
						
						|  | else: | 
					
						
						|  | policies = text.strip() | 
					
						
						|  | remain_text = "" | 
					
						
						|  |  | 
					
						
						|  | policies_num = [] | 
					
						
						|  | for policy_text in policies.split(";"): | 
					
						
						|  | policy_token = self.action_tokenizer.tokenizer(policy_text, add_special_tokens=False).input_ids | 
					
						
						|  | action_policy = self.action_tokenizer.decode_token_ids_to_actions(np.array(policy_token)) | 
					
						
						|  |  | 
					
						
						|  | action_policy = action_policy[1:] | 
					
						
						|  | action_policy = action_policy[:7] | 
					
						
						|  |  | 
					
						
						|  | if len(action_policy) != 7: | 
					
						
						|  | action_policy = [0] * 7 | 
					
						
						|  | policies_num.append(action_policy.tolist()) | 
					
						
						|  |  | 
					
						
						|  | except: | 
					
						
						|  | policies_num = [[0] * 7] | 
					
						
						|  | remain_text = text | 
					
						
						|  |  | 
					
						
						|  | return policies_num, remain_text | 
					
						
						|  |  | 
					
						
						|  | def evaluate_single(self, ground_truth, prediction, verbose=False): | 
					
						
						|  | gt_policies, ground_truth = self.extract_action_policies(ground_truth) | 
					
						
						|  | pred_policies, prediction = self.extract_action_policies(prediction) | 
					
						
						|  |  | 
					
						
						|  | _, pred_movement = self.extract_movement_plan(prediction) | 
					
						
						|  | _, gt_movement = self.extract_movement_plan(ground_truth) | 
					
						
						|  |  | 
					
						
						|  | dist, relative_dist, _ = self.compare_movement(label_pos=gt_movement, pred_pos=pred_movement) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | next_state_score = 0 | 
					
						
						|  |  | 
					
						
						|  | acc = self.compare_policy(label_pol=gt_policies, pred_pol=pred_policies) | 
					
						
						|  |  | 
					
						
						|  | return next_state_score, acc, dist, relative_dist, pred_policies, gt_policies | 
					
						
						|  |  | 
					
						
						|  | def evaluate_batch(self, batch_gt, batch_pred, verbose=False): | 
					
						
						|  | state_acc_ls = [] | 
					
						
						|  | action_acc_ls = [] | 
					
						
						|  | L1_loss_ls = [] | 
					
						
						|  | relative_L1_loss_ls = [] | 
					
						
						|  | pred_policies_ls = [] | 
					
						
						|  | gt_policies_ls = [] | 
					
						
						|  | for i in range(len(batch_gt)): | 
					
						
						|  | ground_truth = batch_gt[i] | 
					
						
						|  | prediction = batch_pred[i] | 
					
						
						|  | next_state_score, action_policy_score, L1_dist, relative_L1_dist, pred_policies, gt_policies = ( | 
					
						
						|  | self.evaluate_single(ground_truth, prediction) | 
					
						
						|  | ) | 
					
						
						|  | state_acc_ls.append(next_state_score) | 
					
						
						|  | action_acc_ls.append(action_policy_score) | 
					
						
						|  | L1_loss_ls.append(L1_dist) | 
					
						
						|  | relative_L1_loss_ls.append(relative_L1_dist) | 
					
						
						|  | pred_policies_ls.append(pred_policies) | 
					
						
						|  | gt_policies_ls.append(gt_policies) | 
					
						
						|  | if verbose: | 
					
						
						|  | print(f"Ground Truth:\n\n {ground_truth}") | 
					
						
						|  | print() | 
					
						
						|  | print(f"prediction:\n\n {prediction}") | 
					
						
						|  | print() | 
					
						
						|  | print(f"Ground Truth Policies:\n\n {gt_policies}") | 
					
						
						|  | print(f"prediction policies:\n\n {pred_policies}") | 
					
						
						|  | print("*" * 40) | 
					
						
						|  |  | 
					
						
						|  | return state_acc_ls, action_acc_ls, L1_loss_ls, relative_L1_loss_ls, pred_policies_ls, gt_policies_ls | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", model_max_length=2048, padding_side="right") | 
					
						
						|  | action_tokenizer = ActionTokenizer(tokenizer) | 
					
						
						|  | solver = Solver(action_tokenizer) | 
					
						
						|  |  | 
					
						
						|  |  |