Learn practical skills, build real-world projects, and advance your career

Linear Regression with Pytorch

  1. PyTorch Basics: Tensors & Gradients
  2. Linear Regression & Gradient Descent
  3. Image Classfication using Logistic Regression
  4. Training Deep Neural Networks on a GPU
  5. Image Classification using Convolutional Neural Networks
  6. Data Augmentation, Regularization and ResNets
  7. Generating Images using Generative Adverserial Networks

Continuing where the previous tutorial left off, we'll discuss one of the foundational algorithms of machine learning in this post: Linear regression. We'll create a model that predicts crop yields for apples and oranges (target variables) by looking at the average temperature, rainfall and humidity (input variables or features) in a region. Here's the training data:

linear-regression-training-data

In a linear regression model, each target variable is estimated to be a weighted sum of the input variables, offset by some constant, known as a bias :

yield_apple  = w11 * temp + w12 * rainfall + w13 * humidity + b1
yield_orange = w21 * temp + w22 * rainfall + w23 * humidity + b2

Visually, it means that the yield of apples is a linear or planar function of temperature, rainfall and humidity:

linear-regression-graph

The learning part of linear regression is to figure out a set of weights w11, w12,... w23, b1 & b2 by looking at the training data, to make accurate predictions for new data (i.e. to predict the yields for apples and oranges in a new region using the average temperature, rainfall and humidity). This is done by adjusting the weights slightly many times to make better predictions, using an optimization technique called gradient descent.

import torch
import numpy as np
# Input (temp, rainfall, humidity)
inputs = np.array([[73, 67, 43], 
                   [91, 88, 64], 
                   [87, 134, 58], 
                   [102, 43, 37], 
                   [69, 96, 70]], dtype='float32')
# Targets (apples, oranges)
targets = np.array([[56, 70], 
                    [81, 101], 
                    [119, 133], 
                    [22, 37], 
                    [103, 119]], dtype='float32')