import logging from typing import Any import litellm from litellm import acompletion, completion from core.base.abstractions import GenerationConfig from core.base.providers.llm import CompletionConfig, CompletionProvider logger = logging.getLogger() class LiteLLMCompletionProvider(CompletionProvider): def __init__(self, config: CompletionConfig, *args, **kwargs) -> None: super().__init__(config) litellm.modify_params = True self.acompletion = acompletion self.completion = completion # if config.provider != "litellm": # logger.error(f"Invalid provider: {config.provider}") # raise ValueError( # "LiteLLMCompletionProvider must be initialized with config with `litellm` provider." # ) def _get_base_args( self, generation_config: GenerationConfig ) -> dict[str, Any]: args: dict[str, Any] = { "model": generation_config.model, "temperature": generation_config.temperature, "top_p": generation_config.top_p, "stream": generation_config.stream, "max_tokens": generation_config.max_tokens_to_sample, "api_base": generation_config.api_base, } # Fix the type errors by properly typing these assignments if generation_config.functions is not None: args["functions"] = generation_config.functions if generation_config.tools is not None: args["tools"] = generation_config.tools if generation_config.response_format is not None: args["response_format"] = generation_config.response_format return args async def _execute_task(self, task: dict[str, Any]): messages = task["messages"] generation_config = task["generation_config"] kwargs = task["kwargs"] args = self._get_base_args(generation_config) args["messages"] = messages args = {**args, **kwargs} logger.debug( f"Executing LiteLLM task with generation_config={generation_config}" ) return await self.acompletion(**args) def _execute_task_sync(self, task: dict[str, Any]): messages = task["messages"] generation_config = task["generation_config"] kwargs = task["kwargs"] args = self._get_base_args(generation_config) args["messages"] = messages args = {**args, **kwargs} logger.debug( f"Executing LiteLLM task with generation_config={generation_config}" ) try: return self.completion(**args) except Exception as e: logger.error(f"Sync LiteLLM task execution failed: {str(e)}") raise