about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_management_base.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/prompt_management_base.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/prompt_management_base.py118
1 files changed, 118 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_management_base.py b/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_management_base.py
new file mode 100644
index 00000000..3fe3b31e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/prompt_management_base.py
@@ -0,0 +1,118 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Tuple, TypedDict
+
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import StandardCallbackDynamicParams
+
+
+class PromptManagementClient(TypedDict):
+    prompt_id: str
+    prompt_template: List[AllMessageValues]
+    prompt_template_model: Optional[str]
+    prompt_template_optional_params: Optional[Dict[str, Any]]
+    completed_messages: Optional[List[AllMessageValues]]
+
+
+class PromptManagementBase(ABC):
+
+    @property
+    @abstractmethod
+    def integration_name(self) -> str:
+        pass
+
+    @abstractmethod
+    def should_run_prompt_management(
+        self,
+        prompt_id: str,
+        dynamic_callback_params: StandardCallbackDynamicParams,
+    ) -> bool:
+        pass
+
+    @abstractmethod
+    def _compile_prompt_helper(
+        self,
+        prompt_id: str,
+        prompt_variables: Optional[dict],
+        dynamic_callback_params: StandardCallbackDynamicParams,
+    ) -> PromptManagementClient:
+        pass
+
+    def merge_messages(
+        self,
+        prompt_template: List[AllMessageValues],
+        client_messages: List[AllMessageValues],
+    ) -> List[AllMessageValues]:
+        return prompt_template + client_messages
+
+    def compile_prompt(
+        self,
+        prompt_id: str,
+        prompt_variables: Optional[dict],
+        client_messages: List[AllMessageValues],
+        dynamic_callback_params: StandardCallbackDynamicParams,
+    ) -> PromptManagementClient:
+        compiled_prompt_client = self._compile_prompt_helper(
+            prompt_id=prompt_id,
+            prompt_variables=prompt_variables,
+            dynamic_callback_params=dynamic_callback_params,
+        )
+
+        try:
+            messages = compiled_prompt_client["prompt_template"] + client_messages
+        except Exception as e:
+            raise ValueError(
+                f"Error compiling prompt: {e}. Prompt id={prompt_id}, prompt_variables={prompt_variables}, client_messages={client_messages}, dynamic_callback_params={dynamic_callback_params}"
+            )
+
+        compiled_prompt_client["completed_messages"] = messages
+        return compiled_prompt_client
+
+    def _get_model_from_prompt(
+        self, prompt_management_client: PromptManagementClient, model: str
+    ) -> str:
+        if prompt_management_client["prompt_template_model"] is not None:
+            return prompt_management_client["prompt_template_model"]
+        else:
+            return model.replace("{}/".format(self.integration_name), "")
+
+    def get_chat_completion_prompt(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        non_default_params: dict,
+        prompt_id: str,
+        prompt_variables: Optional[dict],
+        dynamic_callback_params: StandardCallbackDynamicParams,
+    ) -> Tuple[
+        str,
+        List[AllMessageValues],
+        dict,
+    ]:
+        if not self.should_run_prompt_management(
+            prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
+        ):
+            return model, messages, non_default_params
+
+        prompt_template = self.compile_prompt(
+            prompt_id=prompt_id,
+            prompt_variables=prompt_variables,
+            client_messages=messages,
+            dynamic_callback_params=dynamic_callback_params,
+        )
+
+        completed_messages = prompt_template["completed_messages"] or messages
+
+        prompt_template_optional_params = (
+            prompt_template["prompt_template_optional_params"] or {}
+        )
+
+        updated_non_default_params = {
+            **non_default_params,
+            **prompt_template_optional_params,
+        }
+
+        model = self._get_model_from_prompt(
+            prompt_management_client=prompt_template, model=model
+        )
+
+        return model, completed_messages, updated_non_default_params