diff options
Diffstat (limited to 'R2R/r2r/examples/scripts/run_hyde.py')
-rwxr-xr-x | R2R/r2r/examples/scripts/run_hyde.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/R2R/r2r/examples/scripts/run_hyde.py b/R2R/r2r/examples/scripts/run_hyde.py new file mode 100755 index 00000000..c82ce525 --- /dev/null +++ b/R2R/r2r/examples/scripts/run_hyde.py @@ -0,0 +1,33 @@ +import fire + +from r2r import R2RBuilder, R2RConfig +from r2r.base.abstractions.llm import GenerationConfig +from r2r.main.assembly.factory_extensions import R2RPipeFactoryWithMultiSearch + + +def main(task_prompt_name="hyde", query="Who was aristotle?"): + # Load the configuration file + config = R2RConfig.from_json() + + app = ( + R2RBuilder(config) + .with_pipe_factory(R2RPipeFactoryWithMultiSearch) + .build( + # Add optional override arguments which propagate to the pipe factory + task_prompt_name=task_prompt_name, + ) + ) + + # Run the RAG pipeline through the R2R application + result = app.rag( + query, + query_transform_generation_config=GenerationConfig(model="gpt-4o"), + rag_generation_config=GenerationConfig(model="gpt-3.5-turbo"), + ) + + print(f"Search Results:\n\n{result.search_results}") + print(f"RAG Results:\n\n{result.completion}") + + +if __name__ == "__main__": + fire.Fire(main) |