Cbphcr's picture
Init repo
dfe35ce verified
import argparse
import os
import sys
import json
project_root_path = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
if project_root_path not in sys.path:
sys.path.insert(0, project_root_path)
if os.path.join(project_root_path, "chinatravel") not in sys.path:
sys.path.insert(0, os.path.join(project_root_path, "chinatravel"))
from chinatravel.evaluation.utils import load_json_file, validate_json
from chinatravel.evaluation.commonsense_constraint import evaluate_commonsense_constraints
from chinatravel.evaluation.hard_constraint import evaluate_hard_constraints
from chinatravel.evaluation.preference import evaluate_preference
METHOD_LIST = [
"example" "act_Deepseek_zeroshot",
"act_GPT4o_zeroshot",
"react_Deepseek_zeroshot",
"react_GPT4o_zeroshot",
"react_GLM4Plus_zeroshot",
"react_Deepseek_oneshot",
"react_GPT4o_oneshot",
"naive_ns_Deepseek",
"naive_ns_GPT4o",
"naive_ns_GLM4Plus",
]
def load_result(args, query_index, verbose=False):
def load_result_for_method(method):
plans = {}
for query_id in query_index:
result_file = os.path.join(
"../results/", method, "{}.json".format(query_id)
)
try:
if os.path.exists(result_file):
result = load_json_file(result_file)
plans[query_id] = result
else:
plans[query_id] = {}
except:
plans[query_id] = {}
return plans
result = {}
if args.method == "all":
method_list = []
for mi in METHOD_LIST:
if mi != "example":
method_list.append(mi)
else:
method_list = [args.method]
for method in method_list:
result[method] = load_result_for_method(method)
if verbose:
print(result)
return method_list, result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--splits", "-s", type=str, default="example")
parser.add_argument(
"--method", "-m", type=str, default="example"
) # , choices=METHOD_LIST)
parser.add_argument("--preference", "-p", action="store_true", default=False)
args = parser.parse_args()
# print(args.splits)
query_index, query_data = load_query(args)
method_list, result_data = load_result(args, query_index)
# print(query_data.keys())
# print(result_data.keys())
if not os.path.exists("eval_res/splits_{}/".format(args.splits)):
os.makedirs("eval_res/splits_{}/".format(args.splits))
for method in method_list:
if not os.path.exists("eval_res/splits_{}/{}/".format(args.splits, method)):
os.makedirs("eval_res/splits_{}/{}/".format(args.splits, method))
macro_comm, micro_comm, result_agg = evaluate_commonsense_constraints(
query_index, query_data, result_data[method], verbose=False
)
res_file = "eval_res/splits_{}/{}/commonsense.csv".format(args.splits, method)
result_agg.to_csv(res_file, index=False)
print("save to {}".format(res_file))
print("Method: {}".format(method))
print("Commonsense constraints:")
print("micro accuracy: {}".format(micro_comm))
print("macro accuracy: {}".format(macro_comm))
# record the index of the queries that pass the commonsense constraints
commonsense_pass_info = result_agg.iloc[:, 1:]
id_list = result_agg.iloc[:, 0].tolist()
commonsense_pass = [
id_list[i]
for i in range(len(id_list))
if commonsense_pass_info.iloc[i].sum() == 0
]
# record end
print("Logical constraints:")
macro_logi, micro_logi, result_agg = evaluate_hard_constraints(
query_index, query_data, result_data[method], verbose=False
)
print("micro accuracy: {}".format(micro_logi))
print("macro accuracy: {}".format(macro_logi))
res_file = "eval_res/splits_{}/{}/logical.csv".format(args.splits, method)
result_agg.to_csv(res_file, index=False)
print("save to {}".format(res_file))
if args.preference:
print("Preference:")
result_agg = evaluate_preference(
query_index,
query_data,
result_data[method],
commonsense_pass,
)
res_file = "eval_res/splits_{}/{}/preference.csv".format(
args.splits, method
)
result_agg.to_csv(res_file, index=False)
print("save to {}".format(res_file))