about summary refs log tree commit diff
path: root/R2R/r2r/base/abstractions/llm.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/base/abstractions/llm.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/base/abstractions/llm.py')
-rwxr-xr-xR2R/r2r/base/abstractions/llm.py112
1 files changed, 112 insertions, 0 deletions
diff --git a/R2R/r2r/base/abstractions/llm.py b/R2R/r2r/base/abstractions/llm.py
new file mode 100755
index 00000000..3178d8dc
--- /dev/null
+++ b/R2R/r2r/base/abstractions/llm.py
@@ -0,0 +1,112 @@
+"""Abstractions for the LLM model."""
+
+from typing import TYPE_CHECKING, ClassVar, Optional
+
+from openai.types.chat import ChatCompletion, ChatCompletionChunk
+from pydantic import BaseModel, Field
+
+if TYPE_CHECKING:
+    from .search import AggregateSearchResult
+
+LLMChatCompletion = ChatCompletion
+LLMChatCompletionChunk = ChatCompletionChunk
+
+
+class RAGCompletion:
+    completion: LLMChatCompletion
+    search_results: "AggregateSearchResult"
+
+    def __init__(
+        self,
+        completion: LLMChatCompletion,
+        search_results: "AggregateSearchResult",
+    ):
+        self.completion = completion
+        self.search_results = search_results
+
+
+class GenerationConfig(BaseModel):
+    _defaults: ClassVar[dict] = {
+        "model": "gpt-4o",
+        "temperature": 0.1,
+        "top_p": 1.0,
+        "top_k": 100,
+        "max_tokens_to_sample": 1024,
+        "stream": False,
+        "functions": None,
+        "skip_special_tokens": False,
+        "stop_token": None,
+        "num_beams": 1,
+        "do_sample": True,
+        "generate_with_chat": False,
+        "add_generation_kwargs": None,
+        "api_base": None,
+    }
+
+    model: str = Field(
+        default_factory=lambda: GenerationConfig._defaults["model"]
+    )
+    temperature: float = Field(
+        default_factory=lambda: GenerationConfig._defaults["temperature"]
+    )
+    top_p: float = Field(
+        default_factory=lambda: GenerationConfig._defaults["top_p"]
+    )
+    top_k: int = Field(
+        default_factory=lambda: GenerationConfig._defaults["top_k"]
+    )
+    max_tokens_to_sample: int = Field(
+        default_factory=lambda: GenerationConfig._defaults[
+            "max_tokens_to_sample"
+        ]
+    )
+    stream: bool = Field(
+        default_factory=lambda: GenerationConfig._defaults["stream"]
+    )
+    functions: Optional[list[dict]] = Field(
+        default_factory=lambda: GenerationConfig._defaults["functions"]
+    )
+    skip_special_tokens: bool = Field(
+        default_factory=lambda: GenerationConfig._defaults[
+            "skip_special_tokens"
+        ]
+    )
+    stop_token: Optional[str] = Field(
+        default_factory=lambda: GenerationConfig._defaults["stop_token"]
+    )
+    num_beams: int = Field(
+        default_factory=lambda: GenerationConfig._defaults["num_beams"]
+    )
+    do_sample: bool = Field(
+        default_factory=lambda: GenerationConfig._defaults["do_sample"]
+    )
+    generate_with_chat: bool = Field(
+        default_factory=lambda: GenerationConfig._defaults[
+            "generate_with_chat"
+        ]
+    )
+    add_generation_kwargs: Optional[dict] = Field(
+        default_factory=lambda: GenerationConfig._defaults[
+            "add_generation_kwargs"
+        ]
+    )
+    api_base: Optional[str] = Field(
+        default_factory=lambda: GenerationConfig._defaults["api_base"]
+    )
+
+    @classmethod
+    def set_default(cls, **kwargs):
+        for key, value in kwargs.items():
+            if key in cls._defaults:
+                cls._defaults[key] = value
+            else:
+                raise AttributeError(
+                    f"No default attribute '{key}' in GenerationConfig"
+                )
+
+    def __init__(self, **data):
+        model = data.pop("model", None)
+        if model is not None:
+            super().__init__(model=model, **data)
+        else:
+            super().__init__(**data)