aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/api/llm.py8
-rw-r--r--gn3/llms/process.py10
2 files changed, 8 insertions, 10 deletions
diff --git a/gn3/api/llm.py b/gn3/api/llm.py
index cb68526..c5946ed 100644
--- a/gn3/api/llm.py
+++ b/gn3/api/llm.py
@@ -24,9 +24,6 @@ from datetime import timedelta
GnQNA = Blueprint("GnQNA", __name__)
-
-
-
def handle_errors(func):
@wraps(func)
def decorated_function(*args, **kwargs):
@@ -39,6 +36,7 @@ def handle_errors(func):
@GnQNA.route("/gnqna", methods=["POST"])
def gnqa():
+ # todo add auth
query = request.json.get("querygnqa", "")
if not query:
return jsonify({"error": "querygnqa is missing in the request"}), 400
@@ -46,7 +44,7 @@ def gnqa():
try:
auth_token = current_app.config.get("FAHAMU_AUTH_TOKEN")
task_id, answer, refs = get_gnqa(
- query, auth_token)
+ query, auth_token, current_app.config.get("TMPDIR", "/tmp"))
response = {
"task_id": task_id,
@@ -78,7 +76,7 @@ def rating(task_id):
results.get("answer"),
results.get("weight", 0))
- with db.connection(os.path.join(current_app.config["DATA_DIR"],"/llm.db")) as conn:
+ with db.connection(os.path.join(current_app.config["DATA_DIR"], "/llm.db")) as conn:
cursor = conn.cursor()
create_table = """CREATE TABLE IF NOT EXISTS Rating(
user_id INTEGER NOT NULL,
diff --git a/gn3/llms/process.py b/gn3/llms/process.py
index 549c7e6..e33d3bc 100644
--- a/gn3/llms/process.py
+++ b/gn3/llms/process.py
@@ -69,16 +69,16 @@ def rate_document(task_id, doc_id, rating, auth_token):
raise RuntimeError(f"An error occurred: {str(error)}") from error
-def load_file(filename):
+def load_file(filename, dir_path):
"""function to open and load json file"""
- file_path = os.path.join(TMPDIR, filename)
+ file_path = os.path.join(dir_path, f"/{filename}")
if not os.path.isfile(file_path):
raise FileNotFoundError(f"{filename} was not found or is a directory")
with open(file_path, "rb") as file_handler:
return json.load(file_handler)
-def fetch_pubmed(references, file_name):
+def fetch_pubmed(references, file_name, tmp_dir=""):
"""method to fetch and populate references with pubmed"""
try:
@@ -92,7 +92,7 @@ def fetch_pubmed(references, file_name):
return references
-def get_gnqa(query, auth_token):
+def get_gnqa(query, auth_token, tmp_dir=""):
"""entry function for the gn3 api endpoint()"""
api_client = GeneNetworkQAClient(requests.Session(), api_key=auth_token)
@@ -108,7 +108,7 @@ def get_gnqa(query, auth_token):
context = resp_text['data']['context']
references = parse_context(
context, DocIDs().getInfo, format_bibliography_info)
- references = fetch_pubmed(references, "pubmed.json")
+ references = fetch_pubmed(references, "pubmed.json", tmp_dir)
return task_id, answer, references
else: