diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/proxy_cli.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/proxy_cli.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/proxy_cli.py | 795 |
1 files changed, 795 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/proxy_cli.py b/.venv/lib/python3.12/site-packages/litellm/proxy/proxy_cli.py new file mode 100644 index 00000000..8196eb59 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/proxy_cli.py @@ -0,0 +1,795 @@ +import importlib +import json +import os +import random +import subprocess +import sys +import urllib.parse as urlparse +from typing import TYPE_CHECKING, Any, Optional, Union + +import click +import httpx +from dotenv import load_dotenv + +if TYPE_CHECKING: + from fastapi import FastAPI +else: + FastAPI = Any + +sys.path.append(os.getcwd()) + +config_filename = "litellm.secrets" + +litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV" +if litellm_mode == "DEV": + load_dotenv() +from enum import Enum + +telemetry = None + + +class LiteLLMDatabaseConnectionPool(Enum): + database_connection_pool_limit = 10 + database_connection_pool_timeout = 60 + + +def append_query_params(url, params) -> str: + from litellm._logging import verbose_proxy_logger + + verbose_proxy_logger.debug(f"url: {url}") + verbose_proxy_logger.debug(f"params: {params}") + parsed_url = urlparse.urlparse(url) + parsed_query = urlparse.parse_qs(parsed_url.query) + parsed_query.update(params) + encoded_query = urlparse.urlencode(parsed_query, doseq=True) + modified_url = urlparse.urlunparse(parsed_url._replace(query=encoded_query)) + return modified_url # type: ignore + + +class ProxyInitializationHelpers: + @staticmethod + def _echo_litellm_version(): + pkg_version = importlib.metadata.version("litellm") # type: ignore + click.echo(f"\nLiteLLM: Current Version = {pkg_version}\n") + + @staticmethod + def _run_health_check(host, port): + print("\nLiteLLM: Health Testing models in config") # noqa + response = httpx.get(url=f"http://{host}:{port}/health") + print(json.dumps(response.json(), indent=4)) # noqa + + @staticmethod + def _run_test_chat_completion( + host: str, + port: int, + model: str, + test: Union[bool, str], + ): + request_model = model or "gpt-3.5-turbo" + click.echo( + f"\nLiteLLM: Making a test ChatCompletions request to your proxy. Model={request_model}" + ) + import openai + + api_base = f"http://{host}:{port}" + if isinstance(test, str): + api_base = test + else: + raise ValueError("Invalid test value") + client = openai.OpenAI(api_key="My API Key", base_url=api_base) + + response = client.chat.completions.create( + model=request_model, + messages=[ + { + "role": "user", + "content": "this is a test request, write a short poem", + } + ], + max_tokens=256, + ) + click.echo(f"\nLiteLLM: response from proxy {response}") + + print( # noqa + f"\n LiteLLM: Making a test ChatCompletions + streaming r equest to proxy. Model={request_model}" + ) + + stream_response = client.chat.completions.create( + model=request_model, + messages=[ + { + "role": "user", + "content": "this is a test request, write a short poem", + } + ], + stream=True, + ) + for chunk in stream_response: + click.echo(f"LiteLLM: streaming response from proxy {chunk}") + print("\n making completion request to proxy") # noqa + completion_response = client.completions.create( + model=request_model, prompt="this is a test request, write a short poem" + ) + print(completion_response) # noqa + + @staticmethod + def _get_default_unvicorn_init_args( + host: str, + port: int, + log_config: Optional[str] = None, + ) -> dict: + """ + Get the arguments for `uvicorn` worker + """ + import litellm + + uvicorn_args = { + "app": "litellm.proxy.proxy_server:app", + "host": host, + "port": port, + } + if log_config is not None: + print(f"Using log_config: {log_config}") # noqa + uvicorn_args["log_config"] = log_config + elif litellm.json_logs: + print("Using json logs. Setting log_config to None.") # noqa + uvicorn_args["log_config"] = None + return uvicorn_args + + @staticmethod + def _init_hypercorn_server( + app: FastAPI, + host: str, + port: int, + ssl_certfile_path: str, + ssl_keyfile_path: str, + ): + """ + Initialize litellm with `hypercorn` + """ + import asyncio + + from hypercorn.asyncio import serve + from hypercorn.config import Config + + print( # noqa + f"\033[1;32mLiteLLM Proxy: Starting server on {host}:{port} using Hypercorn\033[0m\n" # noqa + ) # noqa + config = Config() + config.bind = [f"{host}:{port}"] + + if ssl_certfile_path is not None and ssl_keyfile_path is not None: + print( # noqa + f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n" # noqa + ) + config.certfile = ssl_certfile_path + config.keyfile = ssl_keyfile_path + + # hypercorn serve raises a type warning when passing a fast api app - even though fast API is a valid type + asyncio.run(serve(app, config)) # type: ignore + + @staticmethod + def _run_gunicorn_server( + host: str, + port: int, + app: FastAPI, + num_workers: int, + ssl_certfile_path: str, + ssl_keyfile_path: str, + ): + """ + Run litellm with `gunicorn` + """ + if os.name == "nt": + pass + else: + import gunicorn.app.base + + # Gunicorn Application Class + class StandaloneApplication(gunicorn.app.base.BaseApplication): + def __init__(self, app, options=None): + self.options = options or {} # gunicorn options + self.application = app # FastAPI app + super().__init__() + + _endpoint_str = ( + f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\" + ) + curl_command = ( + _endpoint_str + + """ + --header 'Content-Type: application/json' \\ + --data ' { + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ] + }' + \n + """ + ) + print() # noqa + print( # noqa + '\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n' + ) + print( # noqa + f"\033[1;34mLiteLLM: Curl Command Test for your local proxy\n {curl_command} \033[0m\n" + ) + print( # noqa + "\033[1;34mDocs: https://docs.litellm.ai/docs/simple_proxy\033[0m\n" + ) # noqa + print( # noqa + f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:{port} \033[0m\n" + ) # noqa + + def load_config(self): + # note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config + if self.cfg is not None: + config = { + key: value + for key, value in self.options.items() + if key in self.cfg.settings and value is not None + } + else: + config = {} + for key, value in config.items(): + if self.cfg is not None: + self.cfg.set(key.lower(), value) + + def load(self): + # gunicorn app function + return self.application + + print( # noqa + f"\033[1;32mLiteLLM Proxy: Starting server on {host}:{port} with {num_workers} workers\033[0m\n" # noqa + ) + gunicorn_options = { + "bind": f"{host}:{port}", + "workers": num_workers, # default is 1 + "worker_class": "uvicorn.workers.UvicornWorker", + "preload": True, # Add the preload flag, + "accesslog": "-", # Log to stdout + "timeout": 600, # default to very high number, bedrock/anthropic.claude-v2:1 can take 30+ seconds for the 1st chunk to come in + "access_log_format": '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s', + } + + if ssl_certfile_path is not None and ssl_keyfile_path is not None: + print( # noqa + f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n" # noqa + ) + gunicorn_options["certfile"] = ssl_certfile_path + gunicorn_options["keyfile"] = ssl_keyfile_path + + StandaloneApplication(app=app, options=gunicorn_options).run() # Run gunicorn + + @staticmethod + def _run_ollama_serve(): + try: + command = ["ollama", "serve"] + + with open(os.devnull, "w") as devnull: + subprocess.Popen(command, stdout=devnull, stderr=devnull) + except Exception as e: + print( # noqa + f""" + LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` + """ + ) # noqa + + @staticmethod + def _is_port_in_use(port): + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 + + @staticmethod + def _get_loop_type(): + """Helper function to determine the event loop type based on platform""" + if sys.platform in ("win32", "cygwin", "cli"): + return None # Let uvicorn choose the default loop on Windows + return "uvloop" + + +@click.command() +@click.option( + "--host", default="0.0.0.0", help="Host for the server to listen on.", envvar="HOST" +) +@click.option("--port", default=4000, help="Port to bind the server to.", envvar="PORT") +@click.option( + "--num_workers", + default=1, + help="Number of uvicorn / gunicorn workers to spin up. By default, 1 uvicorn is used.", + envvar="NUM_WORKERS", +) +@click.option("--api_base", default=None, help="API base URL.") +@click.option( + "--api_version", + default="2024-07-01-preview", + help="For azure - pass in the api version.", +) +@click.option( + "--model", "-m", default=None, help="The model name to pass to litellm expects" +) +@click.option( + "--alias", + default=None, + help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")', +) +@click.option( + "--add_key", default=None, help="The model name to pass to litellm expects" +) +@click.option("--headers", default=None, help="headers for the API call") +@click.option("--save", is_flag=True, type=bool, help="Save the model-specific config") +@click.option( + "--debug", + default=False, + is_flag=True, + type=bool, + help="To debug the input", + envvar="DEBUG", +) +@click.option( + "--detailed_debug", + default=False, + is_flag=True, + type=bool, + help="To view detailed debug logs", + envvar="DETAILED_DEBUG", +) +@click.option( + "--use_queue", + default=False, + is_flag=True, + type=bool, + help="To use celery workers for async endpoints", +) +@click.option( + "--temperature", default=None, type=float, help="Set temperature for the model" +) +@click.option( + "--max_tokens", default=None, type=int, help="Set max tokens for the model" +) +@click.option( + "--request_timeout", + default=None, + type=int, + help="Set timeout in seconds for completion calls", +) +@click.option("--drop_params", is_flag=True, help="Drop any unmapped params") +@click.option( + "--add_function_to_prompt", + is_flag=True, + help="If function passed but unsupported, pass it as prompt", +) +@click.option( + "--config", + "-c", + default=None, + help="Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`", +) +@click.option( + "--max_budget", + default=None, + type=float, + help="Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`", +) +@click.option( + "--telemetry", + default=True, + type=bool, + help="Helps us know if people are using this feature. Turn this off by doing `--telemetry False`", +) +@click.option( + "--log_config", + default=None, + type=str, + help="Path to the logging configuration file", +) +@click.option( + "--version", + "-v", + default=False, + is_flag=True, + type=bool, + help="Print LiteLLM version", +) +@click.option( + "--health", + flag_value=True, + help="Make a chat/completions request to all llms in config.yaml", +) +@click.option( + "--test", + flag_value=True, + help="proxy chat completions url to make a test request to", +) +@click.option( + "--test_async", + default=False, + is_flag=True, + help="Calls async endpoints /queue/requests and /queue/response", +) +@click.option( + "--iam_token_db_auth", + default=False, + is_flag=True, + help="Connects to RDS DB with IAM token", +) +@click.option( + "--num_requests", + default=10, + type=int, + help="Number of requests to hit async endpoint with", +) +@click.option( + "--run_gunicorn", + default=False, + is_flag=True, + help="Starts proxy via gunicorn, instead of uvicorn (better for managing multiple workers)", +) +@click.option( + "--run_hypercorn", + default=False, + is_flag=True, + help="Starts proxy via hypercorn, instead of uvicorn (supports HTTP/2)", +) +@click.option( + "--ssl_keyfile_path", + default=None, + type=str, + help="Path to the SSL keyfile. Use this when you want to provide SSL certificate when starting proxy", + envvar="SSL_KEYFILE_PATH", +) +@click.option( + "--ssl_certfile_path", + default=None, + type=str, + help="Path to the SSL certfile. Use this when you want to provide SSL certificate when starting proxy", + envvar="SSL_CERTFILE_PATH", +) +@click.option( + "--use_prisma_migrate", + is_flag=True, + default=False, + help="Use prisma migrate instead of prisma db push for database schema updates", +) +@click.option("--local", is_flag=True, default=False, help="for local debugging") +def run_server( # noqa: PLR0915 + host, + port, + api_base, + api_version, + model, + alias, + add_key, + headers, + save, + debug, + detailed_debug, + temperature, + max_tokens, + request_timeout, + drop_params, + add_function_to_prompt, + config, + max_budget, + telemetry, + test, + local, + num_workers, + test_async, + iam_token_db_auth, + num_requests, + use_queue, + health, + version, + run_gunicorn, + run_hypercorn, + ssl_keyfile_path, + ssl_certfile_path, + log_config, + use_prisma_migrate, +): + args = locals() + if local: + from proxy_server import ( + KeyManagementSettings, + ProxyConfig, + app, + save_worker_config, + ) + else: + try: + from .proxy_server import ( + KeyManagementSettings, + ProxyConfig, + app, + save_worker_config, + ) + except ImportError as e: + if "litellm[proxy]" in str(e): + # user is missing a proxy dependency, ask them to pip install litellm[proxy] + raise e + else: + # this is just a local/relative import error, user git cloned litellm + from proxy_server import ( + KeyManagementSettings, + ProxyConfig, + app, + save_worker_config, + ) + if version is True: + ProxyInitializationHelpers._echo_litellm_version() + return + if model and "ollama" in model and api_base is None: + ProxyInitializationHelpers._run_ollama_serve() + if health is True: + ProxyInitializationHelpers._run_health_check(host, port) + return + if test is True: + ProxyInitializationHelpers._run_test_chat_completion(host, port, model, test) + return + else: + if headers: + headers = json.loads(headers) + save_worker_config( + model=model, + alias=alias, + api_base=api_base, + api_version=api_version, + debug=debug, + detailed_debug=detailed_debug, + temperature=temperature, + max_tokens=max_tokens, + request_timeout=request_timeout, + max_budget=max_budget, + telemetry=telemetry, + drop_params=drop_params, + add_function_to_prompt=add_function_to_prompt, + headers=headers, + save=save, + config=config, + use_queue=use_queue, + ) + try: + import uvicorn + except Exception: + raise ImportError( + "uvicorn, gunicorn needs to be imported. Run - `pip install 'litellm[proxy]'`" + ) + + db_connection_pool_limit = 100 + db_connection_timeout = 60 + general_settings = {} + ### GET DB TOKEN FOR IAM AUTH ### + + if iam_token_db_auth: + from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token + + db_host = os.getenv("DATABASE_HOST") + db_port = os.getenv("DATABASE_PORT") + db_user = os.getenv("DATABASE_USER") + db_name = os.getenv("DATABASE_NAME") + db_schema = os.getenv("DATABASE_SCHEMA") + + token = generate_iam_auth_token( + db_host=db_host, db_port=db_port, db_user=db_user + ) + + # print(f"token: {token}") + _db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}" + if db_schema: + _db_url += f"?schema={db_schema}" + + os.environ["DATABASE_URL"] = _db_url + os.environ["IAM_TOKEN_DB_AUTH"] = "True" + + ### DECRYPT ENV VAR ### + + from litellm.secret_managers.aws_secret_manager import decrypt_env_var + + if ( + os.getenv("USE_AWS_KMS", None) is not None + and os.getenv("USE_AWS_KMS") == "True" + ): + ## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV + new_env_var = decrypt_env_var() + + for k, v in new_env_var.items(): + os.environ[k] = v + + if config is not None: + """ + Allow user to pass in db url via config + + read from there and save it to os.env['DATABASE_URL'] + """ + try: + import asyncio + + except Exception: + raise ImportError( + "yaml needs to be imported. Run - `pip install 'litellm[proxy]'`" + ) + + proxy_config = ProxyConfig() + _config = asyncio.run(proxy_config.get_config(config_file_path=config)) + + ### LITELLM SETTINGS ### + litellm_settings = _config.get("litellm_settings", None) + if ( + litellm_settings is not None + and "json_logs" in litellm_settings + and litellm_settings["json_logs"] is True + ): + import litellm + + litellm.json_logs = True + + litellm._turn_on_json() + ### GENERAL SETTINGS ### + general_settings = _config.get("general_settings", {}) + if general_settings is None: + general_settings = {} + if general_settings: + ### LOAD SECRET MANAGER ### + key_management_system = general_settings.get( + "key_management_system", None + ) + proxy_config.initialize_secret_manager(key_management_system) + key_management_settings = general_settings.get( + "key_management_settings", None + ) + if key_management_settings is not None: + import litellm + + litellm._key_management_settings = KeyManagementSettings( + **key_management_settings + ) + database_url = general_settings.get("database_url", None) + if database_url is None: + # Check if all required variables are provided + database_host = os.getenv("DATABASE_HOST") + database_username = os.getenv("DATABASE_USERNAME") + database_password = os.getenv("DATABASE_PASSWORD") + database_name = os.getenv("DATABASE_NAME") + + if ( + database_host + and database_username + and database_password + and database_name + ): + # Construct DATABASE_URL from the provided variables + database_url = f"postgresql://{database_username}:{database_password}@{database_host}/{database_name}" + os.environ["DATABASE_URL"] = database_url + db_connection_pool_limit = general_settings.get( + "database_connection_pool_limit", + LiteLLMDatabaseConnectionPool.database_connection_pool_limit.value, + ) + db_connection_timeout = general_settings.get( + "database_connection_timeout", + LiteLLMDatabaseConnectionPool.database_connection_pool_timeout.value, + ) + if database_url and database_url.startswith("os.environ/"): + original_dir = os.getcwd() + # set the working directory to where this script is + sys.path.insert( + 0, os.path.abspath("../..") + ) # Adds the parent directory to the system path - for litellm local dev + import litellm + from litellm import get_secret_str + + database_url = get_secret_str(database_url, default_value=None) + os.chdir(original_dir) + if database_url is not None and isinstance(database_url, str): + os.environ["DATABASE_URL"] = database_url + + if ( + os.getenv("DATABASE_URL", None) is not None + or os.getenv("DIRECT_URL", None) is not None + ): + try: + from litellm.secret_managers.main import get_secret + + if os.getenv("DATABASE_URL", None) is not None: + ### add connection pool + pool timeout args + params = { + "connection_limit": db_connection_pool_limit, + "pool_timeout": db_connection_timeout, + } + database_url = get_secret("DATABASE_URL", default_value=None) + modified_url = append_query_params(database_url, params) + os.environ["DATABASE_URL"] = modified_url + if os.getenv("DIRECT_URL", None) is not None: + ### add connection pool + pool timeout args + params = { + "connection_limit": db_connection_pool_limit, + "pool_timeout": db_connection_timeout, + } + database_url = os.getenv("DIRECT_URL") + modified_url = append_query_params(database_url, params) + os.environ["DIRECT_URL"] = modified_url + ### + subprocess.run(["prisma"], capture_output=True) + is_prisma_runnable = True + except FileNotFoundError: + is_prisma_runnable = False + + if is_prisma_runnable: + from litellm.proxy.db.check_migration import check_prisma_schema_diff + from litellm.proxy.db.prisma_client import ( + PrismaManager, + should_update_prisma_schema, + ) + + if ( + should_update_prisma_schema( + general_settings.get("disable_prisma_schema_update") + ) + is False + ): + check_prisma_schema_diff(db_url=None) + else: + PrismaManager.setup_database(use_migrate=use_prisma_migrate) + else: + print( # noqa + f"Unable to connect to DB. DATABASE_URL found in environment, but prisma package not found." # noqa + ) + if port == 4000 and ProxyInitializationHelpers._is_port_in_use(port): + port = random.randint(1024, 49152) + + import litellm + + if detailed_debug is True: + litellm._turn_on_debug() + + # DO NOT DELETE - enables global variables to work across files + from litellm.proxy.proxy_server import app # noqa + + uvicorn_args = ProxyInitializationHelpers._get_default_unvicorn_init_args( + host=host, + port=port, + log_config=log_config, + ) + if run_gunicorn is False and run_hypercorn is False: + if ssl_certfile_path is not None and ssl_keyfile_path is not None: + print( # noqa + f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n" # noqa + ) + uvicorn_args["ssl_keyfile"] = ssl_keyfile_path + uvicorn_args["ssl_certfile"] = ssl_certfile_path + + loop_type = ProxyInitializationHelpers._get_loop_type() + if loop_type: + uvicorn_args["loop"] = loop_type + + uvicorn.run( + **uvicorn_args, + workers=num_workers, + ) + elif run_gunicorn is True: + ProxyInitializationHelpers._run_gunicorn_server( + host=host, + port=port, + app=app, + num_workers=num_workers, + ssl_certfile_path=ssl_certfile_path, + ssl_keyfile_path=ssl_keyfile_path, + ) + elif run_hypercorn is True: + ProxyInitializationHelpers._init_hypercorn_server( + app=app, + host=host, + port=port, + ssl_certfile_path=ssl_certfile_path, + ssl_keyfile_path=ssl_keyfile_path, + ) + + +if __name__ == "__main__": + run_server() |