Avancé
🧠 Fondamentaux
30 XP
0 personnes ont réussi
Pipeline RAG avec FAISS et évaluation
Pour finir, on va construire un pipeline RAG de niveau production qui combine tout ce qu'on a appris : Documents LangChain, text splitting, TF-IDF, FAISS et évaluation.
- ingest(self, texts, sources) Decoupe les textes en chunks, vectorise avec TF-IDF, et indexe dans FAISS. Stocke aussi la correspondance entre index FAISS et chunks.
- search(self, query, k=3, filters=None) Cherche les k chunks les plus proches dans FAISS. Si filters est un dictionnaire, ne garde que les résultats dont les metadonnées correspondent (filtre applique APRES la recherche FAISS, sur un pool plus large). Pour gerer le filtrage post-recherche, cherche k*3 résultats dans FAISS puis filtre et prend les k premiers.
- evaluate(self, queries_and_relevant) Prend une liste de tuples (query, set_of_relevant_sources). Pour chaque query, fait une recherche et compare les sources des résultats avec les sources pertinentes. Renvoie {'mean_precision': float, 'mean_recall': float}.
rag = ProductionRAG(chunk_size=100, chunk_overlap=20)
texts = [
'Python est un langage de programmation populaire crée par Guido van Rossum en 1991. ' * 5,
'Django est un framework web en Python utilise pour créer des applications web robustes. ' * 5,
'Les bases de données SQL permettent de stocker et interroger des données structurees. ' * 5,
]
sources = ['python.txt', 'django.txt', 'sql.txt']
rag.ingest(texts, sources)
assert len(rag.chunks) > 3, 'Les textes doivent etre decoupes'
assert rag.index.ntotal == len(rag.chunks), 'L index FAISS doit contenir tous les chunks'
results = rag.search('programmation Python', k=2)
assert len(results) == 2, f'Attendu 2 résultats, obtenu {len(results)}'
results_filtered = rag.search('Python', k=2, filters={'source': 'django.txt'})
assert all(r.metadata['source'] == 'django.txt' for r in results_filtered), 'Le filtre doit etre respecte'
metrics = rag.evaluate([('Python programmation', {'python.txt'}), ('framework web', {'django.txt'})])
assert 'mean_precision' in metrics, 'Les metriques doivent contenir mean_precision'
assert 'mean_recall' in metrics, 'Les metriques doivent contenir mean_recall'
assert 0 <= metrics['mean_precision'] <= 1, 'Precision doit etre entre 0 et 1'
Indices (3 disponibles)
Solution officielle
import faiss
import numpy as np
from langchain_core.documents import Document
from langchain_text_splitters import RécursiveCharacterTextSplitter
from sklearn.feature_extraction.text import TfidfVectorizer
class ProductionRAG:
def __init__(self, chunk_size=200, chunk_overlap=50):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.chunks = []
self.vectorizer = None
self.index = None
def ingest(self, texts, sources):
documents = [
Document(page_content=text, metadata={'source': source})
for text, source in zip(texts, sources)
]
splitter = RécursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap
)
self.chunks = splitter.split_documents(documents)
chunk_texts = [c.page_content for c in self.chunks]
self.vectorizer = TfidfVectorizer()
tfidf_matrix = self.vectorizer.fit_transform(chunk_texts)
vectors = np.array(tfidf_matrix.toarray(), dtype='float32')
self.index = faiss.IndexFlatL2(vectors.shape[1])
self.index.add(vectors)
def search(self, query, k=3, filters=None):
search_k = k * 3 if filters else k
query_vec = self.vectorizer.transform([query]).toarray().astype('float32')
search_k = min(search_k, self.index.ntotal)
distances, indices = self.index.search(query_vec, search_k)
candidates = [self.chunks[i] for i in indices[0]]
if filters:
candidates = [
c for c in candidates
if all(c.metadata.get(key) == val for key, val in filters.items())
]
return candidates[:k]
def evaluate(self, queries_and_relevant):
precisions = []
recalls = []
for query, relevant_sources in queries_and_relevant:
results = self.search(query)
retrieved_sources = [r.metadata.get('source') for r in results]
hits = sum(1 for s in retrieved_sources if s in relevant_sources)
precision = hits / len(retrieved_sources) if retrieved_sources else 0.0
recall = hits / len(relevant_sources) if relevant_sources else 1.0
precisions.append(precision)
recalls.append(recall)
return {
'mean_precision': sum(precisions) / len(precisions),
'mean_recall': sum(recalls) / len(recalls)
}