aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base/abstractions/llm.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/base/abstractions/llm.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/base/abstractions/llm.py')
-rwxr-xr-xR2R/r2r/base/abstractions/llm.py112
1 files changed, 112 insertions, 0 deletions
diff --git a/R2R/r2r/base/abstractions/llm.py b/R2R/r2r/base/abstractions/llm.py
new file mode 100755
index 00000000..3178d8dc
--- /dev/null
+++ b/R2R/r2r/base/abstractions/llm.py
@@ -0,0 +1,112 @@
+"""Abstractions for the LLM model."""
+
+from typing import TYPE_CHECKING, ClassVar, Optional
+
+from openai.types.chat import ChatCompletion, ChatCompletionChunk
+from pydantic import BaseModel, Field
+
+if TYPE_CHECKING:
+ from .search import AggregateSearchResult
+
+LLMChatCompletion = ChatCompletion
+LLMChatCompletionChunk = ChatCompletionChunk
+
+
+class RAGCompletion:
+ completion: LLMChatCompletion
+ search_results: "AggregateSearchResult"
+
+ def __init__(
+ self,
+ completion: LLMChatCompletion,
+ search_results: "AggregateSearchResult",
+ ):
+ self.completion = completion
+ self.search_results = search_results
+
+
+class GenerationConfig(BaseModel):
+ _defaults: ClassVar[dict] = {
+ "model": "gpt-4o",
+ "temperature": 0.1,
+ "top_p": 1.0,
+ "top_k": 100,
+ "max_tokens_to_sample": 1024,
+ "stream": False,
+ "functions": None,
+ "skip_special_tokens": False,
+ "stop_token": None,
+ "num_beams": 1,
+ "do_sample": True,
+ "generate_with_chat": False,
+ "add_generation_kwargs": None,
+ "api_base": None,
+ }
+
+ model: str = Field(
+ default_factory=lambda: GenerationConfig._defaults["model"]
+ )
+ temperature: float = Field(
+ default_factory=lambda: GenerationConfig._defaults["temperature"]
+ )
+ top_p: float = Field(
+ default_factory=lambda: GenerationConfig._defaults["top_p"]
+ )
+ top_k: int = Field(
+ default_factory=lambda: GenerationConfig._defaults["top_k"]
+ )
+ max_tokens_to_sample: int = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "max_tokens_to_sample"
+ ]
+ )
+ stream: bool = Field(
+ default_factory=lambda: GenerationConfig._defaults["stream"]
+ )
+ functions: Optional[list[dict]] = Field(
+ default_factory=lambda: GenerationConfig._defaults["functions"]
+ )
+ skip_special_tokens: bool = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "skip_special_tokens"
+ ]
+ )
+ stop_token: Optional[str] = Field(
+ default_factory=lambda: GenerationConfig._defaults["stop_token"]
+ )
+ num_beams: int = Field(
+ default_factory=lambda: GenerationConfig._defaults["num_beams"]
+ )
+ do_sample: bool = Field(
+ default_factory=lambda: GenerationConfig._defaults["do_sample"]
+ )
+ generate_with_chat: bool = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "generate_with_chat"
+ ]
+ )
+ add_generation_kwargs: Optional[dict] = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "add_generation_kwargs"
+ ]
+ )
+ api_base: Optional[str] = Field(
+ default_factory=lambda: GenerationConfig._defaults["api_base"]
+ )
+
+ @classmethod
+ def set_default(cls, **kwargs):
+ for key, value in kwargs.items():
+ if key in cls._defaults:
+ cls._defaults[key] = value
+ else:
+ raise AttributeError(
+ f"No default attribute '{key}' in GenerationConfig"
+ )
+
+ def __init__(self, **data):
+ model = data.pop("model", None)
+ if model is not None:
+ super().__init__(model=model, **data)
+ else:
+ super().__init__(**data)