diff options
Diffstat (limited to 'R2R/r2r/base/abstractions/prompt.py')
-rwxr-xr-x | R2R/r2r/base/abstractions/prompt.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/R2R/r2r/base/abstractions/prompt.py b/R2R/r2r/base/abstractions/prompt.py new file mode 100755 index 00000000..e37eeb5f --- /dev/null +++ b/R2R/r2r/base/abstractions/prompt.py @@ -0,0 +1,31 @@ +"""Abstraction for a prompt that can be formatted with inputs.""" + +from typing import Any + +from pydantic import BaseModel + + +class Prompt(BaseModel): + """A prompt that can be formatted with inputs.""" + + name: str + template: str + input_types: dict[str, str] + + def format_prompt(self, inputs: dict[str, Any]) -> str: + self._validate_inputs(inputs) + return self.template.format(**inputs) + + def _validate_inputs(self, inputs: dict[str, Any]) -> None: + for var, expected_type_name in self.input_types.items(): + expected_type = self._convert_type(expected_type_name) + if var not in inputs: + raise ValueError(f"Missing input: {var}") + if not isinstance(inputs[var], expected_type): + raise TypeError( + f"Input '{var}' must be of type {expected_type.__name__}, got {type(inputs[var]).__name__} instead." + ) + + def _convert_type(self, type_name: str) -> type: + type_mapping = {"int": int, "str": str} + return type_mapping.get(type_name, str) |