diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/abstractions/llm.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/shared/abstractions/llm.py | 325 |
1 files changed, 325 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py new file mode 100644 index 00000000..d71e279e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/llm.py @@ -0,0 +1,325 @@ +"""Abstractions for the LLM model.""" + +import json +from enum import Enum +from typing import TYPE_CHECKING, Any, ClassVar, Optional + +from openai.types.chat import ChatCompletionChunk +from pydantic import BaseModel, Field + +from .base import R2RSerializable + +if TYPE_CHECKING: + from .search import AggregateSearchResult + +from typing_extensions import Literal + + +class Function(BaseModel): + arguments: str + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + + name: str + """The name of the function to call.""" + + +class ChatCompletionMessageToolCall(BaseModel): + id: str + """The ID of the tool call.""" + + function: Function + """The function that the model called.""" + + type: Literal["function"] + """The type of the tool. Currently, only `function` is supported.""" + + +class FunctionCall(BaseModel): + arguments: str + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + + name: str + """The name of the function to call.""" + + +class ChatCompletionMessage(BaseModel): + content: Optional[str] = None + """The contents of the message.""" + + refusal: Optional[str] = None + """The refusal message generated by the model.""" + + role: Literal["assistant"] + """The role of the author of this message.""" + + # audio: Optional[ChatCompletionAudio] = None + """ + If the audio output modality is requested, this object contains data about the + audio response from the model. + [Learn more](https://platform.openai.com/docs/guides/audio). + """ + + function_call: Optional[FunctionCall] = None + """Deprecated and replaced by `tool_calls`. + + The name and arguments of a function that should be called, as generated by the + model. + """ + + tool_calls: Optional[list[ChatCompletionMessageToolCall]] = None + """The tool calls generated by the model, such as function calls.""" + + structured_content: Optional[list[dict]] = None + + +class Choice(BaseModel): + finish_reason: Literal[ + "stop", + "length", + "tool_calls", + "content_filter", + "function_call", + "max_tokens", + ] + """The reason the model stopped generating tokens. + + This will be `stop` if the model hit a natural stop point or a provided stop + sequence, `length` if the maximum number of tokens specified in the request was + reached, `content_filter` if content was omitted due to a flag from our content + filters, `tool_calls` if the model called a tool, or `function_call` + (deprecated) if the model called a function. + """ + + index: int + """The index of the choice in the list of choices.""" + + # logprobs: Optional[ChoiceLogprobs] = None + """Log probability information for the choice.""" + + message: ChatCompletionMessage + """A chat completion message generated by the model.""" + + +class LLMChatCompletion(BaseModel): + id: str + """A unique identifier for the chat completion.""" + + choices: list[Choice] + """A list of chat completion choices. + + Can be more than one if `n` is greater than 1. + """ + + created: int + """The Unix timestamp (in seconds) of when the chat completion was created.""" + + model: str + """The model used for the chat completion.""" + + object: Literal["chat.completion"] + """The object type, which is always `chat.completion`.""" + + service_tier: Optional[Literal["scale", "default"]] = None + """The service tier used for processing the request.""" + + system_fingerprint: Optional[str] = None + """This fingerprint represents the backend configuration that the model runs with. + + Can be used in conjunction with the `seed` request parameter to understand when + backend changes have been made that might impact determinism. + """ + + usage: Optional[Any] = None + """Usage statistics for the completion request.""" + + +LLMChatCompletionChunk = ChatCompletionChunk + + +class RAGCompletion: + completion: LLMChatCompletion + search_results: "AggregateSearchResult" + + def __init__( + self, + completion: LLMChatCompletion, + search_results: "AggregateSearchResult", + ): + self.completion = completion + self.search_results = search_results + + +class GenerationConfig(R2RSerializable): + _defaults: ClassVar[dict] = { + "model": None, + "temperature": 0.1, + "top_p": 1.0, + "max_tokens_to_sample": 1024, + "stream": False, + "functions": None, + "tools": None, + "add_generation_kwargs": None, + "api_base": None, + "response_format": None, + "extended_thinking": False, + "thinking_budget": None, + "reasoning_effort": None, + } + + model: Optional[str] = Field( + default_factory=lambda: GenerationConfig._defaults["model"] + ) + temperature: float = Field( + default_factory=lambda: GenerationConfig._defaults["temperature"] + ) + top_p: Optional[float] = Field( + default_factory=lambda: GenerationConfig._defaults["top_p"], + ) + max_tokens_to_sample: int = Field( + default_factory=lambda: GenerationConfig._defaults[ + "max_tokens_to_sample" + ], + ) + stream: bool = Field( + default_factory=lambda: GenerationConfig._defaults["stream"] + ) + functions: Optional[list[dict]] = Field( + default_factory=lambda: GenerationConfig._defaults["functions"] + ) + tools: Optional[list[dict]] = Field( + default_factory=lambda: GenerationConfig._defaults["tools"] + ) + add_generation_kwargs: Optional[dict] = Field( + default_factory=lambda: GenerationConfig._defaults[ + "add_generation_kwargs" + ], + ) + api_base: Optional[str] = Field( + default_factory=lambda: GenerationConfig._defaults["api_base"], + ) + response_format: Optional[dict | BaseModel] = None + extended_thinking: bool = Field( + default=False, + description="Flag to enable extended thinking mode (for Anthropic providers)", + ) + thinking_budget: Optional[int] = Field( + default=None, + description=( + "Token budget for internal reasoning when extended thinking mode is enabled. " + "Must be less than max_tokens_to_sample." + ), + ) + reasoning_effort: Optional[str] = Field( + default=None, + description=( + "Effort level for internal reasoning when extended thinking mode is enabled, `low`, `medium`, or `high`." + "Only applicable to OpenAI providers." + ), + ) + + @classmethod + def set_default(cls, **kwargs): + for key, value in kwargs.items(): + if key in cls._defaults: + cls._defaults[key] = value + else: + raise AttributeError( + f"No default attribute '{key}' in GenerationConfig" + ) + + def __init__(self, **data): + # Handle max_tokens mapping to max_tokens_to_sample + if "max_tokens" in data: + # Only set max_tokens_to_sample if it's not already provided + if "max_tokens_to_sample" not in data: + data["max_tokens_to_sample"] = data.pop("max_tokens") + else: + # If both are provided, max_tokens_to_sample takes precedence + data.pop("max_tokens") + + if ( + "response_format" in data + and isinstance(data["response_format"], type) + and issubclass(data["response_format"], BaseModel) + ): + model_class = data["response_format"] + data["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": model_class.__name__, + "schema": model_class.model_json_schema(), + }, + } + + model = data.pop("model", None) + if model is not None: + super().__init__(model=model, **data) + else: + super().__init__(**data) + + def __str__(self): + return json.dumps(self.to_dict()) + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "model": "openai/gpt-4o", + "temperature": 0.1, + "top_p": 1.0, + "max_tokens_to_sample": 1024, + "stream": False, + "functions": None, + "tools": None, + "add_generation_kwargs": None, + "api_base": None, + } + } + + +class MessageType(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + TOOL = "tool" + + def __str__(self): + return self.value + + +class Message(R2RSerializable): + role: MessageType | str + content: Optional[Any] = None + name: Optional[str] = None + function_call: Optional[dict[str, Any]] = None + tool_calls: Optional[list[dict[str, Any]]] = None + tool_call_id: Optional[str] = None + metadata: Optional[dict[str, Any]] = None + structured_content: Optional[list[dict]] = None + image_url: Optional[str] = None # For URL-based images + image_data: Optional[dict[str, str]] = ( + None # For base64 {media_type, data} + ) + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "role": "user", + "content": "This is a test message.", + "name": None, + "function_call": None, + "tool_calls": None, + } + } |