about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/abstractions/llm.py')
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/llm.py325
1 files changed, 325 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py
new file mode 100644
index 00000000..d71e279e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py
@@ -0,0 +1,325 @@
+"""Abstractions for the LLM model."""
+
+import json
+from enum import Enum
+from typing import TYPE_CHECKING, Any, ClassVar, Optional
+
+from openai.types.chat import ChatCompletionChunk
+from pydantic import BaseModel, Field
+
+from .base import R2RSerializable
+
+if TYPE_CHECKING:
+    from .search import AggregateSearchResult
+
+from typing_extensions import Literal
+
+
+class Function(BaseModel):
+    arguments: str
+    """
+    The arguments to call the function with, as generated by the model in JSON
+    format. Note that the model does not always generate valid JSON, and may
+    hallucinate parameters not defined by your function schema. Validate the
+    arguments in your code before calling your function.
+    """
+
+    name: str
+    """The name of the function to call."""
+
+
+class ChatCompletionMessageToolCall(BaseModel):
+    id: str
+    """The ID of the tool call."""
+
+    function: Function
+    """The function that the model called."""
+
+    type: Literal["function"]
+    """The type of the tool. Currently, only `function` is supported."""
+
+
+class FunctionCall(BaseModel):
+    arguments: str
+    """
+    The arguments to call the function with, as generated by the model in JSON
+    format. Note that the model does not always generate valid JSON, and may
+    hallucinate parameters not defined by your function schema. Validate the
+    arguments in your code before calling your function.
+    """
+
+    name: str
+    """The name of the function to call."""
+
+
+class ChatCompletionMessage(BaseModel):
+    content: Optional[str] = None
+    """The contents of the message."""
+
+    refusal: Optional[str] = None
+    """The refusal message generated by the model."""
+
+    role: Literal["assistant"]
+    """The role of the author of this message."""
+
+    # audio: Optional[ChatCompletionAudio] = None
+    """
+    If the audio output modality is requested, this object contains data about the
+    audio response from the model.
+    [Learn more](https://platform.openai.com/docs/guides/audio).
+    """
+
+    function_call: Optional[FunctionCall] = None
+    """Deprecated and replaced by `tool_calls`.
+
+    The name and arguments of a function that should be called, as generated by the
+    model.
+    """
+
+    tool_calls: Optional[list[ChatCompletionMessageToolCall]] = None
+    """The tool calls generated by the model, such as function calls."""
+
+    structured_content: Optional[list[dict]] = None
+
+
+class Choice(BaseModel):
+    finish_reason: Literal[
+        "stop",
+        "length",
+        "tool_calls",
+        "content_filter",
+        "function_call",
+        "max_tokens",
+    ]
+    """The reason the model stopped generating tokens.
+
+    This will be `stop` if the model hit a natural stop point or a provided stop
+    sequence, `length` if the maximum number of tokens specified in the request was
+    reached, `content_filter` if content was omitted due to a flag from our content
+    filters, `tool_calls` if the model called a tool, or `function_call`
+    (deprecated) if the model called a function.
+    """
+
+    index: int
+    """The index of the choice in the list of choices."""
+
+    # logprobs: Optional[ChoiceLogprobs] = None
+    """Log probability information for the choice."""
+
+    message: ChatCompletionMessage
+    """A chat completion message generated by the model."""
+
+
+class LLMChatCompletion(BaseModel):
+    id: str
+    """A unique identifier for the chat completion."""
+
+    choices: list[Choice]
+    """A list of chat completion choices.
+
+    Can be more than one if `n` is greater than 1.
+    """
+
+    created: int
+    """The Unix timestamp (in seconds) of when the chat completion was created."""
+
+    model: str
+    """The model used for the chat completion."""
+
+    object: Literal["chat.completion"]
+    """The object type, which is always `chat.completion`."""
+
+    service_tier: Optional[Literal["scale", "default"]] = None
+    """The service tier used for processing the request."""
+
+    system_fingerprint: Optional[str] = None
+    """This fingerprint represents the backend configuration that the model runs with.
+
+    Can be used in conjunction with the `seed` request parameter to understand when
+    backend changes have been made that might impact determinism.
+    """
+
+    usage: Optional[Any] = None
+    """Usage statistics for the completion request."""
+
+
+LLMChatCompletionChunk = ChatCompletionChunk
+
+
+class RAGCompletion:
+    completion: LLMChatCompletion
+    search_results: "AggregateSearchResult"
+
+    def __init__(
+        self,
+        completion: LLMChatCompletion,
+        search_results: "AggregateSearchResult",
+    ):
+        self.completion = completion
+        self.search_results = search_results
+
+
+class GenerationConfig(R2RSerializable):
+    _defaults: ClassVar[dict] = {
+        "model": None,
+        "temperature": 0.1,
+        "top_p": 1.0,
+        "max_tokens_to_sample": 1024,
+        "stream": False,
+        "functions": None,
+        "tools": None,
+        "add_generation_kwargs": None,
+        "api_base": None,
+        "response_format": None,
+        "extended_thinking": False,
+        "thinking_budget": None,
+        "reasoning_effort": None,
+    }
+
+    model: Optional[str] = Field(
+        default_factory=lambda: GenerationConfig._defaults["model"]
+    )
+    temperature: float = Field(
+        default_factory=lambda: GenerationConfig._defaults["temperature"]
+    )
+    top_p: Optional[float] = Field(
+        default_factory=lambda: GenerationConfig._defaults["top_p"],
+    )
+    max_tokens_to_sample: int = Field(
+        default_factory=lambda: GenerationConfig._defaults[
+            "max_tokens_to_sample"
+        ],
+    )
+    stream: bool = Field(
+        default_factory=lambda: GenerationConfig._defaults["stream"]
+    )
+    functions: Optional[list[dict]] = Field(
+        default_factory=lambda: GenerationConfig._defaults["functions"]
+    )
+    tools: Optional[list[dict]] = Field(
+        default_factory=lambda: GenerationConfig._defaults["tools"]
+    )
+    add_generation_kwargs: Optional[dict] = Field(
+        default_factory=lambda: GenerationConfig._defaults[
+            "add_generation_kwargs"
+        ],
+    )
+    api_base: Optional[str] = Field(
+        default_factory=lambda: GenerationConfig._defaults["api_base"],
+    )
+    response_format: Optional[dict | BaseModel] = None
+    extended_thinking: bool = Field(
+        default=False,
+        description="Flag to enable extended thinking mode (for Anthropic providers)",
+    )
+    thinking_budget: Optional[int] = Field(
+        default=None,
+        description=(
+            "Token budget for internal reasoning when extended thinking mode is enabled. "
+            "Must be less than max_tokens_to_sample."
+        ),
+    )
+    reasoning_effort: Optional[str] = Field(
+        default=None,
+        description=(
+            "Effort level for internal reasoning when extended thinking mode is enabled, `low`, `medium`, or `high`."
+            "Only applicable to OpenAI providers."
+        ),
+    )
+
+    @classmethod
+    def set_default(cls, **kwargs):
+        for key, value in kwargs.items():
+            if key in cls._defaults:
+                cls._defaults[key] = value
+            else:
+                raise AttributeError(
+                    f"No default attribute '{key}' in GenerationConfig"
+                )
+
+    def __init__(self, **data):
+        # Handle max_tokens mapping to max_tokens_to_sample
+        if "max_tokens" in data:
+            # Only set max_tokens_to_sample if it's not already provided
+            if "max_tokens_to_sample" not in data:
+                data["max_tokens_to_sample"] = data.pop("max_tokens")
+            else:
+                # If both are provided, max_tokens_to_sample takes precedence
+                data.pop("max_tokens")
+
+        if (
+            "response_format" in data
+            and isinstance(data["response_format"], type)
+            and issubclass(data["response_format"], BaseModel)
+        ):
+            model_class = data["response_format"]
+            data["response_format"] = {
+                "type": "json_schema",
+                "json_schema": {
+                    "name": model_class.__name__,
+                    "schema": model_class.model_json_schema(),
+                },
+            }
+
+        model = data.pop("model", None)
+        if model is not None:
+            super().__init__(model=model, **data)
+        else:
+            super().__init__(**data)
+
+    def __str__(self):
+        return json.dumps(self.to_dict())
+
+    class Config:
+        populate_by_name = True
+        json_schema_extra = {
+            "example": {
+                "model": "openai/gpt-4o",
+                "temperature": 0.1,
+                "top_p": 1.0,
+                "max_tokens_to_sample": 1024,
+                "stream": False,
+                "functions": None,
+                "tools": None,
+                "add_generation_kwargs": None,
+                "api_base": None,
+            }
+        }
+
+
+class MessageType(Enum):
+    SYSTEM = "system"
+    USER = "user"
+    ASSISTANT = "assistant"
+    FUNCTION = "function"
+    TOOL = "tool"
+
+    def __str__(self):
+        return self.value
+
+
+class Message(R2RSerializable):
+    role: MessageType | str
+    content: Optional[Any] = None
+    name: Optional[str] = None
+    function_call: Optional[dict[str, Any]] = None
+    tool_calls: Optional[list[dict[str, Any]]] = None
+    tool_call_id: Optional[str] = None
+    metadata: Optional[dict[str, Any]] = None
+    structured_content: Optional[list[dict]] = None
+    image_url: Optional[str] = None  # For URL-based images
+    image_data: Optional[dict[str, str]] = (
+        None  # For base64 {media_type, data}
+    )
+
+    class Config:
+        populate_by_name = True
+        json_schema_extra = {
+            "example": {
+                "role": "user",
+                "content": "This is a test message.",
+                "name": None,
+                "function_call": None,
+                "tool_calls": None,
+            }
+        }