Hide code cell source

import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..', 'shared')))
import setup_code
stroke_data = setup_code.stroke_data

Importing packages#

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

Explore the dataset#

stroke_data.head()
id gender age hypertension heart_disease ever_married work_type Residence_type avg_glucose_level bmi smoking_status stroke
0 9046 Male 67.0 0 1 Yes Private Urban 228.69 36.6 formerly smoked 1
1 51676 Female 61.0 0 0 Yes Self-employed Rural 202.21 NaN never smoked 1
2 31112 Male 80.0 0 1 Yes Private Rural 105.92 32.5 never smoked 1
3 60182 Female 49.0 0 0 Yes Private Urban 171.23 34.4 smokes 1
4 1665 Female 79.0 1 0 Yes Self-employed Rural 174.12 24.0 never smoked 1

A little bit of data preprocessing#

stroke_data.isnull().sum()
id                     0
gender                 0
age                    0
hypertension           0
heart_disease          0
ever_married           0
work_type              0
Residence_type         0
avg_glucose_level      0
bmi                  201
smoking_status         0
stroke                 0
dtype: int64
### getting the bmi mean for non-null values
bmi_mean = stroke_data[~stroke_data['bmi'].isna()]['bmi'].mean()
### filling the missing bmi values with the mean
na_index = stroke_data[stroke_data['bmi'].isna()].index
stroke_data.loc[na_index, 'bmi'] = bmi_mean
stroke_data.isnull().sum()
id                   0
gender               0
age                  0
hypertension         0
heart_disease        0
ever_married         0
work_type            0
Residence_type       0
avg_glucose_level    0
bmi                  0
smoking_status       0
stroke               0
dtype: int64

Matplotlib and histogram#

plt.hist(stroke_data.avg_glucose_level, bins=30, alpha=0.6, edgecolor='black')
plt.xlabel('Value')
plt.ylabel('Probability density')
plt.title('Normalized Histogram (Density)')
plt.show()
../_images/9ec8108223d271e1d6b95e68d984f990594d7fe7f6cc537b95a9e7950ac1ea85.png

Changing font size and plot size#

plt.figure(figsize=(6, 4))
plt.hist(stroke_data.avg_glucose_level, bins=30, alpha=0.6, edgecolor='black')
plt.xlabel('Value', fontsize=14)
plt.ylabel('Probability density',fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.title('Normalized Histogram (Density)', fontsize=18)
plt.show()
../_images/7757782bdace636cb17c2d7dae7293019f2e7de77fe98c5f069504f65d09e117.png

Saving the plot#

plt.figure(figsize=(6, 4))
plt.hist(stroke_data.avg_glucose_level, bins=30, alpha=0.6, edgecolor='black')
plt.xlabel('Value', fontsize=14)
plt.ylabel('Probability density',fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
# plt.grid(axis='y', alpha=0.75)
plt.title('Normalized Histogram (Density)', fontsize=18)
plt.savefig('avg_glucose_level_histogram.pdf', dpi=300, bbox_inches='tight')
plt.show()
../_images/7757782bdace636cb17c2d7dae7293019f2e7de77fe98c5f069504f65d09e117.png

Seaborn and scatter plot#

plt.figure(figsize=(6, 4))
sns.scatterplot(data=stroke_data, x='age', y='bmi')
plt.title('Scatter Plot of bmi vs Age')
plt.show()
../_images/2ecb5dfea3f99cf445feb474dbb34dea3670771d578dcc3048e824b0c1ce4d89.png

Color#

plt.figure(figsize=(6, 4))
sns.scatterplot(data=stroke_data, x='age', y='bmi', hue='stroke', alpha=0.6)
plt.title('Scatter Plot of bmi vs Age')
plt.show()
../_images/dd463b69f3b4f7e28fe75c2db01491c694c5578d0f98656afedbbd9ffdbbb416.png
plt.figure(figsize=(6, 4))
sns.scatterplot(data=stroke_data[stroke_data['stroke']==0], x='age', y='bmi', alpha=0.6)
sns.scatterplot(data=stroke_data[stroke_data['stroke']==1], x='age', y='bmi', alpha=0.6)
plt.title('Scatter Plot of bmi vs Age')
plt.show()
../_images/e658fc6b7101336702d8304f8a0e5c44cf87700ec688d1fcea43810b1e583680.png

Multiple plots#

fig, axes = plt.subplots(2, 1, figsize=(5, 10))

sns.scatterplot(data=stroke_data, x='age', y='bmi', alpha=0.6, ax=axes[0])

sns.scatterplot(data=stroke_data, x='age', y='avg_glucose_level', alpha=0.6, ax=axes[1])

plt.show()
../_images/debbdfbd94ce4ab21d5170e1dc635768cdebf20f7746d62f05c00834f58766ad.png
fig, axes = plt.subplots(2, 1, figsize=(5, 8), sharex=True)

sns.scatterplot(data=stroke_data, x='age', y='bmi', alpha=0.6, ax=axes[0])
axes[0].set_title('scatter plot of bmmi vs age')

sns.scatterplot(data=stroke_data, x='age', y='avg_glucose_level', alpha=0.6, ax=axes[1])
axes[1].set_title('scatter plot of avg_glucose_level vs age')

plt.tight_layout()
plt.show()
../_images/6f421471d6003ab01dc9dbed2d98c292835cdca1297c9393d9c54478a964561b.png

Seaborn and distribution visualization#

box plot#

sns.boxplot(data=stroke_data, x="heart_disease", y="bmi", hue="stroke", fill=False, gap=.1)
plt.title('Box Plot of BMI by Heart Disease and Stroke Status')
plt.show()
../_images/d84a68e427e78ab675e6703b17480bf5c62340fd99a14ce54b4d35bec6104aad.png

violin plot#

sns.violinplot(data=stroke_data, x="heart_disease", y="bmi", hue="stroke", split=True, gap=.1, inner="quart")
plt.title('Violin Plot of BMI by Heart Disease and Stroke Status')
plt.show()
../_images/fc0b8276f6373ee5223f38b012e1645a1c4d822535217d66e33f9ce9d50be5db.png

Visualizing correlation#

stroke_data.head()
id gender age hypertension heart_disease ever_married work_type Residence_type avg_glucose_level bmi smoking_status stroke
0 9046 Male 67.0 0 1 Yes Private Urban 228.69 36.600000 formerly smoked 1
1 51676 Female 61.0 0 0 Yes Self-employed Rural 202.21 28.893237 never smoked 1
2 31112 Male 80.0 0 1 Yes Private Rural 105.92 32.500000 never smoked 1
3 60182 Female 49.0 0 0 Yes Private Urban 171.23 34.400000 smokes 1
4 1665 Female 79.0 1 0 Yes Self-employed Rural 174.12 24.000000 never smoked 1

Pairplot#

df_pairplot = stroke_data[['age', 'bmi', 'avg_glucose_level', 'stroke']]
sns.pairplot(
    df_pairplot,
    hue="stroke",      
    diag_kind="kde",      
    markers=["o", "s"],plot_kws={'s': 10}
)
plt.suptitle("Pairwise Relationships in Iris", y =1.02)
plt.show()
../_images/fad69acd499d02807f233a7655c6b82ade04219dd182785eb1ce6ed20e71efff.png
df_pairplot_reordered = pd.concat([df_pairplot[df_pairplot['stroke'] == 0],df_pairplot[df_pairplot['stroke'] == 1]]).reset_index(drop=True)
df_pairplot_reordered.head()
age bmi avg_glucose_level stroke
0 3.0 18.0 95.12 0
1 58.0 39.2 87.96 0
2 8.0 17.6 110.89 0
3 70.0 35.9 69.04 0
4 14.0 19.1 161.28 0
sns.pairplot(
    df_pairplot_reordered,
    hue="stroke",      
    diag_kind="kde",      
    markers=["o", "s"],plot_kws={'s': 10}
)
plt.suptitle("Pairwise Relationships in Iris", y =1.02)
plt.show()
../_images/f16d44b96a85cc140274d96eebe70dfaf2b55d33328d565031214fe772214ca8.png

Heatmap#

stroke_numerical = stroke_data.select_dtypes(include=['float64'])
stroke_numerical.head()
age avg_glucose_level bmi
0 67.0 228.69 36.600000
1 61.0 202.21 28.893237
2 80.0 105.92 32.500000
3 49.0 171.23 34.400000
4 79.0 174.12 24.000000
correlations_df = stroke_numerical.corr()
plt.figure(figsize=(8, 6))

sns.heatmap(
    correlations_df,
    annot=True,       
    fmt=".2f",         
)
plt.title("Correlation Matrix")
plt.show()
../_images/04868a93a24168dcd4c3a116bc7fbec5cf0058e3d3687d01e0acf6601ce63748.png

Linear regression and line plot#

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[28], line 1
----> 1 from sklearn.linear_model import LinearRegression
      2 from sklearn.model_selection import train_test_split

ModuleNotFoundError: No module named 'sklearn'
X_train, X_test, y_train, y_test =train_test_split(stroke_data['age'].values, stroke_data['bmi'].values, test_size=20, shuffle=False)
model = LinearRegression()

model.fit(X_train.reshape(-1, 1), y_train.reshape(-1, 1))
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
stroke_data['age'].max(), stroke_data['age'].min()
(82.0, 0.08)
X_values = np.linspace(0,80,50)
predicted_bmi = model.predict(X_values.reshape(-1, 1))
plt.plot(X_values, predicted_bmi, label='Regression Line')
plt.xlabel('Age')
plt.ylabel('predicted BMI')
plt.title('Line plot of predicted BMI vs Age')
plt.legend()
plt.show()
../_images/993466fb5fc2683cbec9816ef7f520bdfb961fbcb6a8ed92f319c83a176e8427.png
plt.scatter(stroke_data['age'], stroke_data['bmi'], label='Data Points', alpha=0.5, s= 10)
plt.plot(X_values, predicted_bmi, label='Regression Line', color='red')
plt.xlabel('Age')
plt.ylabel('predicted BMI')
plt.title('Line plot of predicted BMI vs Age')
plt.legend()
plt.show()
../_images/d259ab6593af4852d535e7f6eaba611e7146763b5c2a76b376b3ca168b998008.png
sns.regplot(data=stroke_data, x="age", y="bmi",scatter_kws={'s': 10},  line_kws={'color': 'red'})
plt.title('Line plot of predicted BMI vs Age')
plt.show()
../_images/09bbbb0ddda2c750f9604cfe03c5683cc6eb3743a29112a9bd887472f6739065.png