aboutsummaryrefslogtreecommitdiff
path: root/gnqa/paper2_eval/src
diff options
context:
space:
mode:
authorShelbySolomonDarnell2024-08-16 17:26:14 +0300
committerShelbySolomonDarnell2024-08-16 17:26:14 +0300
commit50f0ed1d717d6877cb0562b1f2d54f0f242312d9 (patch)
tree18123164e626cd91d6c79d205532ccd3cc5fa900 /gnqa/paper2_eval/src
parent6e4a45e92f25b1084b34404991706e6c4ce3fd6f (diff)
downloadgn-ai-50f0ed1d717d6877cb0562b1f2d54f0f242312d9.tar.gz
added paper2_eval
Diffstat (limited to 'gnqa/paper2_eval/src')
-rw-r--r--gnqa/paper2_eval/src/parsejson.py63
-rw-r--r--gnqa/paper2_eval/src/retrieve_context.py171
2 files changed, 234 insertions, 0 deletions
diff --git a/gnqa/paper2_eval/src/parsejson.py b/gnqa/paper2_eval/src/parsejson.py
new file mode 100644
index 0000000..b49a898
--- /dev/null
+++ b/gnqa/paper2_eval/src/parsejson.py
@@ -0,0 +1,63 @@
+import json
+import sys
+
+
+def iterate_json(obj, thedict):
+ if isinstance(obj, dict):
+ for key, val in obj.items():
+ if (key == "text"):
+ thedict["contexts"].append(val.replace("\n", " ").strip())
+ elif (key == "answer"):
+ thedict["answer"] = val.replace("\n", " ").strip()
+ elif (key == "question"):
+ thedict["question"] = val.replace("\n", " ").strip()
+ else:
+ if (len(obj.items()) == 1 ):
+ print(key, " --> ", val)
+ iterate_json(val, thedict)
+ elif isinstance(obj, list):
+ for item in obj:
+ iterate_json(item, thedict)
+
+def create_dataset_from_files(tag, file_name, rag_out):
+ for the_file in file_name[tag]:
+ ragas_output = {
+ "contexts": [],
+ "answer": "",
+ "question": ""}
+ #print(the_file)
+ with open("./data/"+the_file, "r") as r_file:
+ data_file = json.load(r_file)
+ iterate_json(data_file, ragas_output)
+ rag_out["answer"].append(ragas_output["answer"])
+ rag_out["question"].append(ragas_output["question"])
+ rag_out["contexts"].append(ragas_output["contexts"])
+
+def create_resultset_from_file(file_name):
+ with open("./data/"+the_file, "r") as r_file:
+ data_file = json.load(r_file)
+ iterate_json(data_file, ragas_output)
+
+
+file_list_tag = str(sys.argv[1])
+read_file = str(sys.argv[2]) # e.g. doc_list.json
+outp_file = str(sys.argv[3])
+
+rag_out = {
+ "question": [],
+ "answer": [],
+ "contexts": []
+}
+
+cntxt_lst = []
+
+# this should be a json file with a list of input files and an output file
+with open(read_file, "r") as r_file:
+ file_lst = json.load(r_file)
+
+create_dataset_from_files(file_list_tag, file_lst, rag_out)
+
+with open(outp_file, "a") as the_data:
+ #json.dump(ragas_output, the_data)
+ the_data.write(",\n")
+ the_data.write(json.dumps(rag_out, indent=2))
diff --git a/gnqa/paper2_eval/src/retrieve_context.py b/gnqa/paper2_eval/src/retrieve_context.py
new file mode 100644
index 0000000..58b9d47
--- /dev/null
+++ b/gnqa/paper2_eval/src/retrieve_context.py
@@ -0,0 +1,171 @@
+import os
+import sys
+import json
+import time
+import configparser
+import apis.process as gnqa
+from apis.process import get_gnqa, get_response_from_taskid
+
+
+config = configparser.ConfigParser()
+config.read('_config.cfg')
+
+'''
+the refs object is a list of items containing doc_id, bibInfo, and comboTxt
+We only need comboTxt
+'''
+def simplifyContext(refs):
+ result = []
+ for item in refs:
+ combo_text = item['comboTxt']
+ combo_text = combo_text.replace('\n','')
+ combo_text = combo_text.replace('\t','')
+ result.append(combo_text)
+ return result
+
+def writeDatasetFile(responses, outp_file):
+ print(outp_file)
+ output = json.dumps(responses, indent=2)
+ if os.path.exists(outp_file):
+ with open(outp_file, "a") as the_data:
+ the_data.write('' + output)
+ else:
+ with open(outp_file, "a") as the_data:
+ the_data.write(output)
+
+
+def reset_responses():
+ return {
+ 'question': [],
+ 'answer': [],
+ 'contexts': [],
+ 'task_id': []
+ }
+
+def parse_document(jsonfile):
+ print('Parse document')
+ for item in jsonfile:
+ level = item['level']
+ domain = item['domain']
+ query_lst = item['query']
+ create_datasets(query_lst, domain, level)
+
+def create_datasets(query_list, domain, level):
+ print('Creating dataset')
+ responses = reset_responses()
+ ndx = 0
+ for query in query_list:
+ print(query)
+ task_id, answer, refs = get_gnqa(query, config['key.api']['fahamuai'], config['DEFAULT']['DATA_DIR'])
+ responses['question'].append(query)
+ responses['answer'].append(answer)
+ responses['task_id'].append(task_id)
+ responses['contexts'].append(simplifyContext(refs))
+ ndx+=1
+ time.sleep(10) # sleep a bit to not overtask the api
+ if ndx % 5 == 0:
+ print('Will print to file number {0}'.format(int(ndx/5)))
+ outp_file = '{0}dataset_{1}_{2}_{3}.json'.format(config['out.response.dataset']['gpt4o_dir'],level,domain,str(int(ndx/5)))
+ writeDatasetFile(responses, outp_file)
+ responses = reset_responses()
+ if len(responses['question']) > 0:
+ outp_file = '{0}dataset_{1}_{2}_{3}.json'.format(config['out.response.dataset']['gpt4o_dir'],level,domain,str(int(ndx/5)+1))
+ writeDatasetFile(responses, outp_file)
+
+def parse_responses(jsonfile):
+ print('Parsing human responses')
+ de_dict_general = {"level": "domainexpert", "domain": "general", "query": [], "task_id": []}
+ de_dict_aging = {"level": "domainexpert", "domain": "aging", "query": [], "task_id": []}
+ de_dict_diabetes = {"level": "domainexpert", "domain": "diabetes", "query": [], "task_id": []}
+ cs_dict_general = {"level": "citizenscientist", "domain": "general", "query": [], "task_id": []}
+ cs_dict_aging = {"level": "citizenscientist", "domain": "aging", "query": [], "task_id": []}
+ cs_dict_diabetes = {"level": "citizenscientist", "domain": "diabetes", "query": [], "task_id": []}
+ j = 0
+ for _, val in jsonfile.items():
+ ndx = 0
+ lvl = val.get("level")
+ for qry in val.get("query"):
+ ans = val.get("answer")[ndx] if "answer" in val else ""
+ tpc = val.get("topic")[ndx]
+ tpc = "general" if tpc==0 else "aging" if tpc==1 else "diabetes"
+ tskd = val.get("task_id")[ndx]
+ if lvl == 'cs' and tpc == 'general':
+ addToDataList(cs_dict_general, qry, ans, tskd)
+ elif lvl == 'cs' and tpc == 'aging':
+ addToDataList(cs_dict_aging, qry, ans, tskd)
+ elif lvl == 'cs' and tpc == 'diabetes':
+ addToDataList(cs_dict_diabetes, qry, ans, tskd)
+ elif lvl == 'de' and tpc == 'general':
+ addToDataList(de_dict_general, qry, ans, tskd)
+ elif lvl == 'de' and tpc == 'aging':
+ addToDataList(de_dict_aging, qry, ans, tskd)
+ elif lvl == 'de' and tpc == 'diabetes':
+ addToDataList(de_dict_diabetes, qry, ans, tskd)
+ else:
+ print('Somehow there is a query without a topic or expertise level')
+ ndx+=1
+ j+=1
+ create_datasets_from_taskid(de_dict_general)
+ create_datasets_from_taskid(de_dict_aging)
+ create_datasets_from_taskid(de_dict_diabetes)
+ create_datasets_from_taskid(cs_dict_general)
+ create_datasets_from_taskid(cs_dict_aging)
+ create_datasets_from_taskid(cs_dict_diabetes)
+
+def addToDataList(data_lst, qry, ans, tskd):
+ data_lst["query"].append(qry)
+ data_lst["task_id"].append(tskd)
+ if "answer" not in data_lst.keys():
+ data_lst["answer"] = []
+ data_lst["answer"].append(ans)
+
+
+def create_datasets_from_taskid(info_dict):#task_list, query_list, answers, domain, level):
+ print('Creating dataset of questions from {0} in the topic of {1}'.format(info_dict["level"], info_dict["domain"]))
+ responses = reset_responses()
+ ndx = 0
+ query_list = info_dict["query"]
+ if "answer" in info_dict:
+ answers = info_dict["answer"]
+ else:
+ info_dict["answer"] = []
+ answers = []
+
+ for task_id in info_dict["task_id"]:
+ _, an_answer, refs = get_response_from_taskid(config['key.api']['fahamuai'], task_id)
+ responses['question'].append(query_list[ndx])
+ if answers[ndx] == "":
+ responses['answer'].append(an_answer)
+ else:
+ responses['answer'].append(answers[ndx])
+ responses['task_id'].append(task_id)
+ responses['contexts'].append(simplifyContext(refs))
+ ndx+=1
+ time.sleep(10) # sleep a bit to not overtask the api
+ if ndx % 5 == 0:
+ #print('Will print to file number {0}'.format(int(ndx/5)))
+ outp_file = '{0}dataset_{1}_{2}_{3}_two.json'.format(config['out.response.dataset']['human_dir'],info_dict["level"],info_dict["domain"],str(int(ndx/5)))
+ writeDatasetFile(responses, outp_file)
+ responses = reset_responses()
+ if len(responses['question']) > 0:
+ #print('Will print to file number {0}'.format(int((ndx/5)+1)))
+ #print(responses)
+ outp_file = '{0}dataset_{1}_{2}_{3}_two.json'.format(config['out.response.dataset']['human_dir'],info_dict["level"],info_dict["domain"],str(int(ndx/5)+1))
+ writeDatasetFile(responses, outp_file)
+
+try:
+
+ read_file = str(sys.argv[1])
+ file_type = str(sys.argv[2])
+
+except:
+ exit('Example use "python3 retrieve_context.py data/queries/qlist.json human/gpt4o"')
+
+
+print('Read input file')
+with open(read_file, "r") as r_file:
+ file_lst = json.load(r_file)
+if file_type == "gpt4o":
+ parse_document(file_lst)
+else:
+ parse_responses(file_lst) \ No newline at end of file