aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/examples/scripts/run_hyde.py
blob: c82ce525309974748793ca792a89e34392f3d5b9 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import fire

from r2r import R2RBuilder, R2RConfig
from r2r.base.abstractions.llm import GenerationConfig
from r2r.main.assembly.factory_extensions import R2RPipeFactoryWithMultiSearch


def main(task_prompt_name="hyde", query="Who was aristotle?"):
    # Load the configuration file
    config = R2RConfig.from_json()

    app = (
        R2RBuilder(config)
        .with_pipe_factory(R2RPipeFactoryWithMultiSearch)
        .build(
            # Add optional override arguments which propagate to the pipe factory
            task_prompt_name=task_prompt_name,
        )
    )

    # Run the RAG pipeline through the R2R application
    result = app.rag(
        query,
        query_transform_generation_config=GenerationConfig(model="gpt-4o"),
        rag_generation_config=GenerationConfig(model="gpt-3.5-turbo"),
    )

    print(f"Search Results:\n\n{result.search_results}")
    print(f"RAG Results:\n\n{result.completion}")


if __name__ == "__main__":
    fire.Fire(main)