about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/triton/completion
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/triton/completion
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/triton/completion')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/triton/completion/handler.py5
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/triton/completion/transformation.py343
2 files changed, 348 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/triton/completion/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/triton/completion/handler.py
new file mode 100644
index 00000000..cd1b7a62
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/triton/completion/handler.py
@@ -0,0 +1,5 @@
+"""
+Triton Completion - uses `llm_http_handler.py` to make httpx requests
+
+Request/Response transformation is handled in `transformation.py`
+"""
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/triton/completion/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/triton/completion/transformation.py
new file mode 100644
index 00000000..56151f89
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/triton/completion/transformation.py
@@ -0,0 +1,343 @@
+"""
+Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint.
+"""
+
+import json
+from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union
+
+from httpx import Headers, Response
+
+from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
+from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
+from litellm.llms.base_llm.chat.transformation import (
+    BaseConfig,
+    BaseLLMException,
+    LiteLLMLoggingObj,
+)
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import (
+    ChatCompletionToolCallChunk,
+    ChatCompletionUsageBlock,
+    Choices,
+    GenericStreamingChunk,
+    Message,
+    ModelResponse,
+)
+
+from ..common_utils import TritonError
+
+
+class TritonConfig(BaseConfig):
+    """
+    Base class for Triton configurations.
+
+    Handles routing between /infer and /generate triton completion llms
+    """
+
+    def get_error_class(
+        self, error_message: str, status_code: int, headers: Union[Dict, Headers]
+    ) -> BaseLLMException:
+        return TritonError(
+            status_code=status_code, message=error_message, headers=headers
+        )
+
+    def validate_environment(
+        self,
+        headers: Dict,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: Dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+    ) -> Dict:
+        return {"Content-Type": "application/json"}
+
+    def get_supported_openai_params(self, model: str) -> List:
+        return ["max_tokens", "max_completion_tokens"]
+
+    def map_openai_params(
+        self,
+        non_default_params: Dict,
+        optional_params: Dict,
+        model: str,
+        drop_params: bool,
+    ) -> Dict:
+        for param, value in non_default_params.items():
+            if param == "max_tokens" or param == "max_completion_tokens":
+                optional_params[param] = value
+        return optional_params
+
+    def get_complete_url(
+        self,
+        api_base: Optional[str],
+        model: str,
+        optional_params: dict,
+        litellm_params: dict,
+        stream: Optional[bool] = None,
+    ) -> str:
+        if api_base is None:
+            raise ValueError("api_base is required")
+        llm_type = self._get_triton_llm_type(api_base)
+        if llm_type == "generate" and stream:
+            return api_base + "_stream"
+        return api_base
+
+    def transform_response(
+        self,
+        model: str,
+        raw_response: Response,
+        model_response: ModelResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: Dict,
+        messages: List[AllMessageValues],
+        optional_params: Dict,
+        litellm_params: Dict,
+        encoding: Any,
+        api_key: Optional[str] = None,
+        json_mode: Optional[bool] = None,
+    ) -> ModelResponse:
+        api_base = litellm_params.get("api_base", "")
+        llm_type = self._get_triton_llm_type(api_base)
+        if llm_type == "generate":
+            return TritonGenerateConfig().transform_response(
+                model=model,
+                raw_response=raw_response,
+                model_response=model_response,
+                logging_obj=logging_obj,
+                request_data=request_data,
+                messages=messages,
+                optional_params=optional_params,
+                litellm_params=litellm_params,
+                encoding=encoding,
+                api_key=api_key,
+                json_mode=json_mode,
+            )
+        elif llm_type == "infer":
+            return TritonInferConfig().transform_response(
+                model=model,
+                raw_response=raw_response,
+                model_response=model_response,
+                logging_obj=logging_obj,
+                request_data=request_data,
+                messages=messages,
+                optional_params=optional_params,
+                litellm_params=litellm_params,
+                encoding=encoding,
+                api_key=api_key,
+                json_mode=json_mode,
+            )
+        return model_response
+
+    def transform_request(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        headers: dict,
+    ) -> dict:
+        api_base = litellm_params.get("api_base", "")
+        llm_type = self._get_triton_llm_type(api_base)
+        if llm_type == "generate":
+            return TritonGenerateConfig().transform_request(
+                model=model,
+                messages=messages,
+                optional_params=optional_params,
+                litellm_params=litellm_params,
+                headers=headers,
+            )
+        elif llm_type == "infer":
+            return TritonInferConfig().transform_request(
+                model=model,
+                messages=messages,
+                optional_params=optional_params,
+                litellm_params=litellm_params,
+                headers=headers,
+            )
+        return {}
+
+    def _get_triton_llm_type(self, api_base: str) -> Literal["generate", "infer"]:
+        if api_base.endswith("/generate"):
+            return "generate"
+        elif api_base.endswith("/infer"):
+            return "infer"
+        else:
+            raise ValueError(f"Invalid Triton API base: {api_base}")
+
+    def get_model_response_iterator(
+        self,
+        streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
+        sync_stream: bool,
+        json_mode: Optional[bool] = False,
+    ) -> Any:
+        return TritonResponseIterator(
+            streaming_response=streaming_response,
+            sync_stream=sync_stream,
+            json_mode=json_mode,
+        )
+
+
+class TritonGenerateConfig(TritonConfig):
+    """
+    Transformations for triton /generate endpoint (This is a trtllm model)
+    """
+
+    def transform_request(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        headers: dict,
+    ) -> dict:
+        inference_params = optional_params.copy()
+        stream = inference_params.pop("stream", False)
+        data_for_triton: Dict[str, Any] = {
+            "text_input": prompt_factory(model=model, messages=messages),
+            "parameters": {
+                "max_tokens": int(optional_params.get("max_tokens", 2000)),
+                "bad_words": [""],
+                "stop_words": [""],
+            },
+            "stream": bool(stream),
+        }
+        data_for_triton["parameters"].update(inference_params)
+        return data_for_triton
+
+    def transform_response(
+        self,
+        model: str,
+        raw_response: Response,
+        model_response: ModelResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: Dict,
+        messages: List[AllMessageValues],
+        optional_params: Dict,
+        litellm_params: Dict,
+        encoding: Any,
+        api_key: Optional[str] = None,
+        json_mode: Optional[bool] = None,
+    ) -> ModelResponse:
+        try:
+            raw_response_json = raw_response.json()
+        except Exception:
+            raise TritonError(
+                message=raw_response.text, status_code=raw_response.status_code
+            )
+        model_response.choices = [
+            Choices(index=0, message=Message(content=raw_response_json["text_output"]))
+        ]
+
+        return model_response
+
+
+class TritonInferConfig(TritonConfig):
+    """
+    Transformations for triton /infer endpoint (his is an infer model with a custom model on triton)
+    """
+
+    def transform_request(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        headers: dict,
+    ) -> dict:
+
+        text_input = messages[0].get("content", "")
+        data_for_triton = {
+            "inputs": [
+                {
+                    "name": "text_input",
+                    "shape": [1],
+                    "datatype": "BYTES",
+                    "data": [text_input],
+                }
+            ]
+        }
+
+        for k, v in optional_params.items():
+            if not (k == "stream" or k == "max_retries"):
+                datatype = "INT32" if isinstance(v, int) else "BYTES"
+                datatype = "FP32" if isinstance(v, float) else datatype
+                data_for_triton["inputs"].append(
+                    {"name": k, "shape": [1], "datatype": datatype, "data": [v]}
+                )
+
+        if "max_tokens" not in optional_params:
+            data_for_triton["inputs"].append(
+                {
+                    "name": "max_tokens",
+                    "shape": [1],
+                    "datatype": "INT32",
+                    "data": [20],
+                }
+            )
+        return data_for_triton
+
+    def transform_response(
+        self,
+        model: str,
+        raw_response: Response,
+        model_response: ModelResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: Dict,
+        messages: List[AllMessageValues],
+        optional_params: Dict,
+        litellm_params: Dict,
+        encoding: Any,
+        api_key: Optional[str] = None,
+        json_mode: Optional[bool] = None,
+    ) -> ModelResponse:
+        try:
+            raw_response_json = raw_response.json()
+        except Exception:
+            raise TritonError(
+                message=raw_response.text, status_code=raw_response.status_code
+            )
+
+        _triton_response_data = raw_response_json["outputs"][0]["data"]
+        triton_response_data: Optional[str] = None
+        if isinstance(_triton_response_data, list):
+            triton_response_data = "".join(_triton_response_data)
+        else:
+            triton_response_data = _triton_response_data
+
+        model_response.choices = [
+            Choices(
+                index=0,
+                message=Message(content=triton_response_data),
+            )
+        ]
+
+        return model_response
+
+
+class TritonResponseIterator(BaseModelResponseIterator):
+    def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
+        try:
+            text = ""
+            tool_use: Optional[ChatCompletionToolCallChunk] = None
+            is_finished = False
+            finish_reason = ""
+            usage: Optional[ChatCompletionUsageBlock] = None
+            provider_specific_fields = None
+            index = int(chunk.get("index", 0))
+
+            # set values
+            text = chunk.get("text_output", "")
+            finish_reason = chunk.get("stop_reason", "")
+            is_finished = chunk.get("is_finished", False)
+
+            return GenericStreamingChunk(
+                text=text,
+                tool_use=tool_use,
+                is_finished=is_finished,
+                finish_reason=finish_reason,
+                usage=usage,
+                index=index,
+                provider_specific_fields=provider_specific_fields,
+            )
+        except json.JSONDecodeError:
+            raise ValueError(f"Failed to decode JSON from chunk: {chunk}")