ct_test / eval_exp.py
Cbphcr's picture
Init repo
dfe35ce verified
import os
import jsonschema
import pandas as pd
from jsonschema import validate
from chinatravel.data.load_datasets import load_query
from chinatravel.evaluation.utils import load_json_file
from chinatravel.symbol_verification.commonsense_constraint import (
Is_time_correct,
Is_space_correct,
Is_hotels_correct,
Is_transport_correct,
Is_attractions_correct,
Is_restaurants_correct,
Is_intercity_transport_correct,
)
from chinatravel.symbol_verification.hard_constraint import evaluate_constraints_py
os.environ["HF_DATASETS_OFFLINE"] = "1"
def load_result(result_dir, query_index):
result = {}
matched_uid = []
unmatched_uid = []
for query_id in query_index:
result_file = os.path.join(result_dir, f"{query_id}.json")
# print(f"Loading result for {query_id} from {result_file}")
try:
if os.path.exists(result_file):
result[query_id] = load_json_file(result_file)
matched_uid.append(query_id)
else:
result[query_id] = {}
unmatched_uid.append(query_id)
except Exception:
result[query_id] = {}
unmatched_uid.append(query_id)
return result, matched_uid, unmatched_uid
def validate_json(json_data, schema):
try:
validate(instance=json_data, schema=schema)
return True
except jsonschema.exceptions.ValidationError as e:
return False
def evaluate_schema_constraints(data_index, plan_json_dict, schema, result):
total_correct = 0
result_agg = pd.DataFrame(columns=["data_id", "schema"])
result_agg["data_id"] = data_index
pass_id = []
total = len(data_index)
for ii, idx in enumerate(data_index):
plan_json = plan_json_dict[idx]
succ_flag = 0
try:
if validate_json(plan_json, schema):
succ_flag = 1
pass_id.append(idx)
except Exception as e:
pass
result_agg.loc[ii, "schema"] = succ_flag
total_correct += succ_flag
yield {
"stage": "schema",
"progress": (ii + 1) / total * 100,
}
total_count = len(data_index)
result["DR"] = total_correct / total_count * 100
result["S_pass_id"] = pass_id
"""
Constraints:
Available
1. Intercity transport information exsits and is objective: ID, time, startpos and endpos need to be correct.
2. Attractions
3. Hotels
4. Restaurants
5. transportation
6. Times
7. space
"""
def evaluate_commonsense_constraints(
data_index, symbolic_input_dict, plan_json_dict, result
):
func_list = [
Is_intercity_transport_correct,
Is_attractions_correct,
Is_hotels_correct,
Is_restaurants_correct,
Is_transport_correct,
Is_time_correct,
Is_space_correct,
]
result_agg = pd.DataFrame(columns=["data_id"])
result_agg["data_id"] = data_index
individual_succ = 0
pass_id = []
total = len(data_index)
for ii, idx in enumerate(data_index):
symbolic_input, plan_json = symbolic_input_dict[idx], plan_json_dict[idx]
try:
for func in func_list:
table_res, _ = func(symbolic_input, plan_json, verbose=False)
for colum_i in table_res.columns:
if colum_i not in result_agg.columns:
result_agg[colum_i] = 0
result_agg.loc[ii, colum_i] = table_res[colum_i].loc[0]
if result_agg.loc[ii][1:].sum() == 0:
individual_succ += 1
pass_id.append(idx)
except Exception as message:
pass
yield {
"stage": "commonsense",
"progress": (ii + 1) / total * 100,
}
total_count = len(data_index)
micro_accuracy = 1.0 - result_agg.drop("data_id", axis=1).sum().sum() / (
total_count * (result_agg.shape[1] - 1)
)
macro_accuracy = individual_succ / total_count
result["EPR_micro"] = micro_accuracy * 100
result["EPR_macro"] = macro_accuracy * 100
result["E_pass_id"] = pass_id
def evaluate_hard_constraints_v2(
data_index, symbolic_input_dict, plan_json_dict, env_pass_id, result: dict
):
max_logic_num = 0
for idx in data_index:
max_logic_num = max(
max_logic_num, len(symbolic_input_dict[idx]["hard_logic_py"])
)
columns = ["data_id"]
for i in range(max_logic_num):
columns.append(f"logic_py_{i}")
result_agg = pd.DataFrame(columns=columns)
for col_i in result_agg.columns[1:]:
result_agg[col_i] = 0
macro_count, macro_succ_count = 0, 0
micro_count, micro_succ_count = 0, 0
conditional_micro_succ_count, conditional_macro_succ_count = 0, 0
results = []
passed_id = []
total = len(data_index)
for ii, idx in enumerate(data_index):
symbolic_input, plan_json = symbolic_input_dict[idx], plan_json_dict[idx]
result_ii = evaluate_constraints_py(
symbolic_input["hard_logic_py"], plan_json, verbose=False
)
results.append(result_ii)
dict_ii = {}
succ_c_sum = 0
for logic_i in range(len(symbolic_input["hard_logic_py"])):
dict_ii[f"logic_py_{logic_i}"] = int(result_ii[logic_i])
succ_c_sum += int(result_ii[logic_i])
macro_count += 1
macro_succ_count += succ_c_sum == len(dict_ii)
micro_count += len(dict_ii)
micro_succ_count += succ_c_sum
if idx in env_pass_id:
conditional_micro_succ_count += succ_c_sum
conditional_macro_succ_count += succ_c_sum == len(dict_ii)
if succ_c_sum == len(dict_ii):
passed_id.append(idx)
dict_ii["data_id"] = idx
result_agg.loc[ii] = pd.Series(dict_ii)
yield {
"stage": "logic",
"progress": (ii + 1) / total * 100,
}
macro = macro_succ_count / macro_count
micro = micro_succ_count / micro_count
c_marco = conditional_macro_succ_count / macro_count
c_micro = conditional_micro_succ_count / micro_count
result["LPR_micro"] = micro * 100
result["LPR_macro"] = macro * 100
result["C-LPR"] = c_micro * 100
result["L_pass_id"] = passed_id
def evaluate(args, result):
eval_result = {}
query_index, query_data = load_query(args)
result_data, matched_uid, unmatched_uid = load_result(
args.result_dir, query_index=query_index
)
eval_result["matched_uid"] = matched_uid
eval_result["unmatched_uid"] = unmatched_uid
schema_file_path = "chinatravel/evaluation/output_schema.json"
schema = load_json_file(schema_file_path)
# schema pass rate
yield from evaluate_schema_constraints(
query_index, result_data, schema=schema, result=eval_result
)
# commonsense pass rate
yield from evaluate_commonsense_constraints(
query_index, query_data, result_data, result=eval_result
)
# hard logic pass rate
yield from evaluate_hard_constraints_v2(
query_index,
query_data,
result_data,
env_pass_id=eval_result.get("E_pass_id", []),
result=eval_result,
)
# all pass rate
# all_pass_id = list(
# set(schema_pass_id) & set(commonsense_pass_id) & set(logi_pass_id)
# )
all_pass_id = set(query_index) # Initialize with all query IDs
for key in eval_result:
if "pass_id" in key:
all_pass_id.intersection_update(set(eval_result[key]))
eval_result["FPR"] = len(all_pass_id) / len(query_index) * 100
# del pass_id
del_keys = [key for key in eval_result if "pass_id" in key]
for key in del_keys:
del eval_result[key]
result = eval_result
yield {
"stage": "final",
"progress": 100,
"result": result,
}