aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/ollama/completion/handler.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/ollama/completion/handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/ollama/completion/handler.py102
1 files changed, 102 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/ollama/completion/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/ollama/completion/handler.py
new file mode 100644
index 00000000..208a9d81
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/ollama/completion/handler.py
@@ -0,0 +1,102 @@
+"""
+Ollama /chat/completion calls handled in llm_http_handler.py
+
+[TODO]: migrate embeddings to a base handler as well.
+"""
+
+import asyncio
+from typing import Any, Dict, List
+
+import litellm
+from litellm.types.utils import EmbeddingResponse
+
+# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
+# and convert to jpeg if necessary.
+
+
+async def ollama_aembeddings(
+ api_base: str,
+ model: str,
+ prompts: List[str],
+ model_response: EmbeddingResponse,
+ optional_params: dict,
+ logging_obj: Any,
+ encoding: Any,
+):
+ if api_base.endswith("/api/embed"):
+ url = api_base
+ else:
+ url = f"{api_base}/api/embed"
+
+ ## Load Config
+ config = litellm.OllamaConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in optional_params
+ ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
+ optional_params[k] = v
+
+ data: Dict[str, Any] = {"model": model, "input": prompts}
+ special_optional_params = ["truncate", "options", "keep_alive"]
+
+ for k, v in optional_params.items():
+ if k in special_optional_params:
+ data[k] = v
+ else:
+ # Ensure "options" is a dictionary before updating it
+ data.setdefault("options", {})
+ if isinstance(data["options"], dict):
+ data["options"].update({k: v})
+ total_input_tokens = 0
+ output_data = []
+
+ response = await litellm.module_level_aclient.post(url=url, json=data)
+
+ response_json = response.json()
+
+ embeddings: List[List[float]] = response_json["embeddings"]
+ for idx, emb in enumerate(embeddings):
+ output_data.append({"object": "embedding", "index": idx, "embedding": emb})
+
+ input_tokens = response_json.get("prompt_eval_count") or len(
+ encoding.encode("".join(prompt for prompt in prompts))
+ )
+ total_input_tokens += input_tokens
+
+ model_response.object = "list"
+ model_response.data = output_data
+ model_response.model = "ollama/" + model
+ setattr(
+ model_response,
+ "usage",
+ litellm.Usage(
+ prompt_tokens=total_input_tokens,
+ completion_tokens=total_input_tokens,
+ total_tokens=total_input_tokens,
+ prompt_tokens_details=None,
+ completion_tokens_details=None,
+ ),
+ )
+ return model_response
+
+
+def ollama_embeddings(
+ api_base: str,
+ model: str,
+ prompts: list,
+ optional_params: dict,
+ model_response: EmbeddingResponse,
+ logging_obj: Any,
+ encoding=None,
+):
+ return asyncio.run(
+ ollama_aembeddings(
+ api_base=api_base,
+ model=model,
+ prompts=prompts,
+ model_response=model_response,
+ optional_params=optional_params,
+ logging_obj=logging_obj,
+ encoding=encoding,
+ )
+ )