You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

87 KiB

None <html> <head> </head>

Linear regression

The linear regression is a training procedure based on a linear model. The model makes a prediction by simply computing a weighted sum of the input features, plus a constant term called the bias term (also called the intercept term):

$$ \hat{y}=\theta_0 + \theta_1 x_1 + \theta_2 x_2 + \cdots + \theta_n x_n$$

This can be writen more easy by using vector notation form for $m$ values. Therefore, the model will become:

$$ \begin{bmatrix} \hat{y}^0 \\ \hat{y}^1\\ \hat{y}^2\\ \vdots \\ \hat{y}^m \end{bmatrix} = \begin{bmatrix} 1 & x_1^0 & x_2^0 & \cdots &x_n^0\\ 1 & x_1^1 & x_2^1 & \cdots & x_n^1\\ \vdots & \vdots &\vdots & \cdots & \vdots\\ 1 & x_1^m & x_2^m & \cdots & x_n^m \end{bmatrix} \begin{bmatrix} \theta_0 \\ \theta_1 \\ \theta_2 \\ \vdots \\ \theta_n \end{bmatrix} $$

Resulting:

$$\hat{y}= h_\theta(x) = x \theta $$

Now that we have our mode, how do we train it?

Please, consider that training the model means adjusting the parameters to reduce the error or minimizing the cost function. The most common performance measure of a regression model is the Mean Square Error (MSE). Therefore, to train a Linear Regression model, you need to find the value of θ that minimizes the MSE:

$$ MSE(X,h_\theta) = \frac{1}{m} \sum_{i=1}^{m} \left(\hat{y}^{(i)}-y^{(i)} \right)^2$$

$$ MSE(X,h_\theta) = \frac{1}{m} \sum_{i=1}^{m} \left( x^{(i)}\theta-y^{(i)} \right)^2$$

$$ MSE(X,h_\theta) = \frac{1}{m} \left( x\theta-y \right)^T \left( x\theta-y \right)$$

The normal equation

To find the value of $\theta$ that minimizes the cost function, there is a closed-form solution that gives the result directly. This is called the Normal Equation; and can be find it by derivating the MSE equation as a function of $\theta$ and making it equals to zero:

$$\hat{\theta} = (X^T X)^{-1} X^{T} y $$

$$ Temp = \theta_0 + \theta_1 * t $$

In [2]:
import pandas as pd
df = pd.read_csv('data.csv')
y = df['0']
y
Out[2]:
0      24.218
1      23.154
2      24.347
3      24.411
4      24.411
        ...  
295    46.357
296    46.551
297    46.519
298    46.551
299    46.583
Name: 0, Length: 300, dtype: float64
In [4]:
import matplotlib.pyplot as plt
plt.plot(y,'.r')
Out[4]:
[<matplotlib.lines.Line2D at 0x10dca4890>]
In [6]:
import numpy as np
n = len(y)
x = np.linspace(0,n-1,n)
X = np.c_[np.ones(n), x]
X
Out[6]:
array([[  1.,   0.],
       [  1.,   1.],
       [  1.,   2.],
       [  1.,   3.],
       [  1.,   4.],
       [  1.,   5.],
       [  1.,   6.],
       [  1.,   7.],
       [  1.,   8.],
       [  1.,   9.],
       [  1.,  10.],
       [  1.,  11.],
       [  1.,  12.],
       [  1.,  13.],
       [  1.,  14.],
       [  1.,  15.],
       [  1.,  16.],
       [  1.,  17.],
       [  1.,  18.],
       [  1.,  19.],
       [  1.,  20.],
       [  1.,  21.],
       [  1.,  22.],
       [  1.,  23.],
       [  1.,  24.],
       [  1.,  25.],
       [  1.,  26.],
       [  1.,  27.],
       [  1.,  28.],
       [  1.,  29.],
       [  1.,  30.],
       [  1.,  31.],
       [  1.,  32.],
       [  1.,  33.],
       [  1.,  34.],
       [  1.,  35.],
       [  1.,  36.],
       [  1.,  37.],
       [  1.,  38.],
       [  1.,  39.],
       [  1.,  40.],
       [  1.,  41.],
       [  1.,  42.],
       [  1.,  43.],
       [  1.,  44.],
       [  1.,  45.],
       [  1.,  46.],
       [  1.,  47.],
       [  1.,  48.],
       [  1.,  49.],
       [  1.,  50.],
       [  1.,  51.],
       [  1.,  52.],
       [  1.,  53.],
       [  1.,  54.],
       [  1.,  55.],
       [  1.,  56.],
       [  1.,  57.],
       [  1.,  58.],
       [  1.,  59.],
       [  1.,  60.],
       [  1.,  61.],
       [  1.,  62.],
       [  1.,  63.],
       [  1.,  64.],
       [  1.,  65.],
       [  1.,  66.],
       [  1.,  67.],
       [  1.,  68.],
       [  1.,  69.],
       [  1.,  70.],
       [  1.,  71.],
       [  1.,  72.],
       [  1.,  73.],
       [  1.,  74.],
       [  1.,  75.],
       [  1.,  76.],
       [  1.,  77.],
       [  1.,  78.],
       [  1.,  79.],
       [  1.,  80.],
       [  1.,  81.],
       [  1.,  82.],
       [  1.,  83.],
       [  1.,  84.],
       [  1.,  85.],
       [  1.,  86.],
       [  1.,  87.],
       [  1.,  88.],
       [  1.,  89.],
       [  1.,  90.],
       [  1.,  91.],
       [  1.,  92.],
       [  1.,  93.],
       [  1.,  94.],
       [  1.,  95.],
       [  1.,  96.],
       [  1.,  97.],
       [  1.,  98.],
       [  1.,  99.],
       [  1., 100.],
       [  1., 101.],
       [  1., 102.],
       [  1., 103.],
       [  1., 104.],
       [  1., 105.],
       [  1., 106.],
       [  1., 107.],
       [  1., 108.],
       [  1., 109.],
       [  1., 110.],
       [  1., 111.],
       [  1., 112.],
       [  1., 113.],
       [  1., 114.],
       [  1., 115.],
       [  1., 116.],
       [  1., 117.],
       [  1., 118.],
       [  1., 119.],
       [  1., 120.],
       [  1., 121.],
       [  1., 122.],
       [  1., 123.],
       [  1., 124.],
       [  1., 125.],
       [  1., 126.],
       [  1., 127.],
       [  1., 128.],
       [  1., 129.],
       [  1., 130.],
       [  1., 131.],
       [  1., 132.],
       [  1., 133.],
       [  1., 134.],
       [  1., 135.],
       [  1., 136.],
       [  1., 137.],
       [  1., 138.],
       [  1., 139.],
       [  1., 140.],
       [  1., 141.],
       [  1., 142.],
       [  1., 143.],
       [  1., 144.],
       [  1., 145.],
       [  1., 146.],
       [  1., 147.],
       [  1., 148.],
       [  1., 149.],
       [  1., 150.],
       [  1., 151.],
       [  1., 152.],
       [  1., 153.],
       [  1., 154.],
       [  1., 155.],
       [  1., 156.],
       [  1., 157.],
       [  1., 158.],
       [  1., 159.],
       [  1., 160.],
       [  1., 161.],
       [  1., 162.],
       [  1., 163.],
       [  1., 164.],
       [  1., 165.],
       [  1., 166.],
       [  1., 167.],
       [  1., 168.],
       [  1., 169.],
       [  1., 170.],
       [  1., 171.],
       [  1., 172.],
       [  1., 173.],
       [  1., 174.],
       [  1., 175.],
       [  1., 176.],
       [  1., 177.],
       [  1., 178.],
       [  1., 179.],
       [  1., 180.],
       [  1., 181.],
       [  1., 182.],
       [  1., 183.],
       [  1., 184.],
       [  1., 185.],
       [  1., 186.],
       [  1., 187.],
       [  1., 188.],
       [  1., 189.],
       [  1., 190.],
       [  1., 191.],
       [  1., 192.],
       [  1., 193.],
       [  1., 194.],
       [  1., 195.],
       [  1., 196.],
       [  1., 197.],
       [  1., 198.],
       [  1., 199.],
       [  1., 200.],
       [  1., 201.],
       [  1., 202.],
       [  1., 203.],
       [  1., 204.],
       [  1., 205.],
       [  1., 206.],
       [  1., 207.],
       [  1., 208.],
       [  1., 209.],
       [  1., 210.],
       [  1., 211.],
       [  1., 212.],
       [  1., 213.],
       [  1., 214.],
       [  1., 215.],
       [  1., 216.],
       [  1., 217.],
       [  1., 218.],
       [  1., 219.],
       [  1., 220.],
       [  1., 221.],
       [  1., 222.],
       [  1., 223.],
       [  1., 224.],
       [  1., 225.],
       [  1., 226.],
       [  1., 227.],
       [  1., 228.],
       [  1., 229.],
       [  1., 230.],
       [  1., 231.],
       [  1., 232.],
       [  1., 233.],
       [  1., 234.],
       [  1., 235.],
       [  1., 236.],
       [  1., 237.],
       [  1., 238.],
       [  1., 239.],
       [  1., 240.],
       [  1., 241.],
       [  1., 242.],
       [  1., 243.],
       [  1., 244.],
       [  1., 245.],
       [  1., 246.],
       [  1., 247.],
       [  1., 248.],
       [  1., 249.],
       [  1., 250.],
       [  1., 251.],
       [  1., 252.],
       [  1., 253.],
       [  1., 254.],
       [  1., 255.],
       [  1., 256.],
       [  1., 257.],
       [  1., 258.],
       [  1., 259.],
       [  1., 260.],
       [  1., 261.],
       [  1., 262.],
       [  1., 263.],
       [  1., 264.],
       [  1., 265.],
       [  1., 266.],
       [  1., 267.],
       [  1., 268.],
       [  1., 269.],
       [  1., 270.],
       [  1., 271.],
       [  1., 272.],
       [  1., 273.],
       [  1., 274.],
       [  1., 275.],
       [  1., 276.],
       [  1., 277.],
       [  1., 278.],
       [  1., 279.],
       [  1., 280.],
       [  1., 281.],
       [  1., 282.],
       [  1., 283.],
       [  1., 284.],
       [  1., 285.],
       [  1., 286.],
       [  1., 287.],
       [  1., 288.],
       [  1., 289.],
       [  1., 290.],
       [  1., 291.],
       [  1., 292.],
       [  1., 293.],
       [  1., 294.],
       [  1., 295.],
       [  1., 296.],
       [  1., 297.],
       [  1., 298.],
       [  1., 299.]])
In [8]:
theta = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(y)
theta
Out[8]:
array([25.70275643,  0.07850281])
In [9]:
ypre = X.dot(theta)
plt.plot(x, ypre, '*-r', label='model')
plt.plot(x,y, '.k', label='data')
plt.legend()
plt.show()

Polynomial model

$$ y = \theta_0+\theta_1 t+\theta_2 t^2++\theta_3 t^3$$

In [11]:
X = np.c_[np.ones(len(x)), x, x*x]
theta = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(y)
theta
Out[11]:
array([ 2.28848082e+01,  1.35240024e-01, -1.89756565e-04])
In [13]:
Xnew1 = np.linspace(0,300,50)
Xnew = np.c_[np.ones(len(Xnew1)), Xnew1, Xnew1*Xnew1]
ypred = Xnew.dot(theta)
plt.plot(Xnew1, ypred, '*-r', label='model')
plt.plot(x,y, '.g', label='data')
plt.legend()
plt.show()

Batch gradient descent

$$ \theta_{new} = \theta - \eta \nabla_{\theta} $$ $$\nabla = \frac{2}{m}X^T(X\theta-y) $$

In [ ]:

</html>