1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
|
from typing import Dict, Optional
from litellm._logging import verbose_logger
from litellm.secret_managers.main import get_secret_str
class PassthroughEndpointRouter:
"""
Use this class to Set/Get credentials for pass-through endpoints
"""
def __init__(self):
self.credentials: Dict[str, str] = {}
def set_pass_through_credentials(
self,
custom_llm_provider: str,
api_base: Optional[str],
api_key: Optional[str],
):
"""
Set credentials for a pass-through endpoint. Used when a user adds a pass-through LLM endpoint on the UI.
Args:
custom_llm_provider: The provider of the pass-through endpoint
api_base: The base URL of the pass-through endpoint
api_key: The API key for the pass-through endpoint
"""
credential_name = self._get_credential_name_for_provider(
custom_llm_provider=custom_llm_provider,
region_name=self._get_region_name_from_api_base(
api_base=api_base, custom_llm_provider=custom_llm_provider
),
)
if api_key is None:
raise ValueError("api_key is required for setting pass-through credentials")
self.credentials[credential_name] = api_key
def get_credentials(
self,
custom_llm_provider: str,
region_name: Optional[str],
) -> Optional[str]:
credential_name = self._get_credential_name_for_provider(
custom_llm_provider=custom_llm_provider,
region_name=region_name,
)
verbose_logger.debug(
f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
)
if credential_name in self.credentials:
verbose_logger.debug(f"Found credentials for {credential_name}")
return self.credentials[credential_name]
else:
verbose_logger.debug(
f"No credentials found for {credential_name}, looking for env variable"
)
_env_variable_name = (
self._get_default_env_variable_name_passthrough_endpoint(
custom_llm_provider=custom_llm_provider,
)
)
return get_secret_str(_env_variable_name)
def _get_credential_name_for_provider(
self,
custom_llm_provider: str,
region_name: Optional[str],
) -> str:
if region_name is None:
return f"{custom_llm_provider.upper()}_API_KEY"
return f"{custom_llm_provider.upper()}_{region_name.upper()}_API_KEY"
def _get_region_name_from_api_base(
self,
custom_llm_provider: str,
api_base: Optional[str],
) -> Optional[str]:
"""
Get the region name from the API base.
Each provider might have a different way of specifying the region in the API base - this is where you can use conditional logic to handle that.
"""
if custom_llm_provider == "assemblyai":
if api_base and "eu" in api_base:
return "eu"
return None
@staticmethod
def _get_default_env_variable_name_passthrough_endpoint(
custom_llm_provider: str,
) -> str:
return f"{custom_llm_provider.upper()}_API_KEY"
|