about summary refs log tree commit diff
path: root/gnqa/paper1_eval/src/apis/gnqaclient.py
diff options
context:
space:
mode:
Diffstat (limited to 'gnqa/paper1_eval/src/apis/gnqaclient.py')
-rw-r--r--gnqa/paper1_eval/src/apis/gnqaclient.py226
1 files changed, 226 insertions, 0 deletions
diff --git a/gnqa/paper1_eval/src/apis/gnqaclient.py b/gnqa/paper1_eval/src/apis/gnqaclient.py
new file mode 100644
index 00000000..0024c314
--- /dev/null
+++ b/gnqa/paper1_eval/src/apis/gnqaclient.py
@@ -0,0 +1,226 @@
+# pylint: skip-file
+import json
+import string
+import os
+import datetime
+import time
+import requests
+
+from requests.adapters                    import HTTPAdapter
+from requests.packages.urllib3.util.retry import Retry
+from requests                             import HTTPError,Session
+from urllib.parse                         import urljoin,quote
+from urllib.request                       import urlretrieve
+from errors.rag_err                       import UnprocessableEntity, LLMError
+
+basedir = os.path.join(os.path.dirname(__file__))
+
+
+class TimeoutHTTPAdapter(HTTPAdapter):
+    def __init__(self, timeout, *args, **kwargs):
+        """TimeoutHTTPAdapter constructor.
+        Args:
+            timeout (int): How many seconds to wait for the server to send data before
+                giving up.
+        """
+        self.timeout = timeout
+        super().__init__(*args, **kwargs)
+
+    def send(self, request, **kwargs):
+        """Override :obj:`HTTPAdapter` send method to add a default timeout."""
+        timeout = kwargs.get("timeout")
+        if timeout is None:
+            kwargs["timeout"] = self.timeout
+
+        return super().send(request, **kwargs)
+
+
+class GeneNetworkQAClient(Session):
+    """GeneNetworkQA Client
+
+    This class provides a client object interface to the GeneNetworkQA API.
+    It extends the `requests.Session` class and includes authorization, base URL,
+    request timeouts, and request retries.
+
+    Args:
+        account (str): Base address subdomain.
+        api_key (str): API key.
+        version (str, optional): API version, defaults to "v3".
+        timeout (int, optional): Timeout value, defaults to 5.
+        total_retries (int, optional): Total retries value, defaults to 5.
+        backoff_factor (int, optional): Retry backoff factor value, defaults to 30.
+
+    Usage:
+        from genenetworkqa import GeneNetworkQAClient
+        gnqa = GeneNetworkQAClient(account="account-name", api_key="XXXXXXXXXXXXXXXXXXX...")
+    """
+
+    BASE_URL = 'https://genenetwork.fahamuai.com/api/tasks'
+
+    def __init__(self, account, api_key, version="v3", timeout=30, total_retries=5, backoff_factor=30):
+        super().__init__()
+        self.headers.update(
+            {"Authorization": "Bearer " + api_key})
+        self.answer_url = f"{self.BASE_URL}/answers"
+        self.feedback_url = f"{self.BASE_URL}/feedback"
+
+        adapter = TimeoutHTTPAdapter(
+            timeout=timeout,
+            max_retries=Retry(
+                total=total_retries,
+                status_forcelist=[429, 500, 502, 503, 504],
+                backoff_factor=backoff_factor,
+            ),
+        )
+
+        self.mount("https://", adapter)
+        self.mount("http://", adapter)
+
+    @staticmethod
+    def format_bibliography_info(bib_info):
+
+        if isinstance(bib_info, str):
+            # Remove '.txt'
+            bib_info = bib_info.removesuffix('.txt')
+        elif isinstance(bib_info, dict):
+            # Format string bibliography information
+            bib_info = "{0}.{1}.{2}.{3} ".format(bib_info.get('author', ''),
+                                                 bib_info.get('title', ''),
+                                                 bib_info.get('year', ''),
+                                                 bib_info.get('doi', ''))
+        return bib_info
+
+    @staticmethod
+    def ask_the_documents(extend_url, my_auth):
+        try:
+            response = requests.post(
+                base_url + extend_url, data={}, headers=my_auth)
+            response.raise_for_status()
+        except requests.exceptions.RequestException as e:
+            # Handle the exception appropriately, e.g., log the error
+            raise RuntimeError(f"Error making the request: {e}")
+
+        if response.status_code != 200:
+            return negative_status_msg(response), 0
+
+        task_id = get_task_id_from_result(response)
+        response = get_answer_using_task_id(task_id, my_auth)
+
+        if response.status_code != 200:
+
+            return negative_status_msg(response), 0
+
+        return response, 1
+
+    @staticmethod
+    def negative_status_msg(response):
+        return f"Error: Status code -{response.status_code}- Reason::{response.reason}"
+      #  return f"Problems\n\tStatus code => {response.status_code}\n\tReason => {response.reason}"
+
+    def ask(self, exUrl, *args, **kwargs):
+        askUrl = self.BASE_URL + exUrl
+        res = self.custom_request('POST', askUrl, *args, **kwargs)
+        if (res.status_code != 200):
+            return self.negative_status_msg(res), 0
+        task_id = self.getTaskIDFromResult(res)
+        return res, task_id
+
+    def answer(self, taskid, *args, **kwargs):
+        query = self.answer_url + self.extendForTaskID(taskid)
+        res = self.custom_request('GET', query, *args, **kwargs)
+        if (res.status_code != 200):
+            print('The result is {0}',format(res))
+            return self.negative_status_msg(res), 0
+        return res, 1
+
+    def get_answer(self, taskid, *args, **kwargs):
+        query = self.answer_url + self.extendTaskID(taskid)
+        res = self.custom_request('GET', query, *args, **kwargs)
+        if (res.status_code != 200):
+            print('The result is {0}',format(res))
+            return self.negative_status_msg(res), 0
+        return res, 1
+
+    def custom_request(self, method, url, *args, **kwargs):
+
+        max_retries = 50
+        retry_delay = 3
+
+        for i in range(max_retries):
+            try:
+                response = super().request(method, url, *args, **kwargs)
+                response.raise_for_status()
+
+            except requests.exceptions.HTTPError as error:
+                if error.response.status_code ==500:
+                    raise LLMError(error.request, error.response, f"Response Error,status_code:{error.response.status_code},Reason: Use of Invalid Token")
+                elif error.response.status_code ==404:
+                    raise LLMError(error.request,error.response,f"404 Client Error: Not Found for url: {self.BASE_URL}")
+                raise error
+
+            except requests.exceptions.RequestException as error:
+                raise error
+
+
+            if response.ok:
+                if method.lower() == "get" and response.json().get("data") is None:
+                    time.sleep(retry_delay)
+                    continue
+                else:
+                    return response
+            else:
+                time.sleep(retry_delay)
+        return response
+
+    @staticmethod
+    def get_task_id_from_result(response):
+        task_id = json.loads(response.text)
+        result = f"?task_id={task_id.get('task_id', '')}"
+        return result
+
+    @staticmethod
+    def get_answer_using_task_id(extend_url, my_auth):
+        try:
+            response = requests.get(
+                answer_url + extend_url, data={}, headers=my_auth)
+            response.raise_for_status()
+            return response
+        except requests.exceptions.RequestException as error:
+            # Handle the exception appropriately, e.g., log the error
+            raise error
+
+    @staticmethod
+    def filter_response_text(val):
+        """
+        Filters out non-printable characters from the input string and parses it as JSON.
+
+        Args:
+            val (str): Input string to be filtered and parsed.
+
+        Returns:
+            dict: Parsed JSON object.
+        # remove  this
+        """
+        return json.loads(''.join([str(char) for char in val if char in string.printable]))
+
+    def getTaskIDFromResult(self, res):
+        return json.loads(res.text)
+
+    def extendTaskID(self, task_id):
+        return '?task_id=' + str(task_id['task_id'])
+
+    def extendForTaskID(self, task_id):
+        return '?task_id=' + str(task_id)
+
+    def get_gnqa(self, query):
+        qstr = quote(query)
+        res, task_id = api_client.ask('?ask=' + qstr)
+        res, success = api_client.get_answer(task_id)
+
+        if success == 1:
+            resp_text = filter_response_text(res.text)
+            answer = resp_text.get('data', {}).get('answer', '')
+            context = resp_text.get('data', {}).get('context', '')
+            return answer, context
+        else:
+            return res, "Unfortunately, I have nothing."