aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/inference/prompts/_patch.py
blob: 14ad4f62b4c1ad9555845ee8ce2e9550d0b91fa1 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
# pylint: disable=line-too-long,R
"""Customize generated code here.

Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
"""

import traceback
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing_extensions import Self
from ._core import Prompty
from ._mustache import render
from ._parsers import invoke_parser
from ._prompty_utils import load, prepare
from ._utils import remove_leading_empty_space


class PromptTemplate:
    """The helper class which takes variant of inputs, e.g. Prompty format or string, and returns the parsed prompt in an array."""

    @classmethod
    def from_prompty(cls, file_path: str) -> Self:
        """Initialize a PromptTemplate object from a prompty file.

        :param file_path: The path to the prompty file.
        :type file_path: str
        :return: The PromptTemplate object.
        :rtype: PromptTemplate
        """
        if not file_path:
            raise ValueError("Please provide file_path")

        # Get the absolute path of the file by `traceback.extract_stack()`, it's "-2" because:
        #  In the stack, the last function is the current function.
        #  The second last function is the caller function, which is the root of the file_path.
        stack = traceback.extract_stack()
        caller = Path(stack[-2].filename)
        abs_file_path = Path(caller.parent / Path(file_path)).resolve().absolute()

        prompty = load(str(abs_file_path))
        return cls(prompty=prompty)

    @classmethod
    def from_string(cls, prompt_template: str, api: str = "chat", model_name: Optional[str] = None) -> Self:
        """Initialize a PromptTemplate object from a message template.

        :param prompt_template: The prompt template string.
        :type prompt_template: str
        :param api: The API type, e.g. "chat" or "completion".
        :type api: str
        :param model_name: The model name, e.g. "gpt-4o-mini".
        :type model_name: str
        :return: The PromptTemplate object.
        :rtype: PromptTemplate
        """
        return cls(
            api=api,
            prompt_template=prompt_template,
            model_name=model_name,
            prompty=None,
        )

    def __init__(
        self,
        *,
        api: str = "chat",
        prompty: Optional[Prompty] = None,
        prompt_template: Optional[str] = None,
        model_name: Optional[str] = None,
    ) -> None:
        self.prompty = prompty
        if self.prompty is not None:
            self.model_name = (
                self.prompty.model.configuration["azure_deployment"]
                if "azure_deployment" in self.prompty.model.configuration
                else None
            )
            self.parameters = self.prompty.model.parameters
            self._config = {}
        elif prompt_template is not None:
            self.model_name = model_name
            self.parameters = {}
            # _config is a dict to hold the internal configuration
            self._config = {
                "api": api if api is not None else "chat",
                "prompt_template": prompt_template,
            }
        else:
            raise ValueError("Please pass valid arguments for PromptTemplate")

    def create_messages(self, data: Optional[Dict[str, Any]] = None, **kwargs) -> List[Dict[str, Any]]:
        """Render the prompt template with the given data.

        :param data: The data to render the prompt template with.
        :type data: Optional[Dict[str, Any]]
        :return: The rendered prompt template.
        :rtype: List[Dict[str, Any]]
        """
        if data is None:
            data = kwargs

        if self.prompty is not None:
            parsed = prepare(self.prompty, data)
            return parsed
        elif "prompt_template" in self._config:
            prompt_template = remove_leading_empty_space(self._config["prompt_template"])
            system_prompt_str = render(prompt_template, data)
            parsed = invoke_parser(None, system_prompt_str)
            return parsed
        else:
            raise ValueError("Please provide valid prompt template")


def patch_sdk():
    """Do not remove from this file.

    `patch_sdk` is a last resort escape hatch that allows you to do customizations
    you can't accomplish using the techniques described in
    https://aka.ms/azsdk/python/dpcodegen/python/customize
    """