aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/r2r/mcp.py
blob: 33490ea138331d02ea44fe9078eaef9bf74f22d2 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Add to your local machine with `mcp install r2r/mcp.py -v R2R_API_URL=http://localhost:7272` or so.
from r2r import R2RClient


def id_to_shorthand(id: str) -> str:
    return str(id)[:7]


def format_search_results_for_llm(
    results,
) -> str:
    """
    Instead of resetting 'source_counter' to 1, we:
     - For each chunk / graph / web / doc in `results`,
     - Find the aggregator index from the collector,
     - Print 'Source [X]:' with that aggregator index.
    """
    lines = []

    # We'll build a quick helper to locate aggregator indices for each object:
    # Or you can rely on the fact that we've added them to the collector
    # in the same order. But let's do a "lookup aggregator index" approach:

    # 1) Chunk search
    if results.chunk_search_results:
        lines.append("Vector Search Results:")
        for c in results.chunk_search_results:
            lines.append(f"Source ID [{id_to_shorthand(c.id)}]:")
            lines.append(c.text or "")  # or c.text[:200] to truncate

    # 2) Graph search
    if results.graph_search_results:
        lines.append("Graph Search Results:")
        for g in results.graph_search_results:
            lines.append(f"Source ID [{id_to_shorthand(g.id)}]:")
            if hasattr(g.content, "summary"):
                lines.append(f"Community Name: {g.content.name}")
                lines.append(f"ID: {g.content.id}")
                lines.append(f"Summary: {g.content.summary}")
                # etc. ...
            elif hasattr(g.content, "name") and hasattr(
                g.content, "description"
            ):
                lines.append(f"Entity Name: {g.content.name}")
                lines.append(f"Description: {g.content.description}")
            elif (
                hasattr(g.content, "subject")
                and hasattr(g.content, "predicate")
                and hasattr(g.content, "object")
            ):
                lines.append(
                    f"Relationship: {g.content.subject}-{g.content.predicate}-{g.content.object}"
                )
            # Add metadata if needed

    # 3) Web search
    if results.web_search_results:
        lines.append("Web Search Results:")
        for w in results.web_search_results:
            lines.append(f"Source ID [{id_to_shorthand(w.id)}]:")
            lines.append(f"Title: {w.title}")
            lines.append(f"Link: {w.link}")
            lines.append(f"Snippet: {w.snippet}")

    # 4) Local context docs
    if results.document_search_results:
        lines.append("Local Context Documents:")
        for doc_result in results.document_search_results:
            doc_title = doc_result.title or "Untitled Document"
            doc_id = doc_result.id
            summary = doc_result.summary

            lines.append(f"Full Document ID: {doc_id}")
            lines.append(f"Shortened Document ID: {id_to_shorthand(doc_id)}")
            lines.append(f"Document Title: {doc_title}")
            if summary:
                lines.append(f"Summary: {summary}")

            if doc_result.chunks:
                # Then each chunk inside:
                for chunk in doc_result.chunks:
                    lines.append(
                        f"\nChunk ID {id_to_shorthand(chunk['id'])}:\n{chunk['text']}"
                    )

    result = "\n".join(lines)
    return result


# Create a FastMCP server

try:
    from mcp.server.fastmcp import FastMCP

    mcp = FastMCP("R2R Retrieval System")
except Exception as e:
    raise ImportError(
        "MCP is not installed. Please run `pip install mcp`"
    ) from e

# Pass lifespan to server
mcp = FastMCP("R2R Retrieval System")


# RAG query tool
@mcp.tool()
async def search(query: str) -> str:
    """
    Performs a

    Args:
        query: The question to answer using the knowledge base

    Returns:
        A response generated based on relevant context from the knowledge base
    """
    client = R2RClient()

    # Call the RAG endpoint
    search_response = client.retrieval.search(
        query=query,
    )
    return format_search_results_for_llm(search_response.results)


# RAG query tool
@mcp.tool()
async def rag(query: str) -> str:
    """
    Perform a Retrieval-Augmented Generation query

    Args:
        query: The question to answer using the knowledge base

    Returns:
        A response generated based on relevant context from the knowledge base
    """
    client = R2RClient()

    # Call the RAG endpoint
    rag_response = client.retrieval.rag(
        query=query,
    )

    return rag_response.results.generated_answer  # type: ignore


# Run the server if executed directly
if __name__ == "__main__":
    mcp.run()