about summary refs log tree commit diff
path: root/gnqa/paper2_eval/src/parse_r2r_result.py
diff options
context:
space:
mode:
authorShelbySolomonDarnell2024-10-01 18:58:56 +0300
committerShelbySolomonDarnell2024-10-01 18:58:56 +0300
commit2c1e9099d34a0600918cfbe87b32d0a05003b3ef (patch)
treef3ff878edb210543f1a336b1913f84c8024f2f16 /gnqa/paper2_eval/src/parse_r2r_result.py
parent184339563b23627ca41bac8736f864d1c6bbfcba (diff)
downloadgn-ai-2c1e9099d34a0600918cfbe87b32d0a05003b3ef.tar.gz
Got R2R responses for all human questions.
Diffstat (limited to 'gnqa/paper2_eval/src/parse_r2r_result.py')
-rw-r--r--gnqa/paper2_eval/src/parse_r2r_result.py51
1 files changed, 31 insertions, 20 deletions
diff --git a/gnqa/paper2_eval/src/parse_r2r_result.py b/gnqa/paper2_eval/src/parse_r2r_result.py
index b30f2e76..a958629d 100644
--- a/gnqa/paper2_eval/src/parse_r2r_result.py
+++ b/gnqa/paper2_eval/src/parse_r2r_result.py
@@ -1,33 +1,45 @@
 import json
 import sys
 
+verbose = 1
+
 read_file = '/data/code/gn-ai/gnqa/paper2_eval/data/rag_out_1.json'
 
-def iterate_json(obj, thedict):
+values_key = {
+    "text" :           {"name": "contexts",      "append": 1},
+    "associatedQuery": {"name": "question",      "append": 0},
+    "id":              {"name": "id",            "append": 1},
+    "title":           {"name": "titles",        "append": 1},
+    "document_id":     {"name": "document_id",   "append": 1},
+    "extraction_id":   {"name": "extraction_id", "append": 1},
+    "content":         {"name": "answer",        "append": 0}
+}
+
+def get_ragas_out_dict():
+    return { "titles":        [],
+             "extraction_id": [],
+             "document_id":   [],
+             "id":            [],
+             "contexts":      [],
+             "answer":        "",
+             "question":      ""}
+
+def extract_response(obj, values_key, thedict):
     if isinstance(obj, dict):
         for key, val in obj.items():
-            if (key == "text"):
-                thedict["contexts"].append(val.replace("\n", " ").strip())
-                print("Key -> {0}\tValue -> {1}".format(key,val))
-            elif (key == "metadata"):
-                thedict["answer"] = val#.replace("\n", " ").strip()
-                print("Key -> {0}\tValue -> {1}".format(key,val))
-            elif (key == "id"):
-                print("Key -> {0}\tValue -> {1}".format(key,val))
-            elif (key == "associatedQuery"):
-                thedict["question"] = val.replace("\n", " ").strip()
-                print("Key -> {0}\tValue -> {1}".format(key,val))
-            elif (key == "title"):
-                print("Key -> {0}\tValue -> {1}".format(key,val))
-            elif (key == "document_id"):
-                print("Key -> {0}\tValue -> {1}".format(key,val))
+            if (key in values_key.keys()):
+                if (values_key[key]["append"]):
+                    thedict[values_key[key]["name"]].append(val.replace("\n", " ").strip())
+                else:
+                    thedict[values_key[key]["name"]] = val.replace("\n", " ").strip()
+                print(("", "Key -> {0}\tValue -> {1}".format(key,val)) [verbose])
             else:
                 if (len(obj.items()) == 1 ):
                     print(key, " --> ", val)
-            iterate_json(val, thedict)
+            extract_response(val, values_key, thedict)
     elif isinstance(obj, list):
         for item in obj:
-            iterate_json(item, thedict)
+            extract_response(item, values_key, thedict)
 
 # this should be a json file with a list of input files and an output file
 with open(read_file, "r") as r_file:
@@ -38,7 +50,6 @@ ragas_output = {
     "titles": [],
     "answer": "",
     "question": ""}
-vector_search_results = result_file["vector_search_results"]
-iterate_json(vector_search_results, ragas_output)
+extract_response(result_file, values_key, ragas_output)
 
 print(json.dumps(ragas_output, indent=2))
\ No newline at end of file