--- config: theme: default look: classic --- flowchart LR A["Reference<br>documents<br>📄📄📄"] --> n1["Embedding<br>🔢🔢🔢"] n1 --> n3["Vector database<br>🗄️🗄️🗄️"] n3 --> n4["Retriever<br>🔎"] n5["User query<br>👤💬"] --> n6["Embedding<br>🔢🔢🔢"] n5 -- 💬 --> n8["Augmented<br>query 📄💬<br>"] n6 --> n4 n8 --> n9["Large Language<br>Model 🧠"] n9 --> n10["Response<br>🗨️"] n4 -- 📄 --> n8
What is RAG?
RAG (Retrieval-Augmented Generation) is a technology that combines the search external sources and generation using large language models. It allows for more accurate and informative responses by leveraging the information retrieved from a corpus of text to generate contextually relevant answers. RAG is particularly useful for question-answering tasks where the answer requires external knowledge beyond what is present in the training data of the language model.
How does RAG work?
RAG consists of two main components: a retriever and a generator. The retriever is responsible for searching a corpus of text to find relevant information based on the input query. The generator then uses the retrieved information to generate a response. By combining these two components, RAG can produce more informative and contextually relevant responses compared to traditional language models.
How to Implement a RAG System
In this tutorial, we will implement a simple RAG system using a large language model (LLM) and a text retrieval system. We will use the Hugging Face Transformers library to load a pre-trained LLM and the Faiss library to build a vector database for text retrieval. We will then combine these components to create a question-answering system that retrieves relevant information from a corpus of text and generates responses based on the retrieved information.
Let’s get started by installing the necessary libraries and setting up the components for our RAG system.
!pip install bitsandbytes faiss-cpu langchain langchain_community \
-transformers --quiet langchain_huggingface sentence
import os
import requests
import json
import numpy as np
import pickle
from bs4 import BeautifulSoup, SoupStrainer
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain.vectorstores.faiss import FAISS
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
Data Collection
We will parse the Wikipedia category page to extract links to articles.
from urllib.parse import urlparse
def fetch_links(url):
= []
links = requests.get(url)
response = BeautifulSoup(response.text, 'html.parser')
soup = urlparse(url).netloc
domain
for ul in soup.find_all('ul'):
for li in ul.find_all('li'):
= li.find('a')
link if link and "href" in link.attrs:
= link.attrs["href"]
href if "/wiki" in href[:5]:
f"https://{domain}{href}")
links.append(
return links
Set the url
variable and get the links.
= 'https://en.wikipedia.org/wiki/Category:Machine_learning_algorithms'
url = fetch_links(url) links
Next, we will download articles as the docs.
"USER_AGENT"] = (
os.environ["Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/134.0.0.0 Safari/537.36 Edg/134.0.0.0"
)
= WebBaseLoader(
loader 20:],
links[={
bs_kwargs"parse_only": SoupStrainer("div", {"class": "mw-body-content"}),
},={"separator": " ", "strip": True},
bs_get_text_kwargs
)= loader.load() docs
Text Splitting and Embedding
Here we break documents into shorter chunks—overlapping parts that should be provided to the LLM as context.
= RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
text_splitter = text_splitter.split_documents(docs) split_docs
We need to perform quick search for relevant information, so let’s transform texts to embeddings and load them into vector database.
from langchain_huggingface import HuggingFaceEmbeddings
= "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
model_name = {"device": "cuda"}
model_kwargs = {"normalize_embeddings": False}
encode_kwargs = HuggingFaceEmbeddings(
embedding =model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
model_name )
= FAISS.from_documents(split_docs, embedding=embedding) vector_store
Question-Answering Pipeline
Here we define function for retrieving relevant documents from the database.
def retrieve(query, top_k=2):
= vector_store.search(query, "similarity")
documents return documents[:top_k]
We will use local LLM. Let’s authorize on HuggingFace which is mandatory for downloading certain models.
from huggingface_hub import login
from google.colab import userdata
= userdata.get("HF_TOKEN")) login(token
Download model, define tokenizer and create config.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
= "Qwen/Qwen2.5-7B"
MODEL_NAME
= AutoModelForCausalLM.from_pretrained(
model
MODEL_NAME,=True,
load_in_8bit=torch.bfloat16,
torch_dtype="cuda",
device_map
)eval()
model.
= AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer
Let’s put context retrieval and generation pipeline into a function.
= pipeline(
gen_pipeline "text-generation", model=model, tokenizer=tokenizer, return_full_text=False
)
def generate_response(query):
= retrieve(query)
relevant_texts = " ".join([t.model_dump()["page_content"] for t in relevant_texts])
context = f"""Answer question using only information provided in the context.
prompt If the context contains no relevant information, say "I couldn't find the information".
Context: '''{context}'''
Question: {query}
Answer:
"""
= gen_pipeline(prompt)
response return response[0]["generated_text"]
Testing the RAG System
Here starts our Q&A session.
= "What is the Actor-critic algorithm in reinforcement learning?"
query
= generate_response(query)
answer print(answer)
The actor-critic algorithm (AC) is a family of reinforcement learning (RL) algorithms that combine policy-based RL algorithms such as policy gradient methods, and value-based RL algorithms such as value iteration, Q-learning, SARSA, and TD learning. An AC algorithm consists of two main components: an "actor" that determines which actions to take according to a policy function, and a "critic" that evaluates those actions according to a value function. Some AC algorithms are on-policy, some are off-policy. Some apply to either continuous or discrete action spaces. Some work in both cases.
= "What is the purpose of backpropagation in neural networks?"
query
= generate_response(query)
answer print(answer)
The purpose of backpropagation in neural networks is to adjust the weights of the connections between neurons in order to minimize the error between the predicted output and the actual output. This is done by propagating the error backwards through the network, starting from the output layer and moving towards the input layer, hence the name "backpropagation." The goal is to find the optimal set of weights that will allow the network to make accurate predictions on new, unseen data.
= "Explain the concept of Curriculum learning in machine learning."
query
= generate_response(query)
answer print(answer)
Curriculum learning in machine learning is a technique that involves gradually introducing more complex concepts or data to a model as it learns. This approach is inspired by the way humans learn, starting with simple concepts and building upon them. In the context provided, it is mentioned that this technique has its roots in the early study of neural networks, particularly in Jeffrey Elman's 1993 paper. Bengio et al. demonstrated successful application of curriculum learning in image classification tasks, such as identifying geometric shapes with increasingly complex forms, and language modeling tasks, such as training with a gradually expanding vocabulary. The authors conclude that curriculum strategies can be beneficial for machine learning models, especially when dealing with complex or large-scale problems.
= "How does K-nearest neighbors (K-NN) algorithm classify data?"
query
= generate_response(query)
answer print(answer)
The K-nearest neighbors (K-NN) algorithm classifies data by a plurality vote of its neighbors, with the object being assigned to the class most common among its K nearest neighbors (K is a positive integer, typically small). If K = 1, then the object is simply assigned to the class of that single nearest neighbor.
= "What is Federated Learning of Cohorts and how does it improve data privacy?"
query
= generate_response(query)
answer print(answer)
Federated Learning of Cohorts (FLoC) is a type of web tracking that groups people into "cohorts" based on their browsing history for the purpose of interest-based advertising. It was being developed as a part of Google's Privacy Sandbox initiative, which includes several other advertising-related technologies with bird-themed names. FLoC was being tested in Chrome 89 as a replacement for third-party cookies. Despite "federated learning" in the name, FLoC does not utilize any federated learning. FLoC improves data privacy by grouping people into cohorts based on their browsing history, rather than tracking individual users. This means that advertisers can still target users based on their interests, but without the need for individual user data.
Looks good. Let’s ask something that is not in the context. For instance, there was no articles on Transformer architecture among wiki articles.
Out-of-Context Questions
= "How does the Transformer architecture improve upon traditional RNNs and LSTMs in NLP tasks?"
query
= generate_response(query)
answer print(answer)
The Transformer architecture improves upon traditional RNNs and LSTMs in NLP tasks by using self-attention mechanisms to capture long-range dependencies between words in a sentence. This allows the model to process entire sentences at once, rather than sequentially like RNNs and LSTMs. Additionally, the Transformer architecture uses a fixed-size attention mechanism, which makes it more efficient and scalable than RNNs and LSTMs.
That’s interesting. To be sure that there’s no information on this topic, let’s check context.
retrieve(query)
[
Document(
id='a2ae5aee-3b78-4804-a983-25d08fb8f5d3',
metadata={'source': 'https://en.wikipedia.org/wiki/Loss_functions_for_classification'},
page_content='Andrew Ng Fei-Fei Li Alex Krizhevsky Ilya Sutskever Demis Hassabis David Silver Ian Goodfellow Andrej Karpathy Architectures Neural Turing machine Differentiable neural computer Transformer Vision transformer (ViT) Recurrent neural network (RNN) Long short-term memory (LSTM) Gated recurrent unit (GRU) Echo state network Multilayer perceptron (MLP) Convolutional neural network (CNN) Residual neural network (RNN) Highway network Mamba Autoencoder Variational autoencoder (VAE) Generative adversarial network (GAN) Graph neural network (GNN) Portals Technology Category Artificial neural networks Machine learning List Companies Projects Retrieved from " https://en.wikipedia.org/w/index.php?title=Loss_functions_for_classification&oldid=1261562183 "'
),
Document(
id='b267b523-9330-4b33-bc3a-b4e6edec109f',
metadata={'source': 'https://en.wikipedia.org/wiki/Policy_gradient_method'},
page_content='neural computer Transformer Vision transformer (ViT) Recurrent neural network (RNN) Long short-term memory (LSTM) Gated recurrent unit (GRU) Echo state network Multilayer perceptron (MLP) Convolutional neural network (CNN) Residual neural network (RNN) Highway network Mamba Autoencoder Variational autoencoder (VAE) Generative adversarial network (GAN) Graph neural network (GNN) Portals Technology Category Artificial neural networks Machine learning List Companies Projects Retrieved from " https://en.wikipedia.org/w/index.php?title=Policy_gradient_method&oldid=1280215280 "'
)
]
It appears that the query retrieved random parts of pages mentioning transformers. However, as they contained no valuable information, the answer was fully generated by the LLM. Although the response was accurate, we may want to enhance the retrieval function by setting a threshold for relevancy to minimize the risk of hallucinations.
Let’s ask a question from the completely different domain.
= "How does the process of photosynthesis work in plants?"
query
= generate_response(query)
answer print(answer)
I couldn't find the information.
This question left unanswered. What about another one?
= "How does blockchain technology ensure security and decentralization?"
query
= generate_response(query)
answer print(answer)
Blockchain technology ensures security and decentralization through its decentralized nature and cryptographic algorithms. It operates on a distributed network of nodes, where each node maintains a copy of the entire blockchain. This means that no single entity has control over the entire system, making it resistant to tampering and censorship. Additionally, blockchain uses cryptographic algorithms to secure transactions and data, ensuring that only authorized parties can access and modify the information. This combination of decentralization and cryptographic security makes blockchain technology highly secure and decentralized.
Unexpectedly, one of the documents contained information on this topic, so the answer was generated based on the retrieved context.
retrieve(query)
[
Document(
id='3d968d3b-1889-4435-b329-c9081400e8c4',
metadata={'source': 'https://en.wikipedia.org/wiki/Augmented_Analytics'},
page_content='to democratising data: Data Parameterisation and Characterisation. Data Decentralisation using an OS of blockchain and DLT technologies, as well as an independently governed secure data exchange to enable trust. Consent Market-driven Data Monetisation. When it comes to connecting assets, there are two features that will accelerate the adoption and usage of data democratisation: decentralized identity management and business data object monetization of data ownership. It enables multiple individuals and organizations to identify, authenticate, and authorize participants and organizations, enabling them to access services, data or systems across multiple networks, organizations, environments, and use cases. It empowers users and enables a personalized, self-service digital onboarding system so that users can self-authenticate without relying on a central administration function to process their information. Simultaneously, decentralized identity management ensures the user is authorized'
),
Document(
id='98578608-2a8d-4533-a655-b556202dda7d',
metadata={'source': 'https://en.wikipedia.org/wiki/Augmented_Analytics'},
page_content='so that users can self-authenticate without relying on a central administration function to process their information. Simultaneously, decentralized identity management ensures the user is authorized to perform actions subject to the system’s policies based on their attributes (role, department, organization, etc.) and/ or physical location. [ 10 ] Use cases [ edit ] Agriculture – Farmers collect data on water use, soil temperature, moisture content and crop growth, augmented analytics can be used to make sense of this data and possibly identify insights that the user can then use to make business decisions. [ 11 ] Smart Cities – Many cities across the United States, known as Smart Cities collect large amounts of data on a daily basis. Augmented analytics can be used to simplify this data in order to increase effectiveness in city management (transportation, natural disasters, etc.). [ 11 ] Analytic Dashboards – Augmented analytics has the ability to take large data sets and create'
)
]
One more question from another domain.
= "What are the fundamental principles of classical mechanics?"
query
= generate_response(query)
answer print(answer)
I couldn't find the information.
This question was left unanswered as expected.
Conclusion
In this tutorial, we implemented a simple RAG system using a large language model and a text retrieval system. We collected articles from Wikipedia, split them into shorter chunks, and transformed them into embeddings for quick search. We then used a local LLM to generate responses based on the retrieved information. The RAG system successfully answered questions related to machine learning algorithms and reinforcement learning. The RAG system demonstrates the potential of combining retrieval and generation techniques to produce informative and contextually relevant answers.
Source code for this tutorial is available on GitHub.