diff options
author | ShelbySolomonDarnell | 2024-08-16 17:26:14 +0300 |
---|---|---|
committer | ShelbySolomonDarnell | 2024-08-16 17:26:14 +0300 |
commit | 50f0ed1d717d6877cb0562b1f2d54f0f242312d9 (patch) | |
tree | 18123164e626cd91d6c79d205532ccd3cc5fa900 /gnqa/paper2_eval/src | |
parent | 6e4a45e92f25b1084b34404991706e6c4ce3fd6f (diff) | |
download | gn-ai-50f0ed1d717d6877cb0562b1f2d54f0f242312d9.tar.gz |
added paper2_eval
Diffstat (limited to 'gnqa/paper2_eval/src')
-rw-r--r-- | gnqa/paper2_eval/src/parsejson.py | 63 | ||||
-rw-r--r-- | gnqa/paper2_eval/src/retrieve_context.py | 171 |
2 files changed, 234 insertions, 0 deletions
diff --git a/gnqa/paper2_eval/src/parsejson.py b/gnqa/paper2_eval/src/parsejson.py new file mode 100644 index 0000000..b49a898 --- /dev/null +++ b/gnqa/paper2_eval/src/parsejson.py @@ -0,0 +1,63 @@ +import json +import sys + + +def iterate_json(obj, thedict): + if isinstance(obj, dict): + for key, val in obj.items(): + if (key == "text"): + thedict["contexts"].append(val.replace("\n", " ").strip()) + elif (key == "answer"): + thedict["answer"] = val.replace("\n", " ").strip() + elif (key == "question"): + thedict["question"] = val.replace("\n", " ").strip() + else: + if (len(obj.items()) == 1 ): + print(key, " --> ", val) + iterate_json(val, thedict) + elif isinstance(obj, list): + for item in obj: + iterate_json(item, thedict) + +def create_dataset_from_files(tag, file_name, rag_out): + for the_file in file_name[tag]: + ragas_output = { + "contexts": [], + "answer": "", + "question": ""} + #print(the_file) + with open("./data/"+the_file, "r") as r_file: + data_file = json.load(r_file) + iterate_json(data_file, ragas_output) + rag_out["answer"].append(ragas_output["answer"]) + rag_out["question"].append(ragas_output["question"]) + rag_out["contexts"].append(ragas_output["contexts"]) + +def create_resultset_from_file(file_name): + with open("./data/"+the_file, "r") as r_file: + data_file = json.load(r_file) + iterate_json(data_file, ragas_output) + + +file_list_tag = str(sys.argv[1]) +read_file = str(sys.argv[2]) # e.g. doc_list.json +outp_file = str(sys.argv[3]) + +rag_out = { + "question": [], + "answer": [], + "contexts": [] +} + +cntxt_lst = [] + +# this should be a json file with a list of input files and an output file +with open(read_file, "r") as r_file: + file_lst = json.load(r_file) + +create_dataset_from_files(file_list_tag, file_lst, rag_out) + +with open(outp_file, "a") as the_data: + #json.dump(ragas_output, the_data) + the_data.write(",\n") + the_data.write(json.dumps(rag_out, indent=2)) diff --git a/gnqa/paper2_eval/src/retrieve_context.py b/gnqa/paper2_eval/src/retrieve_context.py new file mode 100644 index 0000000..58b9d47 --- /dev/null +++ b/gnqa/paper2_eval/src/retrieve_context.py @@ -0,0 +1,171 @@ +import os +import sys +import json +import time +import configparser +import apis.process as gnqa +from apis.process import get_gnqa, get_response_from_taskid + + +config = configparser.ConfigParser() +config.read('_config.cfg') + +''' +the refs object is a list of items containing doc_id, bibInfo, and comboTxt +We only need comboTxt +''' +def simplifyContext(refs): + result = [] + for item in refs: + combo_text = item['comboTxt'] + combo_text = combo_text.replace('\n','') + combo_text = combo_text.replace('\t','') + result.append(combo_text) + return result + +def writeDatasetFile(responses, outp_file): + print(outp_file) + output = json.dumps(responses, indent=2) + if os.path.exists(outp_file): + with open(outp_file, "a") as the_data: + the_data.write('' + output) + else: + with open(outp_file, "a") as the_data: + the_data.write(output) + + +def reset_responses(): + return { + 'question': [], + 'answer': [], + 'contexts': [], + 'task_id': [] + } + +def parse_document(jsonfile): + print('Parse document') + for item in jsonfile: + level = item['level'] + domain = item['domain'] + query_lst = item['query'] + create_datasets(query_lst, domain, level) + +def create_datasets(query_list, domain, level): + print('Creating dataset') + responses = reset_responses() + ndx = 0 + for query in query_list: + print(query) + task_id, answer, refs = get_gnqa(query, config['key.api']['fahamuai'], config['DEFAULT']['DATA_DIR']) + responses['question'].append(query) + responses['answer'].append(answer) + responses['task_id'].append(task_id) + responses['contexts'].append(simplifyContext(refs)) + ndx+=1 + time.sleep(10) # sleep a bit to not overtask the api + if ndx % 5 == 0: + print('Will print to file number {0}'.format(int(ndx/5))) + outp_file = '{0}dataset_{1}_{2}_{3}.json'.format(config['out.response.dataset']['gpt4o_dir'],level,domain,str(int(ndx/5))) + writeDatasetFile(responses, outp_file) + responses = reset_responses() + if len(responses['question']) > 0: + outp_file = '{0}dataset_{1}_{2}_{3}.json'.format(config['out.response.dataset']['gpt4o_dir'],level,domain,str(int(ndx/5)+1)) + writeDatasetFile(responses, outp_file) + +def parse_responses(jsonfile): + print('Parsing human responses') + de_dict_general = {"level": "domainexpert", "domain": "general", "query": [], "task_id": []} + de_dict_aging = {"level": "domainexpert", "domain": "aging", "query": [], "task_id": []} + de_dict_diabetes = {"level": "domainexpert", "domain": "diabetes", "query": [], "task_id": []} + cs_dict_general = {"level": "citizenscientist", "domain": "general", "query": [], "task_id": []} + cs_dict_aging = {"level": "citizenscientist", "domain": "aging", "query": [], "task_id": []} + cs_dict_diabetes = {"level": "citizenscientist", "domain": "diabetes", "query": [], "task_id": []} + j = 0 + for _, val in jsonfile.items(): + ndx = 0 + lvl = val.get("level") + for qry in val.get("query"): + ans = val.get("answer")[ndx] if "answer" in val else "" + tpc = val.get("topic")[ndx] + tpc = "general" if tpc==0 else "aging" if tpc==1 else "diabetes" + tskd = val.get("task_id")[ndx] + if lvl == 'cs' and tpc == 'general': + addToDataList(cs_dict_general, qry, ans, tskd) + elif lvl == 'cs' and tpc == 'aging': + addToDataList(cs_dict_aging, qry, ans, tskd) + elif lvl == 'cs' and tpc == 'diabetes': + addToDataList(cs_dict_diabetes, qry, ans, tskd) + elif lvl == 'de' and tpc == 'general': + addToDataList(de_dict_general, qry, ans, tskd) + elif lvl == 'de' and tpc == 'aging': + addToDataList(de_dict_aging, qry, ans, tskd) + elif lvl == 'de' and tpc == 'diabetes': + addToDataList(de_dict_diabetes, qry, ans, tskd) + else: + print('Somehow there is a query without a topic or expertise level') + ndx+=1 + j+=1 + create_datasets_from_taskid(de_dict_general) + create_datasets_from_taskid(de_dict_aging) + create_datasets_from_taskid(de_dict_diabetes) + create_datasets_from_taskid(cs_dict_general) + create_datasets_from_taskid(cs_dict_aging) + create_datasets_from_taskid(cs_dict_diabetes) + +def addToDataList(data_lst, qry, ans, tskd): + data_lst["query"].append(qry) + data_lst["task_id"].append(tskd) + if "answer" not in data_lst.keys(): + data_lst["answer"] = [] + data_lst["answer"].append(ans) + + +def create_datasets_from_taskid(info_dict):#task_list, query_list, answers, domain, level): + print('Creating dataset of questions from {0} in the topic of {1}'.format(info_dict["level"], info_dict["domain"])) + responses = reset_responses() + ndx = 0 + query_list = info_dict["query"] + if "answer" in info_dict: + answers = info_dict["answer"] + else: + info_dict["answer"] = [] + answers = [] + + for task_id in info_dict["task_id"]: + _, an_answer, refs = get_response_from_taskid(config['key.api']['fahamuai'], task_id) + responses['question'].append(query_list[ndx]) + if answers[ndx] == "": + responses['answer'].append(an_answer) + else: + responses['answer'].append(answers[ndx]) + responses['task_id'].append(task_id) + responses['contexts'].append(simplifyContext(refs)) + ndx+=1 + time.sleep(10) # sleep a bit to not overtask the api + if ndx % 5 == 0: + #print('Will print to file number {0}'.format(int(ndx/5))) + outp_file = '{0}dataset_{1}_{2}_{3}_two.json'.format(config['out.response.dataset']['human_dir'],info_dict["level"],info_dict["domain"],str(int(ndx/5))) + writeDatasetFile(responses, outp_file) + responses = reset_responses() + if len(responses['question']) > 0: + #print('Will print to file number {0}'.format(int((ndx/5)+1))) + #print(responses) + outp_file = '{0}dataset_{1}_{2}_{3}_two.json'.format(config['out.response.dataset']['human_dir'],info_dict["level"],info_dict["domain"],str(int(ndx/5)+1)) + writeDatasetFile(responses, outp_file) + +try: + + read_file = str(sys.argv[1]) + file_type = str(sys.argv[2]) + +except: + exit('Example use "python3 retrieve_context.py data/queries/qlist.json human/gpt4o"') + + +print('Read input file') +with open(read_file, "r") as r_file: + file_lst = json.load(r_file) +if file_type == "gpt4o": + parse_document(file_lst) +else: + parse_responses(file_lst)
\ No newline at end of file |