aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/abstractions/llm.py')
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/llm.py325
1 files changed, 325 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py
new file mode 100644
index 00000000..d71e279e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py
@@ -0,0 +1,325 @@
+"""Abstractions for the LLM model."""
+
+import json
+from enum import Enum
+from typing import TYPE_CHECKING, Any, ClassVar, Optional
+
+from openai.types.chat import ChatCompletionChunk
+from pydantic import BaseModel, Field
+
+from .base import R2RSerializable
+
+if TYPE_CHECKING:
+ from .search import AggregateSearchResult
+
+from typing_extensions import Literal
+
+
+class Function(BaseModel):
+ arguments: str
+ """
+ The arguments to call the function with, as generated by the model in JSON
+ format. Note that the model does not always generate valid JSON, and may
+ hallucinate parameters not defined by your function schema. Validate the
+ arguments in your code before calling your function.
+ """
+
+ name: str
+ """The name of the function to call."""
+
+
+class ChatCompletionMessageToolCall(BaseModel):
+ id: str
+ """The ID of the tool call."""
+
+ function: Function
+ """The function that the model called."""
+
+ type: Literal["function"]
+ """The type of the tool. Currently, only `function` is supported."""
+
+
+class FunctionCall(BaseModel):
+ arguments: str
+ """
+ The arguments to call the function with, as generated by the model in JSON
+ format. Note that the model does not always generate valid JSON, and may
+ hallucinate parameters not defined by your function schema. Validate the
+ arguments in your code before calling your function.
+ """
+
+ name: str
+ """The name of the function to call."""
+
+
+class ChatCompletionMessage(BaseModel):
+ content: Optional[str] = None
+ """The contents of the message."""
+
+ refusal: Optional[str] = None
+ """The refusal message generated by the model."""
+
+ role: Literal["assistant"]
+ """The role of the author of this message."""
+
+ # audio: Optional[ChatCompletionAudio] = None
+ """
+ If the audio output modality is requested, this object contains data about the
+ audio response from the model.
+ [Learn more](https://platform.openai.com/docs/guides/audio).
+ """
+
+ function_call: Optional[FunctionCall] = None
+ """Deprecated and replaced by `tool_calls`.
+
+ The name and arguments of a function that should be called, as generated by the
+ model.
+ """
+
+ tool_calls: Optional[list[ChatCompletionMessageToolCall]] = None
+ """The tool calls generated by the model, such as function calls."""
+
+ structured_content: Optional[list[dict]] = None
+
+
+class Choice(BaseModel):
+ finish_reason: Literal[
+ "stop",
+ "length",
+ "tool_calls",
+ "content_filter",
+ "function_call",
+ "max_tokens",
+ ]
+ """The reason the model stopped generating tokens.
+
+ This will be `stop` if the model hit a natural stop point or a provided stop
+ sequence, `length` if the maximum number of tokens specified in the request was
+ reached, `content_filter` if content was omitted due to a flag from our content
+ filters, `tool_calls` if the model called a tool, or `function_call`
+ (deprecated) if the model called a function.
+ """
+
+ index: int
+ """The index of the choice in the list of choices."""
+
+ # logprobs: Optional[ChoiceLogprobs] = None
+ """Log probability information for the choice."""
+
+ message: ChatCompletionMessage
+ """A chat completion message generated by the model."""
+
+
+class LLMChatCompletion(BaseModel):
+ id: str
+ """A unique identifier for the chat completion."""
+
+ choices: list[Choice]
+ """A list of chat completion choices.
+
+ Can be more than one if `n` is greater than 1.
+ """
+
+ created: int
+ """The Unix timestamp (in seconds) of when the chat completion was created."""
+
+ model: str
+ """The model used for the chat completion."""
+
+ object: Literal["chat.completion"]
+ """The object type, which is always `chat.completion`."""
+
+ service_tier: Optional[Literal["scale", "default"]] = None
+ """The service tier used for processing the request."""
+
+ system_fingerprint: Optional[str] = None
+ """This fingerprint represents the backend configuration that the model runs with.
+
+ Can be used in conjunction with the `seed` request parameter to understand when
+ backend changes have been made that might impact determinism.
+ """
+
+ usage: Optional[Any] = None
+ """Usage statistics for the completion request."""
+
+
+LLMChatCompletionChunk = ChatCompletionChunk
+
+
+class RAGCompletion:
+ completion: LLMChatCompletion
+ search_results: "AggregateSearchResult"
+
+ def __init__(
+ self,
+ completion: LLMChatCompletion,
+ search_results: "AggregateSearchResult",
+ ):
+ self.completion = completion
+ self.search_results = search_results
+
+
+class GenerationConfig(R2RSerializable):
+ _defaults: ClassVar[dict] = {
+ "model": None,
+ "temperature": 0.1,
+ "top_p": 1.0,
+ "max_tokens_to_sample": 1024,
+ "stream": False,
+ "functions": None,
+ "tools": None,
+ "add_generation_kwargs": None,
+ "api_base": None,
+ "response_format": None,
+ "extended_thinking": False,
+ "thinking_budget": None,
+ "reasoning_effort": None,
+ }
+
+ model: Optional[str] = Field(
+ default_factory=lambda: GenerationConfig._defaults["model"]
+ )
+ temperature: float = Field(
+ default_factory=lambda: GenerationConfig._defaults["temperature"]
+ )
+ top_p: Optional[float] = Field(
+ default_factory=lambda: GenerationConfig._defaults["top_p"],
+ )
+ max_tokens_to_sample: int = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "max_tokens_to_sample"
+ ],
+ )
+ stream: bool = Field(
+ default_factory=lambda: GenerationConfig._defaults["stream"]
+ )
+ functions: Optional[list[dict]] = Field(
+ default_factory=lambda: GenerationConfig._defaults["functions"]
+ )
+ tools: Optional[list[dict]] = Field(
+ default_factory=lambda: GenerationConfig._defaults["tools"]
+ )
+ add_generation_kwargs: Optional[dict] = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "add_generation_kwargs"
+ ],
+ )
+ api_base: Optional[str] = Field(
+ default_factory=lambda: GenerationConfig._defaults["api_base"],
+ )
+ response_format: Optional[dict | BaseModel] = None
+ extended_thinking: bool = Field(
+ default=False,
+ description="Flag to enable extended thinking mode (for Anthropic providers)",
+ )
+ thinking_budget: Optional[int] = Field(
+ default=None,
+ description=(
+ "Token budget for internal reasoning when extended thinking mode is enabled. "
+ "Must be less than max_tokens_to_sample."
+ ),
+ )
+ reasoning_effort: Optional[str] = Field(
+ default=None,
+ description=(
+ "Effort level for internal reasoning when extended thinking mode is enabled, `low`, `medium`, or `high`."
+ "Only applicable to OpenAI providers."
+ ),
+ )
+
+ @classmethod
+ def set_default(cls, **kwargs):
+ for key, value in kwargs.items():
+ if key in cls._defaults:
+ cls._defaults[key] = value
+ else:
+ raise AttributeError(
+ f"No default attribute '{key}' in GenerationConfig"
+ )
+
+ def __init__(self, **data):
+ # Handle max_tokens mapping to max_tokens_to_sample
+ if "max_tokens" in data:
+ # Only set max_tokens_to_sample if it's not already provided
+ if "max_tokens_to_sample" not in data:
+ data["max_tokens_to_sample"] = data.pop("max_tokens")
+ else:
+ # If both are provided, max_tokens_to_sample takes precedence
+ data.pop("max_tokens")
+
+ if (
+ "response_format" in data
+ and isinstance(data["response_format"], type)
+ and issubclass(data["response_format"], BaseModel)
+ ):
+ model_class = data["response_format"]
+ data["response_format"] = {
+ "type": "json_schema",
+ "json_schema": {
+ "name": model_class.__name__,
+ "schema": model_class.model_json_schema(),
+ },
+ }
+
+ model = data.pop("model", None)
+ if model is not None:
+ super().__init__(model=model, **data)
+ else:
+ super().__init__(**data)
+
+ def __str__(self):
+ return json.dumps(self.to_dict())
+
+ class Config:
+ populate_by_name = True
+ json_schema_extra = {
+ "example": {
+ "model": "openai/gpt-4o",
+ "temperature": 0.1,
+ "top_p": 1.0,
+ "max_tokens_to_sample": 1024,
+ "stream": False,
+ "functions": None,
+ "tools": None,
+ "add_generation_kwargs": None,
+ "api_base": None,
+ }
+ }
+
+
+class MessageType(Enum):
+ SYSTEM = "system"
+ USER = "user"
+ ASSISTANT = "assistant"
+ FUNCTION = "function"
+ TOOL = "tool"
+
+ def __str__(self):
+ return self.value
+
+
+class Message(R2RSerializable):
+ role: MessageType | str
+ content: Optional[Any] = None
+ name: Optional[str] = None
+ function_call: Optional[dict[str, Any]] = None
+ tool_calls: Optional[list[dict[str, Any]]] = None
+ tool_call_id: Optional[str] = None
+ metadata: Optional[dict[str, Any]] = None
+ structured_content: Optional[list[dict]] = None
+ image_url: Optional[str] = None # For URL-based images
+ image_data: Optional[dict[str, str]] = (
+ None # For base64 {media_type, data}
+ )
+
+ class Config:
+ populate_by_name = True
+ json_schema_extra = {
+ "example": {
+ "role": "user",
+ "content": "This is a test message.",
+ "name": None,
+ "function_call": None,
+ "tool_calls": None,
+ }
+ }