about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/utils
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/utils
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/utils')
-rw-r--r--.venv/lib/python3.12/site-packages/core/utils/__init__.py182
-rw-r--r--.venv/lib/python3.12/site-packages/core/utils/logging_config.py164
-rw-r--r--.venv/lib/python3.12/site-packages/core/utils/sentry.py22
-rw-r--r--.venv/lib/python3.12/site-packages/core/utils/serper.py107
4 files changed, 475 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/utils/__init__.py b/.venv/lib/python3.12/site-packages/core/utils/__init__.py
new file mode 100644
index 00000000..e04db4b9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/utils/__init__.py
@@ -0,0 +1,182 @@
+import re
+from typing import Set, Tuple
+
+from shared.utils.base_utils import (
+    SearchResultsCollector,
+    SSEFormatter,
+    convert_nonserializable_objects,
+    decrement_version,
+    deep_update,
+    dump_collector,
+    dump_obj,
+    format_search_results_for_llm,
+    generate_default_user_collection_id,
+    generate_document_id,
+    generate_extraction_id,
+    generate_id,
+    generate_user_id,
+    increment_version,
+    num_tokens,
+    num_tokens_from_messages,
+    update_settings_from_dict,
+    validate_uuid,
+    yield_sse_event,
+)
+from shared.utils.splitter.text import (
+    RecursiveCharacterTextSplitter,
+    TextSplitter,
+)
+
+
+def extract_citations(text: str) -> list[str]:
+    """
+    Extract citation IDs enclosed in brackets like [abc1234].
+    Returns a list of citation IDs.
+    """
+    # Direct pattern to match IDs inside brackets with alphanumeric pattern
+    CITATION_PATTERN = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
+
+    sids = []
+    for match in CITATION_PATTERN.finditer(text):
+        sid = match.group(1)
+        sids.append(sid)
+
+    return sids
+
+
+def extract_citation_spans(text: str) -> dict[str, list[Tuple[int, int]]]:
+    """
+    Extract citation IDs with their positions in the text.
+
+    Args:
+        text: The text to search for citations
+
+    Returns:
+        dictionary mapping citation IDs to lists of (start, end) position tuples
+    """
+    # Use the same pattern as the original extract_citations
+    CITATION_PATTERN = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
+
+    citation_spans: dict = {}
+
+    for match in CITATION_PATTERN.finditer(text):
+        sid = match.group(1)
+        start = match.start()
+        end = match.end()
+
+        if sid not in citation_spans:
+            citation_spans[sid] = []
+
+        # Add the position span
+        citation_spans[sid].append((start, end))
+
+    return citation_spans
+
+
+class CitationTracker:
+    """
+    Tracks citation spans to ensure each one is only emitted once.
+    """
+
+    def __init__(self):
+        # Track which citation spans we've processed
+        # Format: {citation_id: {(start, end), (start, end), ...}}
+        self.processed_spans: dict[str, Set[Tuple[int, int]]] = {}
+
+        # Track which citation IDs we've seen
+        self.seen_citation_ids: Set[str] = set()
+
+    def is_new_citation(self, citation_id: str) -> bool:
+        """Check if this is the first occurrence of this citation ID."""
+        is_new = citation_id not in self.seen_citation_ids
+        if is_new:
+            self.seen_citation_ids.add(citation_id)
+        return is_new
+
+    def is_new_span(self, citation_id: str, span: Tuple[int, int]) -> bool:
+        """
+        Check if this span has already been processed for this citation ID.
+
+        Args:
+            citation_id: The citation ID
+            span: (start, end) position tuple
+
+        Returns:
+            True if this span hasn't been processed yet, False otherwise
+        """
+        # Initialize set for this citation ID if needed
+        if citation_id not in self.processed_spans:
+            self.processed_spans[citation_id] = set()
+
+        # Check if we've seen this span before
+        if span in self.processed_spans[citation_id]:
+            return False
+
+        # This is a new span, track it
+        self.processed_spans[citation_id].add(span)
+        return True
+
+    def get_all_spans(self) -> dict[str, list[Tuple[int, int]]]:
+        """Get all processed spans for final answer."""
+        return {
+            cid: list(spans) for cid, spans in self.processed_spans.items()
+        }
+
+
+def find_new_citation_spans(
+    text: str, tracker: CitationTracker
+) -> dict[str, list[Tuple[int, int]]]:
+    """
+    Extract citation spans that haven't been processed yet.
+
+    Args:
+        text: Text to search
+        tracker: The CitationTracker instance
+
+    Returns:
+        dictionary of citation IDs to lists of new (start, end) spans
+    """
+    # Get all citation spans in the text
+    all_spans = extract_citation_spans(text)
+
+    # Filter to only spans we haven't processed yet
+    new_spans: dict = {}
+
+    for cid, spans in all_spans.items():
+        for span in spans:
+            if tracker.is_new_span(cid, span):
+                if cid not in new_spans:
+                    new_spans[cid] = []
+                new_spans[cid].append(span)
+
+    return new_spans
+
+
+__all__ = [
+    "format_search_results_for_llm",
+    "generate_id",
+    "generate_document_id",
+    "generate_extraction_id",
+    "generate_user_id",
+    "increment_version",
+    "decrement_version",
+    "generate_default_user_collection_id",
+    "validate_uuid",
+    "yield_sse_event",
+    "dump_collector",
+    "dump_obj",
+    "convert_nonserializable_objects",
+    "num_tokens",
+    "num_tokens_from_messages",
+    "SSEFormatter",
+    "SearchResultsCollector",
+    "update_settings_from_dict",
+    "deep_update",
+    # Text splitter
+    "RecursiveCharacterTextSplitter",
+    "TextSplitter",
+    "extract_citations",
+    "extract_citation_spans",
+    "CitationTracker",
+    "find_new_citation_spans",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/utils/logging_config.py b/.venv/lib/python3.12/site-packages/core/utils/logging_config.py
new file mode 100644
index 00000000..9b989c51
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/utils/logging_config.py
@@ -0,0 +1,164 @@
+import logging
+import logging.config
+import os
+import re
+import sys
+from pathlib import Path
+
+
+class HTTPStatusFilter(logging.Filter):
+    """This filter inspects uvicorn.access log records. It uses
+    record.getMessage() to retrieve the fully formatted log message. Then it
+    searches for HTTP status codes and adjusts the.
+
+    record's log level based on that status:
+      - 4xx: WARNING
+      - 5xx: ERROR
+    All other logs remain unchanged.
+    """
+
+    # A broad pattern to find any 3-digit number in the message.
+    # This should capture the HTTP status code from a line like:
+    # '127.0.0.1:54946 - "GET /v2/relationships HTTP/1.1" 404'
+    STATUS_CODE_PATTERN = re.compile(r"\b(\d{3})\b")
+    HEALTH_ENDPOINT_PATTERN = re.compile(r'"GET /v3/health HTTP/\d\.\d"')
+
+    LEVEL_TO_ANSI = {
+        logging.INFO: "\033[32m",  # green
+        logging.WARNING: "\033[33m",  # yellow
+        logging.ERROR: "\033[31m",  # red
+    }
+    RESET = "\033[0m"
+
+    def filter(self, record: logging.LogRecord) -> bool:
+        if record.name != "uvicorn.access":
+            return True
+
+        message = record.getMessage()
+
+        # Filter out health endpoint requests
+        # FIXME: This should be made configurable in the future
+        if self.HEALTH_ENDPOINT_PATTERN.search(message):
+            return False
+
+        if codes := self.STATUS_CODE_PATTERN.findall(message):
+            status_code = int(codes[-1])
+            if 200 <= status_code < 300:
+                record.levelno = logging.INFO
+                record.levelname = "INFO"
+                color = self.LEVEL_TO_ANSI[logging.INFO]
+            elif 400 <= status_code < 500:
+                record.levelno = logging.WARNING
+                record.levelname = "WARNING"
+                color = self.LEVEL_TO_ANSI[logging.WARNING]
+            elif 500 <= status_code < 600:
+                record.levelno = logging.ERROR
+                record.levelname = "ERROR"
+                color = self.LEVEL_TO_ANSI[logging.ERROR]
+            else:
+                return True
+
+            # Wrap the status code in ANSI codes
+            colored_code = f"{color}{status_code}{self.RESET}"
+            # Replace the status code in the message
+            new_msg = message.replace(str(status_code), colored_code)
+
+            # Update record.msg and clear args to avoid formatting issues
+            record.msg = new_msg
+            record.args = ()
+
+        return True
+
+
+log_level = os.environ.get("R2R_LOG_LEVEL", "INFO").upper()
+log_console_formatter = os.environ.get(
+    "R2R_LOG_CONSOLE_FORMATTER", "colored"
+).lower()  # colored or json
+
+log_dir = Path.cwd() / "logs"
+log_dir.mkdir(exist_ok=True)
+log_file = log_dir / "app.log"
+
+log_config = {
+    "version": 1,
+    "disable_existing_loggers": False,
+    "filters": {
+        "http_status_filter": {
+            "()": HTTPStatusFilter,
+        }
+    },
+    "formatters": {
+        "default": {
+            "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+            "datefmt": "%Y-%m-%d %H:%M:%S",
+        },
+        "colored": {
+            "()": "colorlog.ColoredFormatter",
+            "format": "%(asctime)s - %(log_color)s%(levelname)s%(reset)s - %(message)s",
+            "datefmt": "%Y-%m-%d %H:%M:%S",
+            "log_colors": {
+                "DEBUG": "white",
+                "INFO": "green",
+                "WARNING": "yellow",
+                "ERROR": "red",
+                "CRITICAL": "bold_red",
+            },
+        },
+        "json": {
+            "()": "pythonjsonlogger.json.JsonFormatter",
+            "format": "%(name)s %(levelname)s %(message)s",  # these become keys in the JSON log
+            "rename_fields": {
+                "asctime": "time",
+                "levelname": "level",
+                "name": "logger",
+            },
+        },
+    },
+    "handlers": {
+        "file": {
+            "class": "logging.handlers.RotatingFileHandler",
+            "formatter": "colored",
+            "filename": log_file,
+            "maxBytes": 10485760,  # 10MB
+            "backupCount": 5,
+            "filters": ["http_status_filter"],
+            "level": log_level,  # Set handler level based on the environment variable
+        },
+        "console": {
+            "class": "logging.StreamHandler",
+            "formatter": log_console_formatter,
+            "stream": sys.stdout,
+            "filters": ["http_status_filter"],
+            "level": log_level,  # Set handler level based on the environment variable
+        },
+    },
+    "loggers": {
+        "": {  # Root logger
+            "handlers": ["console", "file"],
+            "level": log_level,  # Set logger level based on the environment variable
+        },
+        "uvicorn": {
+            "handlers": ["console", "file"],
+            "level": log_level,
+            "propagate": False,
+        },
+        "uvicorn.error": {
+            "handlers": ["console", "file"],
+            "level": log_level,
+            "propagate": False,
+        },
+        "uvicorn.access": {
+            "handlers": ["console", "file"],
+            "level": log_level,
+            "propagate": False,
+        },
+    },
+}
+
+
+def configure_logging() -> Path:
+    logging.config.dictConfig(log_config)
+
+    logging.info(f"Logging is configured at {log_level} level.")
+
+    return log_file
diff --git a/.venv/lib/python3.12/site-packages/core/utils/sentry.py b/.venv/lib/python3.12/site-packages/core/utils/sentry.py
new file mode 100644
index 00000000..9a4c09a1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/utils/sentry.py
@@ -0,0 +1,22 @@
+import contextlib
+import os
+
+import sentry_sdk
+
+
+def init_sentry():
+    dsn = os.getenv("R2R_SENTRY_DSN")
+    if not dsn:
+        return
+
+    with contextlib.suppress(Exception):
+        sentry_sdk.init(
+            dsn=dsn,
+            environment=os.getenv("R2R_SENTRY_ENVIRONMENT", "not_set"),
+            traces_sample_rate=float(
+                os.getenv("R2R_SENTRY_TRACES_SAMPLE_RATE", 1.0)
+            ),
+            profiles_sample_rate=float(
+                os.getenv("R2R_SENTRY_PROFILES_SAMPLE_RATE", 1.0)
+            ),
+        )
diff --git a/.venv/lib/python3.12/site-packages/core/utils/serper.py b/.venv/lib/python3.12/site-packages/core/utils/serper.py
new file mode 100644
index 00000000..8962565b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/utils/serper.py
@@ -0,0 +1,107 @@
+# TODO - relocate to a dedicated module
+import http.client
+import json
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+
+# TODO - Move process json to dedicated data processing module
+def process_json(json_object, indent=0):
+    """Recursively traverses the JSON object (dicts and lists) to create an
+    unstructured text blob."""
+    text_blob = ""
+    if isinstance(json_object, dict):
+        for key, value in json_object.items():
+            padding = "  " * indent
+            if isinstance(value, (dict, list)):
+                text_blob += (
+                    f"{padding}{key}:\n{process_json(value, indent + 1)}"
+                )
+            else:
+                text_blob += f"{padding}{key}: {value}\n"
+    elif isinstance(json_object, list):
+        for index, item in enumerate(json_object):
+            padding = "  " * indent
+            if isinstance(item, (dict, list)):
+                text_blob += f"{padding}Item {index + 1}:\n{process_json(item, indent + 1)}"
+            else:
+                text_blob += f"{padding}Item {index + 1}: {item}\n"
+    return text_blob
+
+
+# TODO - Introduce abstract "Integration" ABC.
+class SerperClient:
+    def __init__(self, api_base: str = "google.serper.dev") -> None:
+        api_key = os.getenv("SERPER_API_KEY")
+        if not api_key:
+            raise ValueError(
+                "Please set the `SERPER_API_KEY` environment variable to use `SerperClient`."
+            )
+
+        self.api_base = api_base
+        self.headers = {
+            "X-API-KEY": api_key,
+            "Content-Type": "application/json",
+        }
+
+    @staticmethod
+    def _extract_results(result_data: dict) -> list:
+        formatted_results = []
+
+        for key, value in result_data.items():
+            # Skip searchParameters as it's not a result entry
+            if key == "searchParameters":
+                continue
+
+            # Handle 'answerBox' as a single item
+            if key == "answerBox":
+                value["type"] = key  # Add the type key to the dictionary
+                formatted_results.append(value)
+            # Handle lists of results
+            elif isinstance(value, list):
+                for item in value:
+                    item["type"] = key  # Add the type key to the dictionary
+                    formatted_results.append(item)
+            # Handle 'peopleAlsoAsk' and potentially other single item formats
+            elif isinstance(value, dict):
+                value["type"] = key  # Add the type key to the dictionary
+                formatted_results.append(value)
+
+        return formatted_results
+
+    # TODO - Add explicit typing for the return value
+    def get_raw(self, query: str, limit: int = 10) -> list:
+        connection = http.client.HTTPSConnection(self.api_base)
+        payload = json.dumps({"q": query, "num_outputs": limit})
+        connection.request("POST", "/search", payload, self.headers)
+        response = connection.getresponse()
+        logger.debug("Received response {response} from Serper API.")
+        data = response.read()
+        json_data = json.loads(data.decode("utf-8"))
+        return SerperClient._extract_results(json_data)
+
+    @staticmethod
+    def construct_context(results: list) -> str:
+        # Organize results by type
+        organized_results = {}
+        for result in results:
+            result_type = result.metadata.pop(
+                "type", "Unknown"
+            )  # Pop the type and use as key
+            if result_type not in organized_results:
+                organized_results[result_type] = [result.metadata]
+            else:
+                organized_results[result_type].append(result.metadata)
+
+        context = ""
+        # Iterate over each result type
+        for result_type, items in organized_results.items():
+            context += f"# {result_type} Results:\n"
+            for index, item in enumerate(items, start=1):
+                # Process each item under the current type
+                context += f"Item {index}:\n"
+                context += process_json(item) + "\n"
+
+        return context