aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/fireworks_ai/chat/transformation.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/fireworks_ai/chat/transformation.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/fireworks_ai/chat/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/fireworks_ai/chat/transformation.py262
1 files changed, 262 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/fireworks_ai/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/fireworks_ai/chat/transformation.py
new file mode 100644
index 00000000..1c82f24a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/fireworks_ai/chat/transformation.py
@@ -0,0 +1,262 @@
+from typing import List, Literal, Optional, Tuple, Union, cast
+
+import litellm
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject
+from litellm.types.utils import ProviderSpecificModelInfo
+
+from ...openai.chat.gpt_transformation import OpenAIGPTConfig
+
+
+class FireworksAIConfig(OpenAIGPTConfig):
+ """
+ Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions
+
+ The class `FireworksAIConfig` provides configuration for the Fireworks's Chat Completions API interface. Below are the parameters:
+ """
+
+ tools: Optional[list] = None
+ tool_choice: Optional[Union[str, dict]] = None
+ max_tokens: Optional[int] = None
+ temperature: Optional[int] = None
+ top_p: Optional[int] = None
+ top_k: Optional[int] = None
+ frequency_penalty: Optional[int] = None
+ presence_penalty: Optional[int] = None
+ n: Optional[int] = None
+ stop: Optional[Union[str, list]] = None
+ response_format: Optional[dict] = None
+ user: Optional[str] = None
+ logprobs: Optional[int] = None
+
+ # Non OpenAI parameters - Fireworks AI only params
+ prompt_truncate_length: Optional[int] = None
+ context_length_exceeded_behavior: Optional[Literal["error", "truncate"]] = None
+
+ def __init__(
+ self,
+ tools: Optional[list] = None,
+ tool_choice: Optional[Union[str, dict]] = None,
+ max_tokens: Optional[int] = None,
+ temperature: Optional[int] = None,
+ top_p: Optional[int] = None,
+ top_k: Optional[int] = None,
+ frequency_penalty: Optional[int] = None,
+ presence_penalty: Optional[int] = None,
+ n: Optional[int] = None,
+ stop: Optional[Union[str, list]] = None,
+ response_format: Optional[dict] = None,
+ user: Optional[str] = None,
+ logprobs: Optional[int] = None,
+ prompt_truncate_length: Optional[int] = None,
+ context_length_exceeded_behavior: Optional[Literal["error", "truncate"]] = None,
+ ) -> None:
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return super().get_config()
+
+ def get_supported_openai_params(self, model: str):
+ return [
+ "stream",
+ "tools",
+ "tool_choice",
+ "max_completion_tokens",
+ "max_tokens",
+ "temperature",
+ "top_p",
+ "top_k",
+ "frequency_penalty",
+ "presence_penalty",
+ "n",
+ "stop",
+ "response_format",
+ "user",
+ "logprobs",
+ "prompt_truncate_length",
+ "context_length_exceeded_behavior",
+ ]
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
+
+ supported_openai_params = self.get_supported_openai_params(model=model)
+ is_tools_set = any(
+ param == "tools" and value is not None
+ for param, value in non_default_params.items()
+ )
+
+ for param, value in non_default_params.items():
+ if param == "tool_choice":
+ if value == "required":
+ # relevant issue: https://github.com/BerriAI/litellm/issues/4416
+ optional_params["tool_choice"] = "any"
+ else:
+ # pass through the value of tool choice
+ optional_params["tool_choice"] = value
+ elif param == "response_format":
+
+ if (
+ is_tools_set
+ ): # fireworks ai doesn't support tools and response_format together
+ optional_params = self._add_response_format_to_tools(
+ optional_params=optional_params,
+ value=value,
+ is_response_format_supported=False,
+ enforce_tool_choice=False, # tools and response_format are both set, don't enforce tool_choice
+ )
+ elif "json_schema" in value:
+ optional_params["response_format"] = {
+ "type": "json_object",
+ "schema": value["json_schema"]["schema"],
+ }
+ else:
+ optional_params["response_format"] = value
+ elif param == "max_completion_tokens":
+ optional_params["max_tokens"] = value
+ elif param in supported_openai_params:
+ if value is not None:
+ optional_params[param] = value
+
+ return optional_params
+
+ def _add_transform_inline_image_block(
+ self,
+ content: ChatCompletionImageObject,
+ model: str,
+ disable_add_transform_inline_image_block: Optional[bool],
+ ) -> ChatCompletionImageObject:
+ """
+ Add transform_inline to the image_url (allows non-vision models to parse documents/images/etc.)
+ - ignore if model is a vision model
+ - ignore if user has disabled this feature
+ """
+ if (
+ "vision" in model or disable_add_transform_inline_image_block
+ ): # allow user to toggle this feature.
+ return content
+ if isinstance(content["image_url"], str):
+ content["image_url"] = f"{content['image_url']}#transform=inline"
+ elif isinstance(content["image_url"], dict):
+ content["image_url"][
+ "url"
+ ] = f"{content['image_url']['url']}#transform=inline"
+ return content
+
+ def _transform_messages_helper(
+ self, messages: List[AllMessageValues], model: str, litellm_params: dict
+ ) -> List[AllMessageValues]:
+ """
+ Add 'transform=inline' to the url of the image_url
+ """
+ disable_add_transform_inline_image_block = cast(
+ Optional[bool],
+ litellm_params.get("disable_add_transform_inline_image_block")
+ or litellm.disable_add_transform_inline_image_block,
+ )
+ for message in messages:
+ if message["role"] == "user":
+ _message_content = message.get("content")
+ if _message_content is not None and isinstance(_message_content, list):
+ for content in _message_content:
+ if content["type"] == "image_url":
+ content = self._add_transform_inline_image_block(
+ content=content,
+ model=model,
+ disable_add_transform_inline_image_block=disable_add_transform_inline_image_block,
+ )
+ return messages
+
+ def get_provider_info(self, model: str) -> ProviderSpecificModelInfo:
+ provider_specific_model_info = ProviderSpecificModelInfo(
+ supports_function_calling=True,
+ supports_prompt_caching=True, # https://docs.fireworks.ai/guides/prompt-caching
+ supports_pdf_input=True, # via document inlining
+ supports_vision=True, # via document inlining
+ )
+ return provider_specific_model_info
+
+ def transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ headers: dict,
+ ) -> dict:
+ if not model.startswith("accounts/"):
+ model = f"accounts/fireworks/models/{model}"
+ messages = self._transform_messages_helper(
+ messages=messages, model=model, litellm_params=litellm_params
+ )
+ return super().transform_request(
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ headers=headers,
+ )
+
+ def _get_openai_compatible_provider_info(
+ self, api_base: Optional[str], api_key: Optional[str]
+ ) -> Tuple[Optional[str], Optional[str]]:
+ api_base = (
+ api_base
+ or get_secret_str("FIREWORKS_API_BASE")
+ or "https://api.fireworks.ai/inference/v1"
+ ) # type: ignore
+ dynamic_api_key = api_key or (
+ get_secret_str("FIREWORKS_API_KEY")
+ or get_secret_str("FIREWORKS_AI_API_KEY")
+ or get_secret_str("FIREWORKSAI_API_KEY")
+ or get_secret_str("FIREWORKS_AI_TOKEN")
+ )
+ return api_base, dynamic_api_key
+
+ def get_models(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
+
+ api_base, api_key = self._get_openai_compatible_provider_info(
+ api_base=api_base, api_key=api_key
+ )
+ if api_base is None or api_key is None:
+ raise ValueError(
+ "FIREWORKS_API_BASE or FIREWORKS_API_KEY is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
+ )
+
+ account_id = get_secret_str("FIREWORKS_ACCOUNT_ID")
+ if account_id is None:
+ raise ValueError(
+ "FIREWORKS_ACCOUNT_ID is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
+ )
+
+ response = litellm.module_level_client.get(
+ url=f"{api_base}/v1/accounts/{account_id}/models",
+ headers={"Authorization": f"Bearer {api_key}"},
+ )
+
+ if response.status_code != 200:
+ raise ValueError(
+ f"Failed to fetch models from Fireworks AI. Status code: {response.status_code}, Response: {response.json()}"
+ )
+
+ models = response.json()["models"]
+
+ return ["fireworks_ai/" + model["name"] for model in models]
+
+ @staticmethod
+ def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
+ return api_key or (
+ get_secret_str("FIREWORKS_API_KEY")
+ or get_secret_str("FIREWORKS_AI_API_KEY")
+ or get_secret_str("FIREWORKSAI_API_KEY")
+ or get_secret_str("FIREWORKS_AI_TOKEN")
+ )