aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/base_llm/base_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/base_llm/base_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/base_llm/base_utils.py142
1 files changed, 142 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/base_llm/base_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/base_llm/base_utils.py
new file mode 100644
index 00000000..919cdbfd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/base_llm/base_utils.py
@@ -0,0 +1,142 @@
+"""
+Utility functions for base LLM classes.
+"""
+
+import copy
+from abc import ABC, abstractmethod
+from typing import List, Optional, Type, Union
+
+from openai.lib import _parsing, _pydantic
+from pydantic import BaseModel
+
+from litellm._logging import verbose_logger
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import ProviderSpecificModelInfo
+
+
+class BaseLLMModelInfo(ABC):
+ def get_provider_info(
+ self,
+ model: str,
+ ) -> Optional[ProviderSpecificModelInfo]:
+ return None
+
+ @abstractmethod
+ def get_models(self) -> List[str]:
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def get_base_model(model: str) -> Optional[str]:
+ """
+ Returns the base model name from the given model name.
+
+ Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
+ This function will return `anthropic.claude-3-opus-20240229-v1:0`
+ """
+ pass
+
+
+def _dict_to_response_format_helper(
+ response_format: dict, ref_template: Optional[str] = None
+) -> dict:
+ if ref_template is not None and response_format.get("type") == "json_schema":
+ # Deep copy to avoid modifying original
+ modified_format = copy.deepcopy(response_format)
+ schema = modified_format["json_schema"]["schema"]
+
+ # Update all $ref values in the schema
+ def update_refs(schema):
+ stack = [(schema, [])]
+ visited = set()
+
+ while stack:
+ obj, path = stack.pop()
+ obj_id = id(obj)
+
+ if obj_id in visited:
+ continue
+ visited.add(obj_id)
+
+ if isinstance(obj, dict):
+ if "$ref" in obj:
+ ref_path = obj["$ref"]
+ model_name = ref_path.split("/")[-1]
+ obj["$ref"] = ref_template.format(model=model_name)
+
+ for k, v in obj.items():
+ if isinstance(v, (dict, list)):
+ stack.append((v, path + [k]))
+
+ elif isinstance(obj, list):
+ for i, item in enumerate(obj):
+ if isinstance(item, (dict, list)):
+ stack.append((item, path + [i]))
+
+ update_refs(schema)
+ return modified_format
+ return response_format
+
+
+def type_to_response_format_param(
+ response_format: Optional[Union[Type[BaseModel], dict]],
+ ref_template: Optional[str] = None,
+) -> Optional[dict]:
+ """
+ Re-implementation of openai's 'type_to_response_format_param' function
+
+ Used for converting pydantic object to api schema.
+ """
+ if response_format is None:
+ return None
+
+ if isinstance(response_format, dict):
+ return _dict_to_response_format_helper(response_format, ref_template)
+
+ # type checkers don't narrow the negation of a `TypeGuard` as it isn't
+ # a safe default behaviour but we know that at this point the `response_format`
+ # can only be a `type`
+ if not _parsing._completions.is_basemodel_type(response_format):
+ raise TypeError(f"Unsupported response_format type - {response_format}")
+
+ if ref_template is not None:
+ schema = response_format.model_json_schema(ref_template=ref_template)
+ else:
+ schema = _pydantic.to_strict_json_schema(response_format)
+
+ return {
+ "type": "json_schema",
+ "json_schema": {
+ "schema": schema,
+ "name": response_format.__name__,
+ "strict": True,
+ },
+ }
+
+
+def map_developer_role_to_system_role(
+ messages: List[AllMessageValues],
+) -> List[AllMessageValues]:
+ """
+ Translate `developer` role to `system` role for non-OpenAI providers.
+ """
+ new_messages: List[AllMessageValues] = []
+ for m in messages:
+ if m["role"] == "developer":
+ verbose_logger.debug(
+ "Translating developer role to system role for non-OpenAI providers."
+ ) # ensure user knows what's happening with their input.
+ new_messages.append({"role": "system", "content": m["content"]})
+ else:
+ new_messages.append(m)
+ return new_messages