Exercices Entraîner son LLM Splitter train et validation
🎉

Bravo!

Intermédiaire 🧠 Fondamentaux 20 XP 0 personnes ont réussi

Splitter train et validation

En machine learning, tu n'entraînes jamais un modèle sur 100% de tes données. Tu en gardes une partie de côté (le set de validation) pour vérifier que le modèle généralise bien et qu'il ne fait pas du par-coeur (overfitting).

Pour le fine-tuning d'un LLM, la convention c'est 80% pour l'entraînement et 20% pour la validation. Mais attention : tu ne prends pas les 80 premiers exemples pour le train et les 20 derniers pour la val. Il faut mélanger aléatoirement. Sinon, si tes données sont triées par sujet, ton modèle pourrait très bien marcher sur un sujet et rater complètement l'autre.

En Python, le module random te donne tout ce qu'il faut. random.seed() te permet de fixer la graine aléatoire pour que le mélange soit reproductible (important en recherche et en production pour comparer des runs).

import random
random.seed(42)
random.shuffle(ma_liste) # mélange en place

Écris une fonction splitter_dataset(exemples, ratio_train=0.8, seed=42) qui :
1. Fait une copie de la liste (pour ne pas modifier l'originale)
2. Mélange la copie avec la graine donnée
3. Coupe au bon index (arrondi à l'entier inférieur)
4. Renvoie un dictionnaire avec "train" et "validation"

Exemple :

data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
r = splitter_dataset(data, ratio_train=0.8, seed=42)
len(r["train"]) renvoie 8
len(r["validation"]) renvoie 2

Tests (4/5)

Proportions correctes
data = list(range(100))
r = splitter_dataset(data, ratio_train=0.8, seed=42)
assert len(r['train']) == 80
assert len(r['validation']) == 20
Pas de perte de données
data = list(range(50))
r = splitter_dataset(data, ratio_train=0.8, seed=42)
tous = sorted(r['train'] + r['validation'])
assert tous == list(range(50)), 'Aucun élément ne doit être perdu'
Ne modifie pas l'original
data = [1, 2, 3, 4, 5]
original = data.copy()
splitter_dataset(data, ratio_train=0.8, seed=42)
assert data == original, 'La liste originale ne doit pas être modifiée'
Reproductibilité avec seed
data = list(range(20))
r1 = splitter_dataset(data, seed=123)
r2 = splitter_dataset(data, seed=123)
assert r1['train'] == r2['train']
assert r1['validation'] == r2['validation']

+ 0 tests cachés

Indices (3 disponibles)

solution.py