diff options
-rw-r--r-- | gn3/api/llm.py | 8 | ||||
-rw-r--r-- | gn3/llms/process.py | 10 |
2 files changed, 8 insertions, 10 deletions
diff --git a/gn3/api/llm.py b/gn3/api/llm.py index cb68526..c5946ed 100644 --- a/gn3/api/llm.py +++ b/gn3/api/llm.py @@ -24,9 +24,6 @@ from datetime import timedelta GnQNA = Blueprint("GnQNA", __name__) - - - def handle_errors(func): @wraps(func) def decorated_function(*args, **kwargs): @@ -39,6 +36,7 @@ def handle_errors(func): @GnQNA.route("/gnqna", methods=["POST"]) def gnqa(): + # todo add auth query = request.json.get("querygnqa", "") if not query: return jsonify({"error": "querygnqa is missing in the request"}), 400 @@ -46,7 +44,7 @@ def gnqa(): try: auth_token = current_app.config.get("FAHAMU_AUTH_TOKEN") task_id, answer, refs = get_gnqa( - query, auth_token) + query, auth_token, current_app.config.get("TMPDIR", "/tmp")) response = { "task_id": task_id, @@ -78,7 +76,7 @@ def rating(task_id): results.get("answer"), results.get("weight", 0)) - with db.connection(os.path.join(current_app.config["DATA_DIR"],"/llm.db")) as conn: + with db.connection(os.path.join(current_app.config["DATA_DIR"], "/llm.db")) as conn: cursor = conn.cursor() create_table = """CREATE TABLE IF NOT EXISTS Rating( user_id INTEGER NOT NULL, diff --git a/gn3/llms/process.py b/gn3/llms/process.py index 549c7e6..e33d3bc 100644 --- a/gn3/llms/process.py +++ b/gn3/llms/process.py @@ -69,16 +69,16 @@ def rate_document(task_id, doc_id, rating, auth_token): raise RuntimeError(f"An error occurred: {str(error)}") from error -def load_file(filename): +def load_file(filename, dir_path): """function to open and load json file""" - file_path = os.path.join(TMPDIR, filename) + file_path = os.path.join(dir_path, f"/{filename}") if not os.path.isfile(file_path): raise FileNotFoundError(f"{filename} was not found or is a directory") with open(file_path, "rb") as file_handler: return json.load(file_handler) -def fetch_pubmed(references, file_name): +def fetch_pubmed(references, file_name, tmp_dir=""): """method to fetch and populate references with pubmed""" try: @@ -92,7 +92,7 @@ def fetch_pubmed(references, file_name): return references -def get_gnqa(query, auth_token): +def get_gnqa(query, auth_token, tmp_dir=""): """entry function for the gn3 api endpoint()""" api_client = GeneNetworkQAClient(requests.Session(), api_key=auth_token) @@ -108,7 +108,7 @@ def get_gnqa(query, auth_token): context = resp_text['data']['context'] references = parse_context( context, DocIDs().getInfo, format_bibliography_info) - references = fetch_pubmed(references, "pubmed.json") + references = fetch_pubmed(references, "pubmed.json", tmp_dir) return task_id, answer, references else: |