aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/base/providers/llm.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/providers/llm.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/llm.py200
1 files changed, 200 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/llm.py b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py
new file mode 100644
index 00000000..669dfc4f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py
@@ -0,0 +1,200 @@
+import asyncio
+import logging
+import random
+import time
+from abc import abstractmethod
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, AsyncGenerator, Generator, Optional
+
+from litellm import AuthenticationError
+
+from core.base.abstractions import (
+ GenerationConfig,
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+)
+
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class CompletionConfig(ProviderConfig):
+ provider: Optional[str] = None
+ generation_config: Optional[GenerationConfig] = None
+ concurrent_request_limit: int = 256
+ max_retries: int = 3
+ initial_backoff: float = 1.0
+ max_backoff: float = 64.0
+
+ def validate_config(self) -> None:
+ if not self.provider:
+ raise ValueError("Provider must be set.")
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["anthropic", "litellm", "openai", "r2r"]
+
+
+class CompletionProvider(Provider):
+ def __init__(self, config: CompletionConfig) -> None:
+ if not isinstance(config, CompletionConfig):
+ raise ValueError(
+ "CompletionProvider must be initialized with a `CompletionConfig`."
+ )
+ logger.info(f"Initializing CompletionProvider with config: {config}")
+ super().__init__(config)
+ self.config: CompletionConfig = config
+ self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
+ self.thread_pool = ThreadPoolExecutor(
+ max_workers=config.concurrent_request_limit
+ )
+
+ async def _execute_with_backoff_async(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ async with self.semaphore:
+ return await self._execute_task(task)
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ await asyncio.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ async def _execute_with_backoff_async_stream(
+ self, task: dict[str, Any]
+ ) -> AsyncGenerator[Any, None]:
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ async with self.semaphore:
+ async for chunk in await self._execute_task(task):
+ yield chunk
+ return # Successful completion of the stream
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Streaming request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ await asyncio.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ def _execute_with_backoff_sync(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ return self._execute_task_sync(task)
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ time.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ def _execute_with_backoff_sync_stream(
+ self, task: dict[str, Any]
+ ) -> Generator[Any, None, None]:
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ yield from self._execute_task_sync(task)
+ return # Successful completion of the stream
+ except Exception as e:
+ logger.warning(
+ f"Streaming request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ time.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ @abstractmethod
+ async def _execute_task(self, task: dict[str, Any]):
+ pass
+
+ @abstractmethod
+ def _execute_task_sync(self, task: dict[str, Any]):
+ pass
+
+ async def aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ task = {
+ "messages": messages,
+ "generation_config": generation_config,
+ "kwargs": kwargs,
+ }
+ response = await self._execute_with_backoff_async(task)
+ return LLMChatCompletion(**response.dict())
+
+ async def aget_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> AsyncGenerator[LLMChatCompletionChunk, None]:
+ generation_config.stream = True
+ task = {
+ "messages": messages,
+ "generation_config": generation_config,
+ "kwargs": kwargs,
+ }
+ async for chunk in self._execute_with_backoff_async_stream(task):
+ if isinstance(chunk, dict):
+ yield LLMChatCompletionChunk(**chunk)
+ continue
+
+ chunk.choices[0].finish_reason = (
+ chunk.choices[0].finish_reason
+ if chunk.choices[0].finish_reason != ""
+ else None
+ ) # handle error output conventions
+ chunk.choices[0].finish_reason = (
+ chunk.choices[0].finish_reason
+ if chunk.choices[0].finish_reason != "eos"
+ else "stop"
+ ) # hardcode `eos` to `stop` for consistency
+ try:
+ yield LLMChatCompletionChunk(**(chunk.dict()))
+ except Exception as e:
+ logger.error(f"Error parsing chunk: {e}")
+ yield LLMChatCompletionChunk(**(chunk.as_dict()))
+
+ def get_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Generator[LLMChatCompletionChunk, None, None]:
+ generation_config.stream = True
+ task = {
+ "messages": messages,
+ "generation_config": generation_config,
+ "kwargs": kwargs,
+ }
+ for chunk in self._execute_with_backoff_sync_stream(task):
+ yield LLMChatCompletionChunk(**chunk.dict())