|
import argparse |
|
import os |
|
import sys |
|
import json |
|
|
|
|
|
project_root_path = os.path.dirname( |
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
) |
|
if project_root_path not in sys.path: |
|
sys.path.insert(0, project_root_path) |
|
if os.path.join(project_root_path, "chinatravel") not in sys.path: |
|
sys.path.insert(0, os.path.join(project_root_path, "chinatravel")) |
|
|
|
from chinatravel.evaluation.utils import load_json_file, validate_json |
|
|
|
from chinatravel.evaluation.commonsense_constraint import evaluate_commonsense_constraints |
|
from chinatravel.evaluation.hard_constraint import evaluate_hard_constraints |
|
from chinatravel.evaluation.preference import evaluate_preference |
|
|
|
METHOD_LIST = [ |
|
"example" "act_Deepseek_zeroshot", |
|
"act_GPT4o_zeroshot", |
|
"react_Deepseek_zeroshot", |
|
"react_GPT4o_zeroshot", |
|
"react_GLM4Plus_zeroshot", |
|
"react_Deepseek_oneshot", |
|
"react_GPT4o_oneshot", |
|
"naive_ns_Deepseek", |
|
"naive_ns_GPT4o", |
|
"naive_ns_GLM4Plus", |
|
] |
|
|
|
|
|
|
|
def load_result(args, query_index, verbose=False): |
|
|
|
def load_result_for_method(method): |
|
plans = {} |
|
for query_id in query_index: |
|
result_file = os.path.join( |
|
"../results/", method, "{}.json".format(query_id) |
|
) |
|
|
|
try: |
|
if os.path.exists(result_file): |
|
result = load_json_file(result_file) |
|
plans[query_id] = result |
|
else: |
|
plans[query_id] = {} |
|
except: |
|
plans[query_id] = {} |
|
return plans |
|
|
|
result = {} |
|
if args.method == "all": |
|
method_list = [] |
|
for mi in METHOD_LIST: |
|
if mi != "example": |
|
method_list.append(mi) |
|
else: |
|
method_list = [args.method] |
|
|
|
for method in method_list: |
|
result[method] = load_result_for_method(method) |
|
|
|
if verbose: |
|
print(result) |
|
|
|
return method_list, result |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--splits", "-s", type=str, default="example") |
|
parser.add_argument( |
|
"--method", "-m", type=str, default="example" |
|
) |
|
parser.add_argument("--preference", "-p", action="store_true", default=False) |
|
args = parser.parse_args() |
|
|
|
|
|
|
|
query_index, query_data = load_query(args) |
|
|
|
method_list, result_data = load_result(args, query_index) |
|
|
|
|
|
|
|
|
|
if not os.path.exists("eval_res/splits_{}/".format(args.splits)): |
|
os.makedirs("eval_res/splits_{}/".format(args.splits)) |
|
|
|
for method in method_list: |
|
|
|
if not os.path.exists("eval_res/splits_{}/{}/".format(args.splits, method)): |
|
os.makedirs("eval_res/splits_{}/{}/".format(args.splits, method)) |
|
|
|
macro_comm, micro_comm, result_agg = evaluate_commonsense_constraints( |
|
query_index, query_data, result_data[method], verbose=False |
|
) |
|
|
|
res_file = "eval_res/splits_{}/{}/commonsense.csv".format(args.splits, method) |
|
result_agg.to_csv(res_file, index=False) |
|
print("save to {}".format(res_file)) |
|
|
|
print("Method: {}".format(method)) |
|
print("Commonsense constraints:") |
|
print("micro accuracy: {}".format(micro_comm)) |
|
print("macro accuracy: {}".format(macro_comm)) |
|
|
|
|
|
commonsense_pass_info = result_agg.iloc[:, 1:] |
|
id_list = result_agg.iloc[:, 0].tolist() |
|
commonsense_pass = [ |
|
id_list[i] |
|
for i in range(len(id_list)) |
|
if commonsense_pass_info.iloc[i].sum() == 0 |
|
] |
|
|
|
|
|
print("Logical constraints:") |
|
macro_logi, micro_logi, result_agg = evaluate_hard_constraints( |
|
query_index, query_data, result_data[method], verbose=False |
|
) |
|
|
|
print("micro accuracy: {}".format(micro_logi)) |
|
print("macro accuracy: {}".format(macro_logi)) |
|
|
|
res_file = "eval_res/splits_{}/{}/logical.csv".format(args.splits, method) |
|
result_agg.to_csv(res_file, index=False) |
|
print("save to {}".format(res_file)) |
|
if args.preference: |
|
print("Preference:") |
|
result_agg = evaluate_preference( |
|
query_index, |
|
query_data, |
|
result_data[method], |
|
commonsense_pass, |
|
) |
|
|
|
res_file = "eval_res/splits_{}/{}/preference.csv".format( |
|
args.splits, method |
|
) |
|
result_agg.to_csv(res_file, index=False) |
|
print("save to {}".format(res_file)) |
|
|