aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/openai/cli/_cli.py
blob: fd165f48abf98611b30dd319a1e5052a96140bd8 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from __future__ import annotations

import sys
import logging
import argparse
from typing import Any, List, Type, Optional
from typing_extensions import ClassVar

import httpx
import pydantic

import openai

from . import _tools
from .. import _ApiType, __version__
from ._api import register_commands
from ._utils import can_use_http2
from ._errors import CLIError, display_error
from .._compat import PYDANTIC_V2, ConfigDict, model_parse
from .._models import BaseModel
from .._exceptions import APIError

logger = logging.getLogger()
formatter = logging.Formatter("[%(asctime)s] %(message)s")
handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(formatter)
logger.addHandler(handler)


class Arguments(BaseModel):
    if PYDANTIC_V2:
        model_config: ClassVar[ConfigDict] = ConfigDict(
            extra="ignore",
        )
    else:

        class Config(pydantic.BaseConfig):  # type: ignore
            extra: Any = pydantic.Extra.ignore  # type: ignore

    verbosity: int
    version: Optional[str] = None

    api_key: Optional[str]
    api_base: Optional[str]
    organization: Optional[str]
    proxy: Optional[List[str]]
    api_type: Optional[_ApiType] = None
    api_version: Optional[str] = None

    # azure
    azure_endpoint: Optional[str] = None
    azure_ad_token: Optional[str] = None

    # internal, set by subparsers to parse their specific args
    args_model: Optional[Type[BaseModel]] = None

    # internal, used so that subparsers can forward unknown arguments
    unknown_args: List[str] = []
    allow_unknown_args: bool = False


def _build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description=None, prog="openai")
    parser.add_argument(
        "-v",
        "--verbose",
        action="count",
        dest="verbosity",
        default=0,
        help="Set verbosity.",
    )
    parser.add_argument("-b", "--api-base", help="What API base url to use.")
    parser.add_argument("-k", "--api-key", help="What API key to use.")
    parser.add_argument("-p", "--proxy", nargs="+", help="What proxy to use.")
    parser.add_argument(
        "-o",
        "--organization",
        help="Which organization to run as (will use your default organization if not specified)",
    )
    parser.add_argument(
        "-t",
        "--api-type",
        type=str,
        choices=("openai", "azure"),
        help="The backend API to call, must be `openai` or `azure`",
    )
    parser.add_argument(
        "--api-version",
        help="The Azure API version, e.g. 'https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning'",
    )

    # azure
    parser.add_argument(
        "--azure-endpoint",
        help="The Azure endpoint, e.g. 'https://endpoint.openai.azure.com'",
    )
    parser.add_argument(
        "--azure-ad-token",
        help="A token from Azure Active Directory, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id",
    )

    # prints the package version
    parser.add_argument(
        "-V",
        "--version",
        action="version",
        version="%(prog)s " + __version__,
    )

    def help() -> None:
        parser.print_help()

    parser.set_defaults(func=help)

    subparsers = parser.add_subparsers()
    sub_api = subparsers.add_parser("api", help="Direct API calls")

    register_commands(sub_api)

    sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience")
    _tools.register_commands(sub_tools, subparsers)

    return parser


def main() -> int:
    try:
        _main()
    except (APIError, CLIError, pydantic.ValidationError) as err:
        display_error(err)
        return 1
    except KeyboardInterrupt:
        sys.stderr.write("\n")
        return 1
    return 0


def _parse_args(parser: argparse.ArgumentParser) -> tuple[argparse.Namespace, Arguments, list[str]]:
    # argparse by default will strip out the `--` but we want to keep it for unknown arguments
    if "--" in sys.argv:
        idx = sys.argv.index("--")
        known_args = sys.argv[1:idx]
        unknown_args = sys.argv[idx:]
    else:
        known_args = sys.argv[1:]
        unknown_args = []

    parsed, remaining_unknown = parser.parse_known_args(known_args)

    # append any remaining unknown arguments from the initial parsing
    remaining_unknown.extend(unknown_args)

    args = model_parse(Arguments, vars(parsed))
    if not args.allow_unknown_args:
        # we have to parse twice to ensure any unknown arguments
        # result in an error if that behaviour is desired
        parser.parse_args()

    return parsed, args, remaining_unknown


def _main() -> None:
    parser = _build_parser()
    parsed, args, unknown = _parse_args(parser)

    if args.verbosity != 0:
        sys.stderr.write("Warning: --verbosity isn't supported yet\n")

    proxies: dict[str, httpx.BaseTransport] = {}
    if args.proxy is not None:
        for proxy in args.proxy:
            key = "https://" if proxy.startswith("https") else "http://"
            if key in proxies:
                raise CLIError(f"Multiple {key} proxies given - only the last one would be used")

            proxies[key] = httpx.HTTPTransport(proxy=httpx.Proxy(httpx.URL(proxy)))

    http_client = httpx.Client(
        mounts=proxies or None,
        http2=can_use_http2(),
    )
    openai.http_client = http_client

    if args.organization:
        openai.organization = args.organization

    if args.api_key:
        openai.api_key = args.api_key

    if args.api_base:
        openai.base_url = args.api_base

    # azure
    if args.api_type is not None:
        openai.api_type = args.api_type

    if args.azure_endpoint is not None:
        openai.azure_endpoint = args.azure_endpoint

    if args.api_version is not None:
        openai.api_version = args.api_version

    if args.azure_ad_token is not None:
        openai.azure_ad_token = args.azure_ad_token

    try:
        if args.args_model:
            parsed.func(
                model_parse(
                    args.args_model,
                    {
                        **{
                            # we omit None values so that they can be defaulted to `NotGiven`
                            # and we'll strip it from the API request
                            key: value
                            for key, value in vars(parsed).items()
                            if value is not None
                        },
                        "unknown_args": unknown,
                    },
                )
            )
        else:
            parsed.func()
    finally:
        try:
            http_client.close()
        except Exception:
            pass


if __name__ == "__main__":
    sys.exit(main())