One of the concerns with modern AI chatbots is their hallucinations This means they might give answers that are wrong or made-up.
AI chatbots and text generators can be pretty unpredictable, especially even if you learn how to prompt effectively.
If you leave AI models with more freedom, they might provide inaccurate or even contradicting information.
In this video, we’re going to look closely at what is Corrective RAG, how the Corrective Retrieval Augmented Generation (CRAG) process works, what is the difference between RAG and Corrective RAG and how to use langGraph, Corrective RAG and Mistral AI to create a Powerful Rag Chatbot.
Before we start! 🦸🏻♀️
If you like this topic and you want to support me:
- Clap my article 50 times; that will really help me out.👏
- Follow me on Medium and subscribe to get my latest article🫶
- What content do you want to see me sharing? get started
For those who may not be as familiar with technical details, let me provide a bit of background to make things easier to understand.
What is Corrective RAG
Corrective RAG is a comprehensive framework that combines retrieval evaluation, corrective actions, web searches, and generative model integration to enhance the accuracy, reliability, and robustness of text generation models by ensuring the utilization of accurate and relevant knowledge,
In simple terms, Corrective RAG is the method used to grade documents based on their relevance to the data source. If the data source is related to the question, the process proceeds to generation. Otherwise, the framework seeks additional data sources and utilizes web search to supplement retrieval
Workflow
- Retrieval Evaluator: Before utilizing the retrieved documents, CRAG employs a retrieval evaluator to assess the overall quality of the retrieved information. This evaluator helps determine the relevance and reliability of the retrieved documents for a given query. It plays a crucial role in ensuring that only accurate and relevant information is used for text generation.
- Knowledge Retrieval Actions: Based on the assessment by the retrieval evaluator, different knowledge retrieval actions are triggered:
- Correct: If the retrieved documents are deemed accurate, they undergo a refinement process to extract more precise knowledge strips. This refinement operation involves knowledge decomposition, filtering, and recomposition to enhance the quality of the information.
- Incorrect: In cases where the retrieved documents are deemed inaccurate or irrelevant, they are discarded. Instead, CRAG resorts to large-scale web searches to find complementary knowledge sources for corrections.
- Ambiguous: When the system cannot confidently determine whether the retrieved documents are correct or incorrect, a soft action called “Ambiguous” is triggered, combining elements of both correct and incorrect actions.
- Generative Model Integration: After optimizing the retrieval results through the corrective actions, any generative model can be adapted to generate the final text output. CRAG ensures that the generative model receives refined and accurate information for text generation.
- Plug-and-Play Adaptability: CRAG is designed to be plug-and-play, meaning it can be seamlessly integrated into existing Retrieval-Augmented Generation (RAG) frameworks. It has been experimentally implemented with standard RAG and Self-RAG models, demonstrating its adaptability and effectiveness in improving text generation performance across various datasets and tasks.
RAG VS CRAG
RAG focuses on integrating external knowledge into the generation process, and CRAG takes a step further by evaluating, refining, and integrating this knowledge to improve the accuracy and reliability of language models.
Else See : Five Technique : VLLM + Torch + Flash_Attention =Super Local LLM
Before we dive into our application, we will create an ideal environment for the code to work. For this, we need to install the requirement.txt
pip install -r requirements.txt
Once installed we import Langchain, Langchain Google, langchain community, os, typing, langchain core, operator, langchain schema, langGraph and langchain Openai
from dotenv import load_dotenv
from langchain import hub
from langchain.output_parsers import PydanticOutputParser
from langchain_core.output_parsers import StrOutputParser
from langchain.schema import Document
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.vectorstores import Chroma
from langchain_community.chat_models import ChatOllama
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI,
GoogleGenerativeAIEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langgraph.graph import END, StateGraph
from typing import Dict, TypedDict
from langchain.prompts import PromptTemplate
import pprint
import os
We set run_local
to ‘No’ to decide whether you want to run the language model (LLM) locally or not. For models
, you have the choice between two paid API models: Gemini Pro and OpenAI; you can choose whichever you prefer. For local_llm
, I’m currently using ‘Solar’, but you can decide which local LLM you prefer by checking the Ollama table. Lastly, we set the Tavilty API key.
run_local = 'No'
models = "Google"
openai_api_key = "Your_api"
google_api_key = "Your_api"
local_llm = 'Solar'
os.environ["TAVILY_API_KEY"] = "Your_api"
we load documents from a URL, split them into smaller chunks, generate embeddings based on the chosen model, and index these embeddings for retrieval. The choice of embedding model depends on the configuration variables such as run_local
and models
, which determines whether to use local embeddings, OpenAI or Google.
# Split documents
url = 'https://lilianweng.github.io/posts/2023-06-23-agent/'
loader = WebBaseLoader(url)
docs = loader.load()
# Split
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=500, chunk_overlap=100
)
all_splits = text_splitter.split_documents(docs)
# Embed and index
if run_local == 'Yes':
embeddings = GPT4AllEmbeddings()
elif models == 'openai':
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
else:
embeddings = GoogleGenerativeAIEmbeddings(
model="models/embedding-001", google_api_key=google_api_key
)
# Index
vectorstore = Chroma.from_documents(
documents=all_splits,
collection_name="rag-chroma",
embedding=embeddings,
)
retriever = vectorstore.as_retriever()
defines a class called GraphState
, which is used to represent the state of a graph
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
keys: A dictionary where each key is a string.
"""
keys: Dict[str, any]
the retrieve
function retrieves relevant documents based on a question from the current graph state, updates the state by adding the retrieved documents, and returns the modified state dictionary.
def retrieve(state):
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents,
that contains retrieved documents
"""
print("---RETRIEVE---")
state_dict = state["keys"]
question = state_dict["question"]
local = state_dict["local"]
documents = retriever.get_relevant_documents(question)
return {"keys": {"documents": documents, "local": local,
"question": question}}
The function named generate
takes in a dictionary called state
, representing the current graph state, and generates an answer based on the provided question and retrieved documents
def generate(state):
"""
Generate answer
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation,
that contains LLM generation
"""
print("---GENERATE---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
# Prompt
prompt = hub.pull("rlm/rag-prompt")
# LLM Setup
if run_local == "Yes":
llm = ChatOllama(model=local_llm,
temperature=0)
elif models == "openai" :
llm = ChatOpenAI(
model="gpt-4-0125-preview",
temperature=0 ,
openai_api_key=openai_api_key
)
else:
llm = ChatGoogleGenerativeAI(model="gemini-pro",
google_api_key=google_api_key,
convert_system_message_to_human = True,
verbose = True,
)
# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# Chain
rag_chain = prompt | llm | StrOutputParser()
# Run
generation = rag_chain.invoke({"context": documents,
"question": question})
return {
"keys": {"documents": documents, "question": question,
"generation": generation}
}
This function is an important function in the code in which we going to implement the corrective algorithm and determine whether the retrieved documents are relevant to the question. if the retrieved document is relevant we generate text if not we going to use a web search to find relevant information
def grade_documents(state):
"""
Determines whether the retrieved documents are relevant to the question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates documents key with relevant documents
"""
print("---CHECK RELEVANCE---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
local = state_dict["local"]
# LLM
if run_local == "Yes":
llm = ChatOllama(model=local_llm,
temperature=0)
elif models == "openai" :
llm = ChatOpenAI(
model="gpt-4-0125-preview",
temperature=0 ,
openai_api_key=openai_api_key
)
else:
llm = ChatGoogleGenerativeAI(model="gemini-pro",
google_api_key=google_api_key,
convert_system_message_to_human = True,
verbose = True,
)
# Data model
class grade(BaseModel):
"""Binary score for relevance check."""
score: str = Field(description="Relevance score 'yes' or 'no'")
# Set up a parser + inject instructions into the prompt template.
parser = PydanticOutputParser(pydantic_object=grade)
from langchain_core.output_parsers import JsonOutputParser
parser = JsonOutputParser(pydantic_object=grade)
prompt = PromptTemplate(
template="""You are a grader assessing relevance of a retrieved
document to a user question. \n
Here is the retrieved document: \n\n {context} \n\n
Here is the user question: {question} \n
If the document contains keywords related to the user question,
grade it as relevant. \n
It does not need to be a stringent test. The goal is to filter out
erroneous retrievals. \n
Give a binary score 'yes' or 'no' score to indicate whether the
document is relevant to the question. \n
Provide the binary score as a JSON with no premable or
explaination and use these instructons to format the output:
{format_instructions}""",
input_variables=["query"],
partial_variables={"format_instructions":
parser.get_format_instructions()},
)
chain = prompt | llm | parser
# Score
filtered_docs = []
search = "No" # Default do not opt for web search to supplement retrieval
for d in documents:
score = chain.invoke(
{
"question": question,
"context": d.page_content,
"format_instructions": parser.get_format_instructions(),
}
)
grade = score["score"]
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
search = "Yes" # Perform web search
continue
return {
"keys": {
"documents": filtered_docs,
"question": question,
"local": local,
"run_web_search": search,
}
}
the function plays a critical role in ensuring that the information retrieved aligns with the question’s context. Depending on the relevance assessment, it either utilizes the retrieved documents for text generation or conducts a web search to obtain relevant information. This dynamic approach enhances the robustness and accuracy of the system’s response generation process.
def transform_query(state):
"""
Transform the query to produce a better question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates question key with a re-phrased question
"""
print("---TRANSFORM QUERY---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
local = state_dict["local"]
# Create a prompt template with format instructions and the query
prompt = PromptTemplate(
template="""You are generating questions that is well optimized for
retrieval. \n
Look at the input and try to reason about the underlying sematic
intent / meaning. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Provide an improved question without any premable, only respond
with the updated question: """,
input_variables=["question"],
)
# Grader
# LLM
if run_local == "Yes":
llm = ChatOllama(model=local_llm,
temperature=0)
elif models == "openai" :
llm = ChatOpenAI(
model="gpt-4-0125-preview",
temperature=0 ,
openai_api_key=openai_api_key
)
else:
llm = ChatGoogleGenerativeAI(model="gemini-pro",
google_api_key=google_api_key,
convert_system_message_to_human = True,
verbose = True,
)
# Prompt
chain = prompt | llm | StrOutputParser()
better_question = chain.invoke({"question": question})
return {
"keys": {"documents": documents, "question": better_question,
"local": local}
}
This function performs a web search using the Tavily API based on a reformulated question to enrich the existing documents with additional information retrieved from the web.
def web_search(state):
"""
Web search based on the re-phrased question using Tavily API.
Args:
state (dict): The current graph state
Returns:
state (dict): Web results appended to documents.
"""
print("---WEB SEARCH---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
local = state_dict["local"]
try:
tool = TavilySearchResults()
docs = tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
documents.append(web_results)
except Exception as error:
print(error)
return {"keys": {"documents": documents, "local": local,
"question": question}}
this function is crucial for deciding the agent’s next step. It takes the current state, including the question, filtered documents, and a web search. It then decides to either transform the query and initiate a web search if needed or proceed to generate an answer if relevant information is available. The function returns a string indicating the chosen action. Overall, it guides the agent’s decision-making in response to the existing state
def decide_to_generate(state):
"""
Determines whether to generate an answer or re-generate a question
for web search.
Args:
state (dict): The current state of the agent, including all keys.
Returns:
str: Next node to call
"""
print("---DECIDE TO GENERATE---")
state_dict = state["keys"]
question = state_dict["question"]
filtered_documents = state_dict["documents"]
search = state_dict["run_web_search"]
if search == "Yes":
# All documents have been filtered check_relevance
# We will re-generate a new query
print("---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---")
return "transform_query"
else:
# We have relevant documents, so generate answer
print("---DECISION: GENERATE---")
return "generate"
We create a state graph (workflow) using GraphState to manage stages such as retrieval, grading, generation, query transformation, and web search. Starting from “retrieve,” transitions occur between nodes. Conditional edges from “grade_documents” decide whether to proceed to “transform_query” or “generate.” Subsequently, “transform_query” leads to “web_search” and then “generate.” Finally, “generate” marks the end of the workflow, which is compiled into an application (app).
Else Read LangGraph: Create Your Hyper AI Agent
workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("retrieve", retrieve) # retrieve
workflow.add_node("grade_documents", grade_documents) # grade documents
workflow.add_node("generate", generate) # generatae
workflow.add_node("transform_query", transform_query) # transform_query
workflow.add_node("web_search", web_search) # web search
# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)
# Compile
app = workflow.compile()
let’s prepare inputs containing a question and an LLM on whether to run locally. It then streams these inputs through the application (app
). For each output received, it iterates through the items and prints the node’s name. Optionally, it can print the full state at each node. Finally, it prints the generated answer after the streaming process.
# Run
inputs = {
"keys": {
"question": 'Explain how the different types of agent memory work?',
"local": run_local,
}
}
for output in app.stream(inputs):
for key, value in output.items():
# Node
print(f"Node '{key}':")
# Optional: print full state at each node
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
pprint.pprint("\n---\n")
# Final generation
pprint.pprint(value['keys']['generation'])
Results: OpenAI
tags=['Chroma', 'OpenAIEmbeddings'] vectorstore=<langchain_community.vectorstores.chroma.Chroma object at 0x000001FB133B9760>
---RETRIEVE---
Node 'retrieve':
'\n---\n'
---CHECK RELEVANCE---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
Node 'grade_documents':
'\n---\n'
---DECIDE TO GENERATE---
---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---
---TRANSFORM QUERY---
Node 'transform_query':
'\n---\n'
---WEB SEARCH---
Node 'web_search':
'\n---\n'
---GENERATE---
Node 'generate':
'\n---\n'
Node '__end__':
'\n---\n'
('Short-term memory in agents functions through in-context learning, utilizing '
"the model's current context to learn and make decisions. Long-term memory is "
'achieved by retaining and recalling information over extended periods, often '
'through an external vector store and fast retrieval mechanisms, allowing the '
'agent to access a vast amount of information beyond its immediate context. '
'Additionally, agents can enhance their memory and decision-making '
'capabilities by using external APIs, memory streams for recording '
'experiences, and retrieval models to surface relevant information based on '
'recency, importance, and relevance.')
Results: Google Pro
tags=['Chroma', 'GoogleGenerativeAIEmbeddings'] vectorstore=<langchain_community.vectorstores.chroma.Chroma object at 0x000002043E4C5370>
---RETRIEVE---
Node 'retrieve':
'\n---\n'
---CHECK RELEVANCE---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
Node 'grade_documents':
'\n---\n'
---DECIDE TO GENERATE---
---DECISION: GENERATE---
---GENERATE---
Node 'generate':
'\n---\n'
Node '__end__':
'\n---\n'
('Short-term memory is used for in-context learning and is restricted by the '
'context window length of the Transformer. Long-term memory is an external '
'vector store that can be accessed quickly during query time. Memory stream '
"records a comprehensive list of agents' experiences in natural language.")
Results: Local LLM
--RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK RELEVANCE---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
"Node 'grade_documents':"
'\n---\n'
---DECIDE TO GENERATE---
---DECISION: GENERATE---
---GENERATE---
"Node 'generate':"
'\n---\n'
"Node '__end__':"
'\n---\n'
(' In an LLM (large language model)-powered autonomous agent system, LLM '
'functions as the agent’s brain, complemented by several key components: '
'planning and memory.\n'
'\n'
'Planning involves breaking down large tasks into smaller subgoals for '
'efficient handling of complex tasks and self-criticism and refinement to '
'improve results.\n'
'\n'
'Memory includes short-term memory, which utilizes in-context learning, and '
'long-term memory, providing the agent with the capability to retain and '
'recall information over extended periods using an external vector store and '
'fast retrieval. The agent also learns to call external APIs for missing '
'information.\n'
'\n'
'Types of Memory:\n'
'1. Sensory Memory: retains impressions of sensory information for a few '
'seconds.\n'
'2. Short-Term Memory (STM) or Working Memory: stores information needed for '
'complex cognitive tasks and lasts for 20-30 seconds.\n'
'3. Long-Term Memory (LTM): stores information for a remarkably long time, '
'with two subtypes: explicit/declarative memory and implicit/procedural '
'memory.\n'
'\n'
'The agent uses LLM as its core controller, which can be extended beyond '
'generating well-written copies, stories, essays, and programs to a powerful '
'general problem solver.')
Conclusion :
Corrective RAG is a cutting-edge framework that enhances text generation models by addressing issues arising from inaccurate or irrelevant retrieved information.
It employs a retrieval evaluator to assess the quality of retrieved documents and triggers corrective actions as needed.
By integrating web searches and a decomposition-recomposition algorithm, Corrective RAG aims to significantly enhance the accuracy and robustness of text generation processes.
Reference :