aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base/abstractions/prompt.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/base/abstractions/prompt.py')
-rwxr-xr-xR2R/r2r/base/abstractions/prompt.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/R2R/r2r/base/abstractions/prompt.py b/R2R/r2r/base/abstractions/prompt.py
new file mode 100755
index 00000000..e37eeb5f
--- /dev/null
+++ b/R2R/r2r/base/abstractions/prompt.py
@@ -0,0 +1,31 @@
+"""Abstraction for a prompt that can be formatted with inputs."""
+
+from typing import Any
+
+from pydantic import BaseModel
+
+
+class Prompt(BaseModel):
+ """A prompt that can be formatted with inputs."""
+
+ name: str
+ template: str
+ input_types: dict[str, str]
+
+ def format_prompt(self, inputs: dict[str, Any]) -> str:
+ self._validate_inputs(inputs)
+ return self.template.format(**inputs)
+
+ def _validate_inputs(self, inputs: dict[str, Any]) -> None:
+ for var, expected_type_name in self.input_types.items():
+ expected_type = self._convert_type(expected_type_name)
+ if var not in inputs:
+ raise ValueError(f"Missing input: {var}")
+ if not isinstance(inputs[var], expected_type):
+ raise TypeError(
+ f"Input '{var}' must be of type {expected_type.__name__}, got {type(inputs[var]).__name__} instead."
+ )
+
+ def _convert_type(self, type_name: str) -> type:
+ type_mapping = {"int": int, "str": str}
+ return type_mapping.get(type_name, str)