diff options
author | Alexander_Kabui | 2024-04-25 19:14:12 +0300 |
---|---|---|
committer | Alexander_Kabui | 2024-05-16 12:53:51 +0300 |
commit | f6acfd3d6024ad36ef82a8e27918b03f6538cccc (patch) | |
tree | 6d5086559a87a670414914148418f8d9b9f2e0e4 /gn3/llms | |
parent | 7aa31cf63e17efe194e501bc37068a2207ab8f38 (diff) | |
download | genenetwork3-f6acfd3d6024ad36ef82a8e27918b03f6538cccc.tar.gz |
Code refactoring
* this commit removes ununsed imports and also refactor
GenenetworkQAclient Class
Diffstat (limited to 'gn3/llms')
-rw-r--r-- | gn3/llms/client.py | 67 |
1 files changed, 31 insertions, 36 deletions
diff --git a/gn3/llms/client.py b/gn3/llms/client.py index 042becd..b843907 100644 --- a/gn3/llms/client.py +++ b/gn3/llms/client.py @@ -2,19 +2,13 @@ import json import string import os -import datetime import time import requests from requests import Session -from urllib.parse import urljoin from requests.packages.urllib3.util.retry import Retry -from requests import HTTPError -from requests import Session from requests.adapters import HTTPAdapter -from urllib.request import urlretrieve from urllib.parse import quote -from gn3.llms.errors import UnprocessableEntity from gn3.llms.errors import LLMError basedir = os.path.join(os.path.dirname(__file__)) @@ -24,7 +18,8 @@ class TimeoutHTTPAdapter(HTTPAdapter): def __init__(self, timeout, *args, **kwargs): """TimeoutHTTPAdapter constructor. Args: - timeout (int): How many seconds to wait for the server to send data before + timeout (int): How many seconds to wait for the server to + send data before giving up. """ self.timeout = timeout @@ -43,7 +38,8 @@ class GeneNetworkQAClient(Session): """GeneNetworkQA Client This class provides a client object interface to the GeneNetworkQA API. - It extends the `requests.Session` class and includes authorization, base URL, + It extends the `requests.Session` class and includes authorization, + base URL, request timeouts, and request retries. Args: @@ -52,16 +48,19 @@ class GeneNetworkQAClient(Session): version (str, optional): API version, defaults to "v3". timeout (int, optional): Timeout value, defaults to 5. total_retries (int, optional): Total retries value, defaults to 5. - backoff_factor (int, optional): Retry backoff factor value, defaults to 30. + backoff_factor (int, optional): Retry backoff factor value, + defaults to 30. Usage: from genenetworkqa import GeneNetworkQAClient - gnqa = GeneNetworkQAClient(account="account-name", api_key="XXXXXXXXXXXXXXXXXXX...") + gnqa = GeneNetworkQAClient(account="account-name", + api_key="XXXXXXXXXXXXXXXXXXX...") """ BASE_URL = 'https://genenetwork.fahamuai.com/api/tasks' - def __init__(self, account, api_key, version="v3", timeout=30, total_retries=5, backoff_factor=30): + def __init__(self, account, api_key, version="v3", timeout=30, + total_retries=5, backoff_factor=30): super().__init__() self.headers.update( {"Authorization": "Bearer " + api_key}) @@ -95,31 +94,31 @@ class GeneNetworkQAClient(Session): return bib_info @staticmethod - def ask_the_documents(extend_url, my_auth): + def ask_the_documents(self, extend_url, my_auth): try: response = requests.post( - base_url + extend_url, data={}, headers=my_auth) + self.base_url + extend_url, data={}, headers=my_auth) response.raise_for_status() except requests.exceptions.RequestException as e: # Handle the exception appropriately, e.g., log the error raise RuntimeError(f"Error making the request: {e}") if response.status_code != 200: - return negative_status_msg(response), 0 + return GeneNetworkQAClient.negative_status_msg(response), 0 - task_id = get_task_id_from_result(response) - response = get_answer_using_task_id(task_id, my_auth) + task_id = GeneNetworkQAClient.get_task_id_from_result(response) + response = GeneNetworkQAClient.get_answer_using_task_id(task_id, + my_auth) if response.status_code != 200: - return negative_status_msg(response), 0 + return GeneNetworkQAClient.negative_status_msg(response), 0 return response, 1 @staticmethod def negative_status_msg(response): return f"Error: Status code -{response.status_code}- Reason::{response.reason}" - # return f"Problems\n\tStatus code => {response.status_code}\n\tReason => {response.reason}" def ask(self, exUrl, *args, **kwargs): askUrl = self.BASE_URL + exUrl @@ -147,18 +146,13 @@ class GeneNetworkQAClient(Session): response.raise_for_status() except requests.exceptions.HTTPError as error: - if error.response.status_code ==500: - raise LLMError(error.request, error.response, f"Response Error,status_code:{error.response.status_code},Reason: Use of Invalid Token") - elif error.response.status_code ==404: - raise LLMError(error.request,error.response,f"404 Client Error: Not Found for url: {self.BASE_URL}") + if error.response.status_code == 500: + raise LLMError(error.request, error.response, f"Response Error with:status_code:{error.response.status_code},Reason for error: Use of Invalid Fahamu Token") + elif error.response.status_code == 404: + raise LLMError(error.request, error.response, f"404 Client Error: Not Found for url: {self.BASE_URL}") raise error - except requests.exceptions.RequestException as error: - raise error - - - - + raise error if response.ok: if method.lower() == "get" and response.json().get("data") is None: time.sleep(retry_delay) @@ -175,11 +169,10 @@ class GeneNetworkQAClient(Session): result = f"?task_id={task_id.get('task_id', '')}" return result - @staticmethod - def get_answer_using_task_id(extend_url, my_auth): + def get_answer_using_task_id(self, extend_url, my_auth): try: response = requests.get( - answer_url + extend_url, data={}, headers=my_auth) + self.answer_url + extend_url, data={}, headers=my_auth) response.raise_for_status() return response except requests.exceptions.RequestException as error: @@ -189,7 +182,8 @@ class GeneNetworkQAClient(Session): @staticmethod def filter_response_text(val): """ - Filters out non-printable characters from the input string and parses it as JSON. + Filters out non-printable characters from + the input string and parses it as JSON. Args: val (str): Input string to be filtered and parsed. @@ -198,7 +192,8 @@ class GeneNetworkQAClient(Session): dict: Parsed JSON object. # remove this """ - return json.loads(''.join([str(char) for char in val if char in string.printable])) + return json.loads(''.join([str(char) for char in val if char + in string.printable])) def getTaskIDFromResult(self, res): return json.loads(res.text) @@ -208,11 +203,11 @@ class GeneNetworkQAClient(Session): def get_gnqa(self, query): qstr = quote(query) - res, task_id = api_client.ask('?ask=' + qstr) - res, success = api_client.get_answer(task_id) + res, task_id = GeneNetworkQAClient.ask('?ask=' + qstr) + res, success = GeneNetworkQAClient.get_answer(task_id) if success == 1: - resp_text = filter_response_text(res.text) + resp_text = GeneNetworkQAClient.filter_response_text(res.text) answer = resp_text.get('data', {}).get('answer', '') context = resp_text.get('data', {}).get('context', '') return answer, context |