aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/utils
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/utils')
-rw-r--r--.venv/lib/python3.12/site-packages/core/utils/__init__.py182
-rw-r--r--.venv/lib/python3.12/site-packages/core/utils/logging_config.py164
-rw-r--r--.venv/lib/python3.12/site-packages/core/utils/sentry.py22
-rw-r--r--.venv/lib/python3.12/site-packages/core/utils/serper.py107
4 files changed, 475 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/utils/__init__.py b/.venv/lib/python3.12/site-packages/core/utils/__init__.py
new file mode 100644
index 00000000..e04db4b9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/utils/__init__.py
@@ -0,0 +1,182 @@
+import re
+from typing import Set, Tuple
+
+from shared.utils.base_utils import (
+ SearchResultsCollector,
+ SSEFormatter,
+ convert_nonserializable_objects,
+ decrement_version,
+ deep_update,
+ dump_collector,
+ dump_obj,
+ format_search_results_for_llm,
+ generate_default_user_collection_id,
+ generate_document_id,
+ generate_extraction_id,
+ generate_id,
+ generate_user_id,
+ increment_version,
+ num_tokens,
+ num_tokens_from_messages,
+ update_settings_from_dict,
+ validate_uuid,
+ yield_sse_event,
+)
+from shared.utils.splitter.text import (
+ RecursiveCharacterTextSplitter,
+ TextSplitter,
+)
+
+
+def extract_citations(text: str) -> list[str]:
+ """
+ Extract citation IDs enclosed in brackets like [abc1234].
+ Returns a list of citation IDs.
+ """
+ # Direct pattern to match IDs inside brackets with alphanumeric pattern
+ CITATION_PATTERN = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
+
+ sids = []
+ for match in CITATION_PATTERN.finditer(text):
+ sid = match.group(1)
+ sids.append(sid)
+
+ return sids
+
+
+def extract_citation_spans(text: str) -> dict[str, list[Tuple[int, int]]]:
+ """
+ Extract citation IDs with their positions in the text.
+
+ Args:
+ text: The text to search for citations
+
+ Returns:
+ dictionary mapping citation IDs to lists of (start, end) position tuples
+ """
+ # Use the same pattern as the original extract_citations
+ CITATION_PATTERN = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
+
+ citation_spans: dict = {}
+
+ for match in CITATION_PATTERN.finditer(text):
+ sid = match.group(1)
+ start = match.start()
+ end = match.end()
+
+ if sid not in citation_spans:
+ citation_spans[sid] = []
+
+ # Add the position span
+ citation_spans[sid].append((start, end))
+
+ return citation_spans
+
+
+class CitationTracker:
+ """
+ Tracks citation spans to ensure each one is only emitted once.
+ """
+
+ def __init__(self):
+ # Track which citation spans we've processed
+ # Format: {citation_id: {(start, end), (start, end), ...}}
+ self.processed_spans: dict[str, Set[Tuple[int, int]]] = {}
+
+ # Track which citation IDs we've seen
+ self.seen_citation_ids: Set[str] = set()
+
+ def is_new_citation(self, citation_id: str) -> bool:
+ """Check if this is the first occurrence of this citation ID."""
+ is_new = citation_id not in self.seen_citation_ids
+ if is_new:
+ self.seen_citation_ids.add(citation_id)
+ return is_new
+
+ def is_new_span(self, citation_id: str, span: Tuple[int, int]) -> bool:
+ """
+ Check if this span has already been processed for this citation ID.
+
+ Args:
+ citation_id: The citation ID
+ span: (start, end) position tuple
+
+ Returns:
+ True if this span hasn't been processed yet, False otherwise
+ """
+ # Initialize set for this citation ID if needed
+ if citation_id not in self.processed_spans:
+ self.processed_spans[citation_id] = set()
+
+ # Check if we've seen this span before
+ if span in self.processed_spans[citation_id]:
+ return False
+
+ # This is a new span, track it
+ self.processed_spans[citation_id].add(span)
+ return True
+
+ def get_all_spans(self) -> dict[str, list[Tuple[int, int]]]:
+ """Get all processed spans for final answer."""
+ return {
+ cid: list(spans) for cid, spans in self.processed_spans.items()
+ }
+
+
+def find_new_citation_spans(
+ text: str, tracker: CitationTracker
+) -> dict[str, list[Tuple[int, int]]]:
+ """
+ Extract citation spans that haven't been processed yet.
+
+ Args:
+ text: Text to search
+ tracker: The CitationTracker instance
+
+ Returns:
+ dictionary of citation IDs to lists of new (start, end) spans
+ """
+ # Get all citation spans in the text
+ all_spans = extract_citation_spans(text)
+
+ # Filter to only spans we haven't processed yet
+ new_spans: dict = {}
+
+ for cid, spans in all_spans.items():
+ for span in spans:
+ if tracker.is_new_span(cid, span):
+ if cid not in new_spans:
+ new_spans[cid] = []
+ new_spans[cid].append(span)
+
+ return new_spans
+
+
+__all__ = [
+ "format_search_results_for_llm",
+ "generate_id",
+ "generate_document_id",
+ "generate_extraction_id",
+ "generate_user_id",
+ "increment_version",
+ "decrement_version",
+ "generate_default_user_collection_id",
+ "validate_uuid",
+ "yield_sse_event",
+ "dump_collector",
+ "dump_obj",
+ "convert_nonserializable_objects",
+ "num_tokens",
+ "num_tokens_from_messages",
+ "SSEFormatter",
+ "SearchResultsCollector",
+ "update_settings_from_dict",
+ "deep_update",
+ # Text splitter
+ "RecursiveCharacterTextSplitter",
+ "TextSplitter",
+ "extract_citations",
+ "extract_citation_spans",
+ "CitationTracker",
+ "find_new_citation_spans",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/utils/logging_config.py b/.venv/lib/python3.12/site-packages/core/utils/logging_config.py
new file mode 100644
index 00000000..9b989c51
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/utils/logging_config.py
@@ -0,0 +1,164 @@
+import logging
+import logging.config
+import os
+import re
+import sys
+from pathlib import Path
+
+
+class HTTPStatusFilter(logging.Filter):
+ """This filter inspects uvicorn.access log records. It uses
+ record.getMessage() to retrieve the fully formatted log message. Then it
+ searches for HTTP status codes and adjusts the.
+
+ record's log level based on that status:
+ - 4xx: WARNING
+ - 5xx: ERROR
+ All other logs remain unchanged.
+ """
+
+ # A broad pattern to find any 3-digit number in the message.
+ # This should capture the HTTP status code from a line like:
+ # '127.0.0.1:54946 - "GET /v2/relationships HTTP/1.1" 404'
+ STATUS_CODE_PATTERN = re.compile(r"\b(\d{3})\b")
+ HEALTH_ENDPOINT_PATTERN = re.compile(r'"GET /v3/health HTTP/\d\.\d"')
+
+ LEVEL_TO_ANSI = {
+ logging.INFO: "\033[32m", # green
+ logging.WARNING: "\033[33m", # yellow
+ logging.ERROR: "\033[31m", # red
+ }
+ RESET = "\033[0m"
+
+ def filter(self, record: logging.LogRecord) -> bool:
+ if record.name != "uvicorn.access":
+ return True
+
+ message = record.getMessage()
+
+ # Filter out health endpoint requests
+ # FIXME: This should be made configurable in the future
+ if self.HEALTH_ENDPOINT_PATTERN.search(message):
+ return False
+
+ if codes := self.STATUS_CODE_PATTERN.findall(message):
+ status_code = int(codes[-1])
+ if 200 <= status_code < 300:
+ record.levelno = logging.INFO
+ record.levelname = "INFO"
+ color = self.LEVEL_TO_ANSI[logging.INFO]
+ elif 400 <= status_code < 500:
+ record.levelno = logging.WARNING
+ record.levelname = "WARNING"
+ color = self.LEVEL_TO_ANSI[logging.WARNING]
+ elif 500 <= status_code < 600:
+ record.levelno = logging.ERROR
+ record.levelname = "ERROR"
+ color = self.LEVEL_TO_ANSI[logging.ERROR]
+ else:
+ return True
+
+ # Wrap the status code in ANSI codes
+ colored_code = f"{color}{status_code}{self.RESET}"
+ # Replace the status code in the message
+ new_msg = message.replace(str(status_code), colored_code)
+
+ # Update record.msg and clear args to avoid formatting issues
+ record.msg = new_msg
+ record.args = ()
+
+ return True
+
+
+log_level = os.environ.get("R2R_LOG_LEVEL", "INFO").upper()
+log_console_formatter = os.environ.get(
+ "R2R_LOG_CONSOLE_FORMATTER", "colored"
+).lower() # colored or json
+
+log_dir = Path.cwd() / "logs"
+log_dir.mkdir(exist_ok=True)
+log_file = log_dir / "app.log"
+
+log_config = {
+ "version": 1,
+ "disable_existing_loggers": False,
+ "filters": {
+ "http_status_filter": {
+ "()": HTTPStatusFilter,
+ }
+ },
+ "formatters": {
+ "default": {
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ "datefmt": "%Y-%m-%d %H:%M:%S",
+ },
+ "colored": {
+ "()": "colorlog.ColoredFormatter",
+ "format": "%(asctime)s - %(log_color)s%(levelname)s%(reset)s - %(message)s",
+ "datefmt": "%Y-%m-%d %H:%M:%S",
+ "log_colors": {
+ "DEBUG": "white",
+ "INFO": "green",
+ "WARNING": "yellow",
+ "ERROR": "red",
+ "CRITICAL": "bold_red",
+ },
+ },
+ "json": {
+ "()": "pythonjsonlogger.json.JsonFormatter",
+ "format": "%(name)s %(levelname)s %(message)s", # these become keys in the JSON log
+ "rename_fields": {
+ "asctime": "time",
+ "levelname": "level",
+ "name": "logger",
+ },
+ },
+ },
+ "handlers": {
+ "file": {
+ "class": "logging.handlers.RotatingFileHandler",
+ "formatter": "colored",
+ "filename": log_file,
+ "maxBytes": 10485760, # 10MB
+ "backupCount": 5,
+ "filters": ["http_status_filter"],
+ "level": log_level, # Set handler level based on the environment variable
+ },
+ "console": {
+ "class": "logging.StreamHandler",
+ "formatter": log_console_formatter,
+ "stream": sys.stdout,
+ "filters": ["http_status_filter"],
+ "level": log_level, # Set handler level based on the environment variable
+ },
+ },
+ "loggers": {
+ "": { # Root logger
+ "handlers": ["console", "file"],
+ "level": log_level, # Set logger level based on the environment variable
+ },
+ "uvicorn": {
+ "handlers": ["console", "file"],
+ "level": log_level,
+ "propagate": False,
+ },
+ "uvicorn.error": {
+ "handlers": ["console", "file"],
+ "level": log_level,
+ "propagate": False,
+ },
+ "uvicorn.access": {
+ "handlers": ["console", "file"],
+ "level": log_level,
+ "propagate": False,
+ },
+ },
+}
+
+
+def configure_logging() -> Path:
+ logging.config.dictConfig(log_config)
+
+ logging.info(f"Logging is configured at {log_level} level.")
+
+ return log_file
diff --git a/.venv/lib/python3.12/site-packages/core/utils/sentry.py b/.venv/lib/python3.12/site-packages/core/utils/sentry.py
new file mode 100644
index 00000000..9a4c09a1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/utils/sentry.py
@@ -0,0 +1,22 @@
+import contextlib
+import os
+
+import sentry_sdk
+
+
+def init_sentry():
+ dsn = os.getenv("R2R_SENTRY_DSN")
+ if not dsn:
+ return
+
+ with contextlib.suppress(Exception):
+ sentry_sdk.init(
+ dsn=dsn,
+ environment=os.getenv("R2R_SENTRY_ENVIRONMENT", "not_set"),
+ traces_sample_rate=float(
+ os.getenv("R2R_SENTRY_TRACES_SAMPLE_RATE", 1.0)
+ ),
+ profiles_sample_rate=float(
+ os.getenv("R2R_SENTRY_PROFILES_SAMPLE_RATE", 1.0)
+ ),
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/utils/serper.py b/.venv/lib/python3.12/site-packages/core/utils/serper.py
new file mode 100644
index 00000000..8962565b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/utils/serper.py
@@ -0,0 +1,107 @@
+# TODO - relocate to a dedicated module
+import http.client
+import json
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+
+# TODO - Move process json to dedicated data processing module
+def process_json(json_object, indent=0):
+ """Recursively traverses the JSON object (dicts and lists) to create an
+ unstructured text blob."""
+ text_blob = ""
+ if isinstance(json_object, dict):
+ for key, value in json_object.items():
+ padding = " " * indent
+ if isinstance(value, (dict, list)):
+ text_blob += (
+ f"{padding}{key}:\n{process_json(value, indent + 1)}"
+ )
+ else:
+ text_blob += f"{padding}{key}: {value}\n"
+ elif isinstance(json_object, list):
+ for index, item in enumerate(json_object):
+ padding = " " * indent
+ if isinstance(item, (dict, list)):
+ text_blob += f"{padding}Item {index + 1}:\n{process_json(item, indent + 1)}"
+ else:
+ text_blob += f"{padding}Item {index + 1}: {item}\n"
+ return text_blob
+
+
+# TODO - Introduce abstract "Integration" ABC.
+class SerperClient:
+ def __init__(self, api_base: str = "google.serper.dev") -> None:
+ api_key = os.getenv("SERPER_API_KEY")
+ if not api_key:
+ raise ValueError(
+ "Please set the `SERPER_API_KEY` environment variable to use `SerperClient`."
+ )
+
+ self.api_base = api_base
+ self.headers = {
+ "X-API-KEY": api_key,
+ "Content-Type": "application/json",
+ }
+
+ @staticmethod
+ def _extract_results(result_data: dict) -> list:
+ formatted_results = []
+
+ for key, value in result_data.items():
+ # Skip searchParameters as it's not a result entry
+ if key == "searchParameters":
+ continue
+
+ # Handle 'answerBox' as a single item
+ if key == "answerBox":
+ value["type"] = key # Add the type key to the dictionary
+ formatted_results.append(value)
+ # Handle lists of results
+ elif isinstance(value, list):
+ for item in value:
+ item["type"] = key # Add the type key to the dictionary
+ formatted_results.append(item)
+ # Handle 'peopleAlsoAsk' and potentially other single item formats
+ elif isinstance(value, dict):
+ value["type"] = key # Add the type key to the dictionary
+ formatted_results.append(value)
+
+ return formatted_results
+
+ # TODO - Add explicit typing for the return value
+ def get_raw(self, query: str, limit: int = 10) -> list:
+ connection = http.client.HTTPSConnection(self.api_base)
+ payload = json.dumps({"q": query, "num_outputs": limit})
+ connection.request("POST", "/search", payload, self.headers)
+ response = connection.getresponse()
+ logger.debug("Received response {response} from Serper API.")
+ data = response.read()
+ json_data = json.loads(data.decode("utf-8"))
+ return SerperClient._extract_results(json_data)
+
+ @staticmethod
+ def construct_context(results: list) -> str:
+ # Organize results by type
+ organized_results = {}
+ for result in results:
+ result_type = result.metadata.pop(
+ "type", "Unknown"
+ ) # Pop the type and use as key
+ if result_type not in organized_results:
+ organized_results[result_type] = [result.metadata]
+ else:
+ organized_results[result_type].append(result.metadata)
+
+ context = ""
+ # Iterate over each result type
+ for result_type, items in organized_results.items():
+ context += f"# {result_type} Results:\n"
+ for index, item in enumerate(items, start=1):
+ # Process each item under the current type
+ context += f"Item {index}:\n"
+ context += process_json(item) + "\n"
+
+ return context