aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/openai/cli/_api/chat/completions.py
blob: 344eeff37cc2675da69174016ffb5879126fce7b (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from __future__ import annotations

import sys
from typing import TYPE_CHECKING, List, Optional, cast
from argparse import ArgumentParser
from typing_extensions import Literal, NamedTuple

from ..._utils import get_client
from ..._models import BaseModel
from ...._streaming import Stream
from ....types.chat import (
    ChatCompletionRole,
    ChatCompletionChunk,
    CompletionCreateParams,
)
from ....types.chat.completion_create_params import (
    CompletionCreateParamsStreaming,
    CompletionCreateParamsNonStreaming,
)

if TYPE_CHECKING:
    from argparse import _SubParsersAction


def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
    sub = subparser.add_parser("chat.completions.create")

    sub._action_groups.pop()
    req = sub.add_argument_group("required arguments")
    opt = sub.add_argument_group("optional arguments")

    req.add_argument(
        "-g",
        "--message",
        action="append",
        nargs=2,
        metavar=("ROLE", "CONTENT"),
        help="A message in `{role} {content}` format. Use this argument multiple times to add multiple messages.",
        required=True,
    )
    req.add_argument(
        "-m",
        "--model",
        help="The model to use.",
        required=True,
    )

    opt.add_argument(
        "-n",
        "--n",
        help="How many completions to generate for the conversation.",
        type=int,
    )
    opt.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate.", type=int)
    opt.add_argument(
        "-t",
        "--temperature",
        help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.

Mutually exclusive with `top_p`.""",
        type=float,
    )
    opt.add_argument(
        "-P",
        "--top_p",
        help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.

            Mutually exclusive with `temperature`.""",
        type=float,
    )
    opt.add_argument(
        "--stop",
        help="A stop sequence at which to stop generating tokens for the message.",
    )
    opt.add_argument("--stream", help="Stream messages as they're ready.", action="store_true")
    sub.set_defaults(func=CLIChatCompletion.create, args_model=CLIChatCompletionCreateArgs)


class CLIMessage(NamedTuple):
    role: ChatCompletionRole
    content: str


class CLIChatCompletionCreateArgs(BaseModel):
    message: List[CLIMessage]
    model: str
    n: Optional[int] = None
    max_tokens: Optional[int] = None
    temperature: Optional[float] = None
    top_p: Optional[float] = None
    stop: Optional[str] = None
    stream: bool = False


class CLIChatCompletion:
    @staticmethod
    def create(args: CLIChatCompletionCreateArgs) -> None:
        params: CompletionCreateParams = {
            "model": args.model,
            "messages": [
                {"role": cast(Literal["user"], message.role), "content": message.content} for message in args.message
            ],
            # type checkers are not good at inferring union types so we have to set stream afterwards
            "stream": False,
        }
        if args.temperature is not None:
            params["temperature"] = args.temperature
        if args.stop is not None:
            params["stop"] = args.stop
        if args.top_p is not None:
            params["top_p"] = args.top_p
        if args.n is not None:
            params["n"] = args.n
        if args.stream:
            params["stream"] = args.stream  # type: ignore
        if args.max_tokens is not None:
            params["max_tokens"] = args.max_tokens

        if args.stream:
            return CLIChatCompletion._stream_create(cast(CompletionCreateParamsStreaming, params))

        return CLIChatCompletion._create(cast(CompletionCreateParamsNonStreaming, params))

    @staticmethod
    def _create(params: CompletionCreateParamsNonStreaming) -> None:
        completion = get_client().chat.completions.create(**params)
        should_print_header = len(completion.choices) > 1
        for choice in completion.choices:
            if should_print_header:
                sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))

            content = choice.message.content if choice.message.content is not None else "None"
            sys.stdout.write(content)

            if should_print_header or not content.endswith("\n"):
                sys.stdout.write("\n")

            sys.stdout.flush()

    @staticmethod
    def _stream_create(params: CompletionCreateParamsStreaming) -> None:
        # cast is required for mypy
        stream = cast(  # pyright: ignore[reportUnnecessaryCast]
            Stream[ChatCompletionChunk], get_client().chat.completions.create(**params)
        )
        for chunk in stream:
            should_print_header = len(chunk.choices) > 1
            for choice in chunk.choices:
                if should_print_header:
                    sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))

                content = choice.delta.content or ""
                sys.stdout.write(content)

                if should_print_header:
                    sys.stdout.write("\n")

                sys.stdout.flush()

        sys.stdout.write("\n")