diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/base/abstractions/llm.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to 'R2R/r2r/base/abstractions/llm.py')
-rwxr-xr-x | R2R/r2r/base/abstractions/llm.py | 112 |
1 files changed, 112 insertions, 0 deletions
diff --git a/R2R/r2r/base/abstractions/llm.py b/R2R/r2r/base/abstractions/llm.py new file mode 100755 index 00000000..3178d8dc --- /dev/null +++ b/R2R/r2r/base/abstractions/llm.py @@ -0,0 +1,112 @@ +"""Abstractions for the LLM model.""" + +from typing import TYPE_CHECKING, ClassVar, Optional + +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from .search import AggregateSearchResult + +LLMChatCompletion = ChatCompletion +LLMChatCompletionChunk = ChatCompletionChunk + + +class RAGCompletion: + completion: LLMChatCompletion + search_results: "AggregateSearchResult" + + def __init__( + self, + completion: LLMChatCompletion, + search_results: "AggregateSearchResult", + ): + self.completion = completion + self.search_results = search_results + + +class GenerationConfig(BaseModel): + _defaults: ClassVar[dict] = { + "model": "gpt-4o", + "temperature": 0.1, + "top_p": 1.0, + "top_k": 100, + "max_tokens_to_sample": 1024, + "stream": False, + "functions": None, + "skip_special_tokens": False, + "stop_token": None, + "num_beams": 1, + "do_sample": True, + "generate_with_chat": False, + "add_generation_kwargs": None, + "api_base": None, + } + + model: str = Field( + default_factory=lambda: GenerationConfig._defaults["model"] + ) + temperature: float = Field( + default_factory=lambda: GenerationConfig._defaults["temperature"] + ) + top_p: float = Field( + default_factory=lambda: GenerationConfig._defaults["top_p"] + ) + top_k: int = Field( + default_factory=lambda: GenerationConfig._defaults["top_k"] + ) + max_tokens_to_sample: int = Field( + default_factory=lambda: GenerationConfig._defaults[ + "max_tokens_to_sample" + ] + ) + stream: bool = Field( + default_factory=lambda: GenerationConfig._defaults["stream"] + ) + functions: Optional[list[dict]] = Field( + default_factory=lambda: GenerationConfig._defaults["functions"] + ) + skip_special_tokens: bool = Field( + default_factory=lambda: GenerationConfig._defaults[ + "skip_special_tokens" + ] + ) + stop_token: Optional[str] = Field( + default_factory=lambda: GenerationConfig._defaults["stop_token"] + ) + num_beams: int = Field( + default_factory=lambda: GenerationConfig._defaults["num_beams"] + ) + do_sample: bool = Field( + default_factory=lambda: GenerationConfig._defaults["do_sample"] + ) + generate_with_chat: bool = Field( + default_factory=lambda: GenerationConfig._defaults[ + "generate_with_chat" + ] + ) + add_generation_kwargs: Optional[dict] = Field( + default_factory=lambda: GenerationConfig._defaults[ + "add_generation_kwargs" + ] + ) + api_base: Optional[str] = Field( + default_factory=lambda: GenerationConfig._defaults["api_base"] + ) + + @classmethod + def set_default(cls, **kwargs): + for key, value in kwargs.items(): + if key in cls._defaults: + cls._defaults[key] = value + else: + raise AttributeError( + f"No default attribute '{key}' in GenerationConfig" + ) + + def __init__(self, **data): + model = data.pop("model", None) + if model is not None: + super().__init__(model=model, **data) + else: + super().__init__(**data) |