about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/base/providers/llm.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/providers/llm.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/llm.py200
1 files changed, 200 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/llm.py b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py
new file mode 100644
index 00000000..669dfc4f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py
@@ -0,0 +1,200 @@
+import asyncio
+import logging
+import random
+import time
+from abc import abstractmethod
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, AsyncGenerator, Generator, Optional
+
+from litellm import AuthenticationError
+
+from core.base.abstractions import (
+    GenerationConfig,
+    LLMChatCompletion,
+    LLMChatCompletionChunk,
+)
+
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class CompletionConfig(ProviderConfig):
+    provider: Optional[str] = None
+    generation_config: Optional[GenerationConfig] = None
+    concurrent_request_limit: int = 256
+    max_retries: int = 3
+    initial_backoff: float = 1.0
+    max_backoff: float = 64.0
+
+    def validate_config(self) -> None:
+        if not self.provider:
+            raise ValueError("Provider must be set.")
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["anthropic", "litellm", "openai", "r2r"]
+
+
+class CompletionProvider(Provider):
+    def __init__(self, config: CompletionConfig) -> None:
+        if not isinstance(config, CompletionConfig):
+            raise ValueError(
+                "CompletionProvider must be initialized with a `CompletionConfig`."
+            )
+        logger.info(f"Initializing CompletionProvider with config: {config}")
+        super().__init__(config)
+        self.config: CompletionConfig = config
+        self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
+        self.thread_pool = ThreadPoolExecutor(
+            max_workers=config.concurrent_request_limit
+        )
+
+    async def _execute_with_backoff_async(self, task: dict[str, Any]):
+        retries = 0
+        backoff = self.config.initial_backoff
+        while retries < self.config.max_retries:
+            try:
+                async with self.semaphore:
+                    return await self._execute_task(task)
+            except AuthenticationError:
+                raise
+            except Exception as e:
+                logger.warning(
+                    f"Request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.config.max_retries:
+                    raise
+                await asyncio.sleep(random.uniform(0, backoff))
+                backoff = min(backoff * 2, self.config.max_backoff)
+
+    async def _execute_with_backoff_async_stream(
+        self, task: dict[str, Any]
+    ) -> AsyncGenerator[Any, None]:
+        retries = 0
+        backoff = self.config.initial_backoff
+        while retries < self.config.max_retries:
+            try:
+                async with self.semaphore:
+                    async for chunk in await self._execute_task(task):
+                        yield chunk
+                return  # Successful completion of the stream
+            except AuthenticationError:
+                raise
+            except Exception as e:
+                logger.warning(
+                    f"Streaming request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.config.max_retries:
+                    raise
+                await asyncio.sleep(random.uniform(0, backoff))
+                backoff = min(backoff * 2, self.config.max_backoff)
+
+    def _execute_with_backoff_sync(self, task: dict[str, Any]):
+        retries = 0
+        backoff = self.config.initial_backoff
+        while retries < self.config.max_retries:
+            try:
+                return self._execute_task_sync(task)
+            except Exception as e:
+                logger.warning(
+                    f"Request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.config.max_retries:
+                    raise
+                time.sleep(random.uniform(0, backoff))
+                backoff = min(backoff * 2, self.config.max_backoff)
+
+    def _execute_with_backoff_sync_stream(
+        self, task: dict[str, Any]
+    ) -> Generator[Any, None, None]:
+        retries = 0
+        backoff = self.config.initial_backoff
+        while retries < self.config.max_retries:
+            try:
+                yield from self._execute_task_sync(task)
+                return  # Successful completion of the stream
+            except Exception as e:
+                logger.warning(
+                    f"Streaming request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.config.max_retries:
+                    raise
+                time.sleep(random.uniform(0, backoff))
+                backoff = min(backoff * 2, self.config.max_backoff)
+
+    @abstractmethod
+    async def _execute_task(self, task: dict[str, Any]):
+        pass
+
+    @abstractmethod
+    def _execute_task_sync(self, task: dict[str, Any]):
+        pass
+
+    async def aget_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletion:
+        task = {
+            "messages": messages,
+            "generation_config": generation_config,
+            "kwargs": kwargs,
+        }
+        response = await self._execute_with_backoff_async(task)
+        return LLMChatCompletion(**response.dict())
+
+    async def aget_completion_stream(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> AsyncGenerator[LLMChatCompletionChunk, None]:
+        generation_config.stream = True
+        task = {
+            "messages": messages,
+            "generation_config": generation_config,
+            "kwargs": kwargs,
+        }
+        async for chunk in self._execute_with_backoff_async_stream(task):
+            if isinstance(chunk, dict):
+                yield LLMChatCompletionChunk(**chunk)
+                continue
+
+            chunk.choices[0].finish_reason = (
+                chunk.choices[0].finish_reason
+                if chunk.choices[0].finish_reason != ""
+                else None
+            )  # handle error output conventions
+            chunk.choices[0].finish_reason = (
+                chunk.choices[0].finish_reason
+                if chunk.choices[0].finish_reason != "eos"
+                else "stop"
+            )  # hardcode `eos` to `stop` for consistency
+            try:
+                yield LLMChatCompletionChunk(**(chunk.dict()))
+            except Exception as e:
+                logger.error(f"Error parsing chunk: {e}")
+                yield LLMChatCompletionChunk(**(chunk.as_dict()))
+
+    def get_completion_stream(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Generator[LLMChatCompletionChunk, None, None]:
+        generation_config.stream = True
+        task = {
+            "messages": messages,
+            "generation_config": generation_config,
+            "kwargs": kwargs,
+        }
+        for chunk in self._execute_with_backoff_sync_stream(task):
+            yield LLMChatCompletionChunk(**chunk.dict())