4.1.2.4 Lab – Simple Linear Regression in Python Answers

4.1.2.4 Lab – Simple Linear Regression in Python (Instructor Version)

Objectives

In this lab, you will become familiar with the concepts of simple linear regression and working with the provided data to make a prediction.

  • Part 1: Import the Libraries and Data
  • Part 2: Plot the Data
  • Part 3: Perform Simple Linear Regression

Scenario / Background

In statistics, linear regression is a way to model a relationship between dependent variable y and independent variable x .

In this lab, you will analyze district sales data and perform a simple linear regression to predict annual net sales based on the number of stores in the district.

Required Resources

  • 1 PC with Internet access
  • Python libraries: pandas, numpy, scipy, and matplotlib
  • Datafiles: stores-dist.csv

Part 1: Import the Libraries and Data

In this part, you will import the libraries and the data from the file stores-dist.csv.

Step 1: Import the libraries.

In this step, you will import the following libraries:

  • matplotlib.pyplot as plt
  • numpy as np
  • pandas as pd
# Code Cell 1

Step 2: Import the data.

In this step, you will import the data from the file stores-dist.csv and verify that the file was imported correctly.

# Code Cell 2

# Import the file, stores-dist.csv
salesDist = pd.read_csv('./Data/stores-dist.csv')

# Verify the imported data
salesDist.head()

The column headings, annual net sales and number of stores in district are renamed to make it easier during data processing.

  • annual net sales to sales
  • number of stores in district to stores
# Code Cell 3
# The district column has no relevance at this time, so it can be dropped.
salesDist = salesDist.rename(columns={'annual net sales':'sales','number of stores in district':'stores'})
salesDist.head()

Part 2: Plot the Data

Step 1: Determine the correlation.

In this step, you will investigate the correlation of the data prior to regression analysis. You will also drop any unrelated columns as necessary.

# Code Cell 4
# Check correlation of data prior to doing the analysis
# # Hint: check lab 3.1.5.5

From the correlation coefficent, it appears that the column district has low correlation to the annual net sales and number of stores in the district. So the district column is not necessary as part of the regression analysis. The district column can be dropped from the dataframe.

# Code Cell 5
# The district column has no relevance at this time, so it can be dropped.
#sales = salesDist.drop(...)

sales.head()

From the correlation coefficent data, what type of correlation did you observe between annual net sales and number of stores in the district?

<font color = ‘gray’>Type your answer here.</font>

Step 2: Create the plot.

In this step, you will create a plot to visualize the data. You will also assign stores as the independent variable x and sales as the dependent variable y .

# Code Cell 6
# dependent variable for y axis
y = sales['sales']
# independent variable for x axis
x = sales.stores
# Code Cell 7
# Display the plot inline
%matplotlib inline

# Increase the size of the plot
plt.figure(figsize=(20,10))

# Create a scatter plot: Number of stores in the District vs. Annual Net Sales
plt.plot(x,y, 'o', markersize = 15)

# Add axis labels and increase the font size
plt.ylabel('Annual Net Sales', fontsize = 30)
plt.xlabel('Number of Stores in the District', fontsize = 30)

# Increase the font size on the ticks on the x and y axis
plt.xticks(fontsize = 20)
plt.yticks(fontsize = 20)

# Display the scatter plot
plt.show()

Part 3: Perform Simple Linear Regression

In this part, you will use numpy to generate a regression line for the analyzed data. You will also calculate the centroid for this dataset. The centrod is the mean for the dataset. The generated simple linear regression line must also pass through the centroid.

Step 1: Calculate the slope and y-intercept of the linear regression line.

# Code Cell 8
# Use numpy polyfit for linear regression to fit the data
# Generate the slope of the line (m)
# Generate the y-intercept (b)
m, b = np.polyfit(x,y,1) 
print ('The slope of line is {:.2f}.'.format(m))
print ('The y-intercept is {:.2f}.'.format(b))
print ('The best fit simple linear regression line is {:.2f}x + {:.2f}.'.format(m,b))

Step 2: Calculate the centroid.

The centroid of the dataset is calculated by using the mean function.

# Code Cell 9
# y coordinate for centroid
y_mean = y.mean() 
# x coordinate for centroid
x_mean = x.mean()
print ('The centroid for this dataset is x = {:.2f} and y = {:.2f}.'.format(x_mean, y_mean))

Step 3: Overlay the regression line and the centroid point on the plot.

# Code Cell 10
# Create the plot inline
%matplotlib inline

# Enlarge the plot size
plt.figure(figsize=(20,10))

# Plot the scatter plot of the data set
plt.plot(x,y, 'o', markersize = 14, label = "Annual Net Sales")

# Plot the centroid point
plt.plot(x_mean,y_mean, '*', markersize = 30, color = "r")

# Plot the linear regression line
plt.plot(x, m*x + b, '-', label = 'Simple Linear Regression Line', linewidth = 4)

# Create the x and y axis labels
plt.ylabel('Annual Net Sales', fontsize = 30)
plt.xlabel('Number of Stores in District', fontsize = 30)

# Enlarge x and y tick marks
plt.xticks(fontsize = 20)
plt.yticks(fontsize = 20)

# Point out the centroid point in the plot
plt.annotate('Centroid', xy=(x_mean-0.1, y_mean-5), xytext=(x_mean-3, y_mean-20), arrowprops=dict(facecolor='black', shrink=0.05), fontsize = 30)

# Create legend
plt.legend(loc = 'upper right', fontsize = 20)

Step 4: Prediction

Using the linear regression line, you can predict the annual net sales based on the number of stores in the district.

# Code Cell 11
# Function to predict the net sales from the regression line
def predict(query):
    if query >= 1:
         predict = m * query + b
         return predict
    else:
      print ("You must have at least 1 store in the district to predict the annual net sales.")
# Code Cell 12
# Enter the number of stores in the function to generate the net sales prediction.

What is the predicted net sales if there are 4 stores in the district?

 


guest

0 Comments
Inline Feedbacks
View all comments
0
Would love your thoughts, please comment.x