about summary refs log tree commit diff
path: root/gn3
diff options
context:
space:
mode:
Diffstat (limited to 'gn3')
-rw-r--r--gn3/api/llm.py8
-rw-r--r--gn3/llms/process.py10
2 files changed, 8 insertions, 10 deletions
diff --git a/gn3/api/llm.py b/gn3/api/llm.py
index cb68526..c5946ed 100644
--- a/gn3/api/llm.py
+++ b/gn3/api/llm.py
@@ -24,9 +24,6 @@ from datetime import timedelta
 GnQNA = Blueprint("GnQNA", __name__)
 
 
-
-
-
 def handle_errors(func):
     @wraps(func)
     def decorated_function(*args, **kwargs):
@@ -39,6 +36,7 @@ def handle_errors(func):
 
 @GnQNA.route("/gnqna", methods=["POST"])
 def gnqa():
+    # todo  add auth
     query = request.json.get("querygnqa", "")
     if not query:
         return jsonify({"error": "querygnqa is missing in the request"}), 400
@@ -46,7 +44,7 @@ def gnqa():
     try:
         auth_token = current_app.config.get("FAHAMU_AUTH_TOKEN")
         task_id, answer, refs = get_gnqa(
-            query, auth_token)
+            query, auth_token, current_app.config.get("TMPDIR", "/tmp"))
 
         response = {
             "task_id": task_id,
@@ -78,7 +76,7 @@ def rating(task_id):
                                               results.get("answer"),
                                               results.get("weight", 0))
 
-            with db.connection(os.path.join(current_app.config["DATA_DIR"],"/llm.db")) as conn:
+            with db.connection(os.path.join(current_app.config["DATA_DIR"], "/llm.db")) as conn:
                 cursor = conn.cursor()
                 create_table = """CREATE TABLE IF NOT EXISTS Rating(
                       user_id INTEGER NOT NULL,
diff --git a/gn3/llms/process.py b/gn3/llms/process.py
index 549c7e6..e33d3bc 100644
--- a/gn3/llms/process.py
+++ b/gn3/llms/process.py
@@ -69,16 +69,16 @@ def rate_document(task_id, doc_id, rating, auth_token):
         raise RuntimeError(f"An error occurred: {str(error)}") from error
 
 
-def load_file(filename):
+def load_file(filename, dir_path):
     """function to open and load json file"""
-    file_path = os.path.join(TMPDIR, filename)
+    file_path = os.path.join(dir_path, f"/{filename}")
     if not os.path.isfile(file_path):
         raise FileNotFoundError(f"{filename} was not found or is a directory")
     with open(file_path, "rb") as file_handler:
         return json.load(file_handler)
 
 
-def fetch_pubmed(references, file_name):
+def fetch_pubmed(references, file_name, tmp_dir=""):
     """method to fetch and populate references with pubmed"""
 
     try:
@@ -92,7 +92,7 @@ def fetch_pubmed(references, file_name):
         return references
 
 
-def get_gnqa(query, auth_token):
+def get_gnqa(query, auth_token, tmp_dir=""):
     """entry function for the gn3 api endpoint()"""
 
     api_client = GeneNetworkQAClient(requests.Session(), api_key=auth_token)
@@ -108,7 +108,7 @@ def get_gnqa(query, auth_token):
         context = resp_text['data']['context']
         references = parse_context(
             context, DocIDs().getInfo, format_bibliography_info)
-        references = fetch_pubmed(references, "pubmed.json")
+        references = fetch_pubmed(references, "pubmed.json", tmp_dir)
 
         return task_id, answer, references
     else: