diff options
Diffstat (limited to 'R2R/r2r/prompts/local/r2r_prompt_provider.py')
-rwxr-xr-x | R2R/r2r/prompts/local/r2r_prompt_provider.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/R2R/r2r/prompts/local/r2r_prompt_provider.py b/R2R/r2r/prompts/local/r2r_prompt_provider.py new file mode 100755 index 00000000..830c4203 --- /dev/null +++ b/R2R/r2r/prompts/local/r2r_prompt_provider.py @@ -0,0 +1,69 @@ +import json +import logging +import os +from typing import Any, Optional + +from r2r.base import Prompt, PromptProvider + +logger = logging.getLogger(__name__) + + +class R2RPromptProvider(PromptProvider): + def __init__(self, file_path: Optional[str] = None): + self.prompts: dict[str, Prompt] = {} + self._load_prompts_from_jsonl(file_path=file_path) + + def _load_prompts_from_jsonl(self, file_path: Optional[str] = None): + if not file_path: + file_path = os.path.join( + os.path.dirname(__file__), "defaults.jsonl" + ) + try: + with open(file_path, "r") as file: + for line in file: + if line.strip(): + data = json.loads(line) + self.add_prompt( + data["name"], + data["template"], + data.get("input_types", {}), + ) + except json.JSONDecodeError as e: + error_msg = f"Error loading prompts from JSONL file: {e}" + logger.error(error_msg) + raise ValueError(error_msg) + + def add_prompt( + self, name: str, template: str, input_types: dict[str, str] + ) -> None: + if name in self.prompts: + raise ValueError(f"Prompt '{name}' already exists.") + self.prompts[name] = Prompt( + name=name, template=template, input_types=input_types + ) + + def get_prompt( + self, prompt_name: str, inputs: Optional[dict[str, Any]] = None + ) -> str: + if prompt_name not in self.prompts: + raise ValueError(f"Prompt '{prompt_name}' not found.") + prompt = self.prompts[prompt_name] + if inputs is None: + return prompt.template + return prompt.format_prompt(inputs) + + def update_prompt( + self, + name: str, + template: Optional[str] = None, + input_types: Optional[dict[str, str]] = None, + ) -> None: + if name not in self.prompts: + raise ValueError(f"Prompt '{name}' not found.") + if template: + self.prompts[name].template = template + if input_types: + self.prompts[name].input_types = input_types + + def get_all_prompts(self) -> dict[str, Prompt]: + return self.prompts |