diff options
| -rw-r--r-- | topics/ai/gn_ai.gmi | 395 |
1 files changed, 394 insertions, 1 deletions
diff --git a/topics/ai/gn_ai.gmi b/topics/ai/gn_ai.gmi index 5fd43b7..657e996 100644 --- a/topics/ai/gn_ai.gmi +++ b/topics/ai/gn_ai.gmi @@ -20,6 +20,7 @@ This work is an extension of the GNQA system initiated by Shelby and Pjotr. * [X] Build a RAG system and test with small corpus of mapping results * [X] Experiment with actual mapping results and metadata * [X] Move from RAG to agent +* [] Optimize AI system * [] Scale analysis to more data * [] Compare performance of open LLMs with Claude in the system @@ -167,6 +168,9 @@ The locus associated with this phenotype is Rsm10000001653. It was time to proceed to testing. The results I show below are not exactly for the RAG system I explained above. I was improving the RAG in parallel. So watch out. I will explain everything in the next task :) +For now, you can have a look at the first implementation at: +=> https://github.com/johanmed/gn-rag/commit/2cf0b74442e8f7e3a67d563b882f3ab25a4ceb6d + The goal was to try a complex query the previous system failed on. You can see that the question was indeed not atomic. ``` @@ -189,4 +193,393 @@ ata={}, page_content='\nIn plain English, this data refers to a mapped locus ass ### Move from RAG to agent -This is where I made the system more autonomous i.e agentic. I am now going to explain how I did it. \ No newline at end of file +This is where I made the system more autonomous i.e agentic. I am now going to explain how I did it. I read a couple of sources and found that RAG system built with LangChain could be made agentic by using LangGraph. This creates a graph structure which splits the task among different nodes or agents. Each agent achieves a specific subtasks and a final node manages the integration. + +Checkout this commit to see the results: +=> https://github.com/johanmed/gn-rag/commit/ecde30a31588605358007cc39df25976b9c2e295 + +You can clearly see differences between *rag_langchain.py* and *rag_langgraph.py* + +Basically, + +``` +ef ask_question(self, question: str): + start=time.time() + memory_var=self.memory.load_memory_variables({}) + chat_history=memory_var.get('chat_history', '') + result=self.retrieval_chain.invoke( + {'question': question, + 'input': question, + 'chat_history': chat_history}) + answer=result.get("answer") + citations=result.get("context") + self.memory.save_context( + {'input': question}, + {'answer': answer}) + # Close LLMs + GENERATIVE_MODEL.client.close() + SUMMARY_MODEL.client.close() + end=time.time() + print(f'ask_question: {end-start}') + return { + "question": question, + "answer": answer, + "citations": citations, + } +``` + +became: + +``` +def retrieve(self, state: State) -> dict: + # Define graph node for retrieval + prompt = f""" + You are powerful data retriever and you strictly return + what is asked for. + Retrieve relevant documents for the query below, + excluding these documents: {state.get('seen_documents', [])} + Query: {state['input']}""" + retrieved_docs = self.ensemble_retriever.invoke(prompt) + return {"input": state["input"], + "context": retrieved_docs, + "digested_context": state.get("digested_context", []), + "result_count": state.get("result_count", 0), + "target": state.get("target", 3), + "max_iterations": state.get("max_iterations", 5), + "should_continue": "naturalize", + "iterations": state.get("iterations", 0) + 1, # Add one per run + "chat_history": state.get("chat_history", []), + "answer": state.get("answer", ""), + "seen_documents": state.get("seen_documents", [])} + + def manage(self, state:State) -> dict: + # Define graph node for task orchestration + context = state.get("context", []) + digested_context = state.get("digested_context", []) + answer = state.get("answer", "") + iterations = state.get("iterations", 0) + chat_history = state.get("chat_history", []) + result_count = state.get("result_count", 0) + target = state.get("target", 3) + max_iterations = state.get("max_iterations", 5) + should_continue = state.get("should_continue", "retrieve") + # Orchestration logic + if iterations >= max_iterations or result_count >= target: + should_continue = "summarize" + elif should_continue == "retrieve": + # Reset fields + context = [] + digested_context = [] + answer = "" + elif should_continue == "naturalize" and not context: + should_continue = "retrieve" # Can't naturalize without context + context = [] + digested_context = [] + answer = "" + elif should_continue == "analyze" and \ + (not context or not digested_context): + should_continue = "retrieve" # Can't analyze without context + context = [] + digested_context = [] + answer = "" + elif should_continue == "check_relevance" and not answer: + should_continue = "analyze" # Can't check relevance without answer + elif should_continue not in ["retrieve", \ + "naturalize", "check_relevance", "analyze", "summarize"]: + should_continue = "summarize" # Fallback + return {"input": state["input"], + "should_continue": should_continue, + "result_count": result_count, + "target": target, + "iterations": iterations, + "max_iterations": max_iterations, + "context": context, + "digested_context": digested_context, + "chat_history": chat_history, + "answer": answer, + "seen_documents": state.get("seen_documents", [])} + + def analyze(self, state:State) -> dict: + # Define graph node for analysis and text generation + context = "\n".join(state.get("digested_context", [])) + existing_history="\n".join(state.get("chat_history", [])) \ + if state.get("chat_history") else "" + iterations = state.get("iterations", 0) + max_iterations = state.get("max_iterations", 5) + result_count = state.get("result_count", 0) + target = state.get("target", 3) + if not context: # Cannot proceed without context + should_continue = "summarize" if iterations >= max_iterations \ + or result_count >= target else "retrieve" + response = "" + else: + prompt = f""" + <|im_start|>system + You are an experienced analyst that can use available information + to provide accurate and concise feedback. + <|im_end|> + <|im_start|>user + Answer the question below using following information. + Context: {context} + History: {existing_history} + Question: {state["input"]} + Answer: + <|im_end|> + <|im_start|>assistant""" + response = GENERATIVE_MODEL.invoke(prompt) + if not response or not isinstance(response, str) or \ + response.strip() == "": # Need valid generation + should_continue = "summarize" if iterations >= max_iterations \ + or result_count >= target else "retrieve" + response = "" # Ensure a clean state + else: + should_continue = "check_relevance" + return {"input": state["input"], + "answer": response, + "should_continue": should_continue, + "context": state.get("context", []), + "digested_context": state.get("digested_context", []), + "iterations": iterations, + "max_iterations": max_iterations, + "result_count": result_count, + "target": target, + "chat_history": state.get("chat_history", []), + "seen_documents": state.get("seen_documents", [])} + + + def summarize(self, state:State) -> dict: + # Define node for summarization + existing_history = state.get("chat_history", []) + current_interaction=f""" + User: {state["input"]}\nAssistant: {state["answer"]}""" + full_context = "\n".join(existing_history) + "\n" + \ + current_interaction if existing_history else current_interaction + result_count = state.get("result_count", 0) + target = state.get("target", 3) + iterations = state.get("iterations", 0) + max_iterations = state.get("max_iterations", 5) + prompt = f""" + <|system|> + You are an excellent and concise summary maker. + <|end|> + <|user|> + Summarize in bullet points the conversation below. + Follow this format: input - answer + Conversation: {full_context} + <|end|> + <|assistant|>""" + summary = GENERATIVE_MODEL.invoke(prompt).strip() # central task + if not summary or not isinstance(summary, str) or summary.strip() == "": + summary = f"- {state['input']} - No valid answer generated" + should_continue="end" if result_count >= target or \ + iterations >= max_iterations else "retrieve" + updated_history = existing_history + [summary] # update chat_history + print(f"\nChat history in summarize: {updated_history}") + return {"input": state["input"], + "answer": summary, + "should_continue": should_continue, + "context": state.get("context", []), + "digested_context": state.get("digested_context", []), + "iterations": iterations, + "max_iterations": max_iterations, + "result_count": result_count, + "target": target, + "chat_history": updated_history, + "seen_documents": state.get("seen_documents", [])} + + def check_relevance(self, state:State) -> dict: + # Define node to check relevance of retrieved data + context = "\n".join(state.get("digested_context", [])) + result_count = state.get("result_count", 0) + target = state.get("target", 3) + iterations = state.get("iterations", 0) + max_iterations = state.get("max_iterations", 5) + seen_documents = state.get("seen_documents", []) + prompt = f""" + <|system|> + You are an expert in evaluating data relevance. You do it seriously. + <|end|> + <|user|> + Assess if the provided answer is relevant to the query. + Return only yes or no. Nothing else. + Answer: {state["answer"]} + Query: {state["input"]} + Context: {context} + <|end|> + <|assistant|>""" + assessment = GENERATIVE_MODEL.invoke(prompt).strip() + if assessment=="yes": + result_count = result_count + 1 + should_continue = "summarize" + elif result_count >= target or iterations >= max_iterations: + should_continue = "summarize" + else: + should_continue = "retrieve" + seen_documents.extend([doc.page_content for doc in \ + state.get("context", [])]) + return {"input": state["input"], + "context": state.get("context", []), + "digested_context": state.get("digested_context", []), + "iterations": iterations, + "max_iterations": max_iterations, + "answer": state["answer"], + "result_count": result_count, + "target": target, + "seen_documents": seen_documents, + "chat_history": state.get("chat_history", []), + "should_continue": should_continue} + + def route_manage(self, state: State) -> str: + should_continue = state.get("should_continue", "retrieve") + iterations = state.get("iterations", 0) + max_iterations = state.get("max_iterations", 5) + result_count = state.get("result_count", 0) + target = state.get("target", 3) + context = state.get("context", []) + digested_context = state.get("digested_context", []) + answer = state.get("answer", "") + # Validate state and enforce termination + if iterations >= max_iterations or result_count >= target: + return "summarize" + if should_continue not in ["retrieve", "naturalize", \ + "check_relevance", "analyze", "summarize"]: + return "summarize" # Fallback to summarize + return should_continue + + def initialize_langgraph_chain(self) -> Any: + graph_builder = StateGraph(State) + graph_builder.add_node("manage", self.manage) + graph_builder.add_node("retrieve", self.retrieve) + graph_builder.add_node("naturalize", self.naturalize) + graph_builder.add_node("check_relevance", self.check_relevance) + graph_builder.add_node("analyze", self.analyze) + graph_builder.add_node("summarize", self.summarize) + graph_builder.add_edge(START, "manage") + graph_builder.add_edge("retrieve", "naturalize") + graph_builder.add_edge("naturalize", "analyze") + graph_builder.add_edge("analyze", "check_relevance") + graph_builder.add_edge("check_relevance", "manage") + graph_builder.add_edge("summarize", END) + graph_builder.add_conditional_edges( + "manage", + self.route_manage, + {"retrieve": "retrieve", + "naturalize": "naturalize", + "check_relevance": "check_relevance", + "analyze": "analyze", + "summarize": "summarize"}) + graph=graph_builder.compile() + return graph + + async def invoke_langgraph(self, question: str) -> Any: + graph = self.initialize_langgraph_chain() + initial_state = { + "input": question, + "chat_history": [], + "context": [], + "digested_context": [], + "seen_documents": [], + "answer": "", + "iterations": 0, + "result_count": 0, + "should_continue": "retrieve", + "target": 3, # Explain magic number 3 + "max_iterations": 5 # Explain magic number 5 + } + result = await graph.ainvoke(initial_state) # Run graph asynchronously + return result + + + def answer_question(self, question: str) -> Any: + start = time.time() + result = asyncio.run(self.invoke_langgraph(question)) + end = time.time() + print(f'answer_question: {end-start}') + return {"result": result["chat_history"], + "state": result} +``` + +As mentioned above, we quickly spotted the need for the naturalization of RDF triples. This explains the addition of a naturalization node to the graph: + +``` +def naturalize(self, state: State) -> dict: + # Define graph node for RDF naturalization + prompt = f""" + <|im_start|>system + You are extremely good at naturalizing RDF and inferring meaning. + <|im_end|> + <|im_start|>user + Take element in the list of RDF triples one by one and + make it sounds like Plain English. Repeat for each the subject + which is at the start. You should return a list. Nothing else. + List: ["Entity http://genenetwork.org/id/traitBxd_20537 \ + \nhas http://purl.org/dc/terms/isReferencedBy of \ + http://genenetwork.org/id/unpublished22893", "has \ + http://genenetwork.org/term/locus of \ + http://genenetwork.org/id/Rsm10000002554"] + <|im_end|> + <|im_start|>assistant + New list: ["traitBxd_20537 isReferencedBy unpublished22893", \ + "traitBxd_20537 has a locus Rsm10000002554"] + <|im_end|> + <|im_start|>user + Take element in the list of RDF triples one by one and + make it sounds like Plain English. Repeat for each the subject + which is at the start. You should return a list. Nothing else. + List: {state.get("context", [])} + <|im_start|>end + <|im_start|>assistant""" + response = GENERATIVE_MODEL.invoke(prompt) + print(f"Response in naturalize: {response}") + if isinstance(response, str): + start=response.find("[") + end=response.rfind("]") + 1 # offset by 1 to make slicing + response=json.loads(response[start:end]) + else: + response=[] + return {"input": state["input"], + "context": state.get("context", []), + "digested_context": response, + "result_count": state.get("result_count", 0), + "target": state.get("target", 3), + "max_iterations": state.get("max_iterations", 5), + "should_continue": "analyze", + "iterations": state.get("iterations", 0), + "chat_history": state.get("chat_history", []), + "answer": state.get("answer", ""), + "seen_documents": state.get("seen_documents", [])} +``` + +Next step is to compare the RAG system (rag_langchain.py) to the agent system (rag_langgraph.py) on some queries. + +* Query 1 + +``` +Question: What is the lod score of BXDPublish_10187_gemma_GWA at D12mit280? + +RAG response: 3.01245 + +Agent response: 5.21 +``` + +Hmm, the result of the RAG system is dubious. I did not expect long floating number for the LOD scores. Fortunately, Pjotr's endpoint is useful to make quick query agains the RDF data. By the way, one could see how complementary SPARQL queries and LLM calls are. It was very thoughtful to pass through RDF :) + +Using the endpoint, I have fetched all lodscores for BXDPublish_10187_gemma_GWA at D12mit280 and only the result of the agent system appeared to be true. + +* Query 2 + +``` +Question: I want you look for D12mit280. You are allowed to initiate many rounds of retrieval until you get 20 relevant results. Next, extract the lod score and trait for each result. List for me traits that have a lod score > 4.0. Join to the list the corresponding lod scores so I can confirm. Show results using the following format: trait - lod score + +RAG response: It seems there's a misunderstanding. The provided data doesn't include any results related to "D12mit280". I can only provide information or perform tasks based on the given context and data. If you have another query or need clarification on something else, feel free to ask. + +Agent response: It appears there are no results for D12mit280 with a lod score > 4.0 in the provided entities. If you have more data or different criteria, please let me know so I can assist further. + +``` + +This was surprising. My hypotehsis was that the node in charge of retrieval was not fetching relevant documents for the LLM to elaborate a sound feedback. There is only one way to confirm: check the documents retrieved. + +Printing the documents after retrieval showed that I was actually right. I also noticed a couple of other problems in the execution of the program. Nodes were mostly not executed in the order I expected. Consequently, I decided to get into a long quest of incremental improvements :) + +### Optimize AI system + |
