Exercices Machine Learning LinearRegression : predire une valeur continue
🎉

Bravo!

Intermédiaire 🧠 Fondamentaux 20 XP 0 personnes ont réussi

LinearRegression : predire une valeur continue

Predire le prix d'un appartement en fonction de sa surface, de son etage et de son quartier. Estimer le salaire d'un candidat selon son experience. La regression lineaire est le modele le plus fondamental pour predire une valeur continue.

L'idee : trouver la meilleure droite (ou hyperplan en plusieurs dimensions) qui passe au plus pres de tous les points. En 2D, c'est y = a*x + b. Le modèle apprend les coefficients (a) et l'ordonnée a l'origine (b).

from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(X_train, y_train)
predictions = model.predict(X_test)
model.coef_ # les coefficients (poids de chaque feature)
model.intercept_ # l'ordonnée a l'origine

Écris une fonction regression_lineaire(X_train, y_train, X_test) qui entraine une LinearRegression et renvoie :
'predictions' : les predictions sur X_test
'coefficients' : les coefficients du modèle (list)
'intercept' : l'ordonnée a l'origine (float)

Tests (4/4)

Predictions correctes
import numpy as np
X_train = np.array([[1], [2], [3], [4]])
y_train = np.array([2, 4, 6, 8])
X_test = np.array([[5]])
result = regression_lineaire(X_train, y_train, X_test)
assert abs(result['predictions'][0] - 10.0) < 1e-6
Coefficient correct
import numpy as np
X_train = np.array([[1], [2], [3], [4]])
y_train = np.array([2, 4, 6, 8])
result = regression_lineaire(X_train, y_train, np.array([[5]]))
assert abs(result['coefficients'][0] - 2.0) < 1e-6
Intercept correct
import numpy as np
X_train = np.array([[1], [2], [3]])
y_train = np.array([3, 5, 7])
result = regression_lineaire(X_train, y_train, np.array([[0]]))
assert abs(result['intercept'] - 1.0) < 1e-6
Plusieurs features
import numpy as np
X_train = np.array([[1, 1], [2, 2], [3, 3]])
y_train = np.array([3, 6, 9])
result = regression_lineaire(X_train, y_train, np.array([[4, 4]]))
assert len(result['coefficients']) == 2

Indices (3 disponibles)

solution.py