Python for AI (Part 5): Exploring Seaborn for Advanced Visualization
Introduction
This is Part 5 of your python-for-AI journey. Having learned to master Matplotlib in Part 4, we’re going to take our game up to Seaborn, a library that builds upon Matplotlib with easier-to-produce, better-looking statistical visualizations. Now if you’ve been using Java but are moving to Python, you’ll appreciate Seaborn’s integration with Pandas and its machine learning-driven graphics that let you visualize your data before building your ML models. Let’s get started on this tutorial using a detailed student dataset to showcase the power of Seaborn!
What is Seaborn and Why Use It?
Seaborn comes with Matplotlib and is a Python library rated unmatplotlib — tailored to statistical graphics. Its ideal for AI because it breaks down complicated plots into:
- Built-in themes for polished looks.
- Functions for pair plots, heatmaps, and more—tailored for data analysis.
- Direct Pandas DataFrame support, streamlining your workflow.
pip install seaborn
Then import it with required libraries:
import seaborn as sns # Seaborn for advanced plotting import matplotlib.pyplot as plt # Matplotlib for underlying plot control import pandas as pd # Pandas for data handling
Our Test Dataset
We’ll explore Seaborn with a student performance dataset, designed to test its features:
# Create a dictionary with student data data = { 'Name': ['Alice', 'Bob', 'Charlie', 'Dana', 'Eve', 'Frank', 'Grace', 'Hank', 'Ivy', 'Jack'], # Student names 'Math': [85, 92, 78, 95, 88, 65, 90, 82, 77, 94], # Math scores (numerical) 'Science': [88, 85, 90, 92, 80, 70, 87, 83, 79, 91], # Science scores (numerical) 'English': [90, 88, 85, 93, 87, 75, 89, 84, 80, 92], # English scores (numerical) 'Study_Hours': [5, 6, 4, 7, 5.5, 3, 6.5, 4.5, 4, 6], # Hours studied (numerical) 'Grade_Level': ['Freshman', 'Sophomore', 'Freshman', 'Sophomore', 'Freshman', 'Freshman', 'Sophomore', 'Sophomore', 'Freshman', 'Sophomore'], # Grade category 'Pass': [True, True, False, True, True, False, True, True, False, True] # Pass/Fail status (categorical) } # Convert dictionary to a Pandas DataFrame for structured data handling df = pd.DataFrame(data)
This dataset consists of 10 students, including numerical data (scores and study hours) and categorical data (Grade_Level and Pass). It’s large enough to demonstrate the sorts of things that Seaborn does well—plotting distributions and relationships—but small enough to maintain clarity in the plots. This is a practical example of how you could use this kind of dataset in AI to predict student success.
Pair Plot: Multivariate Analysis
A pair plot shows relationships between multiple numerical variables and their distributions:
# Create a pair plot with Seaborn sns.pairplot( df, # DataFrame to plot hue='Grade_Level', # Color points by Grade_Level (Freshman vs Sophomore) vars=['Math', 'Science', 'English', 'Study_Hours'], # Select numerical columns to compare markers=['o', 's'] # Use circles for Freshman, squares for Sophomore ) # Display the plot plt.show()
This code creates a matrix of scatterplots for each group of numerical variables (e.g., Math
vs. Science
, Study_Hours
vs. English
) and histograms along the diagonal for each of the variables’ distribution. The argument hue='Grade_Level'
colors points by grade and markers
distinguishes Freshman (circles) from Sophomore (squares). This is particularly valuable in AI for checking correlations (e.g., does Study_Hours
have anything to do with scores) and finding clusters that may help derive the features for your machine learning models. It helps you quickly see the big picture of what is happening with your data, similar to how you would resolve object relations in Java before you build a system.
Heatmap: Correlation Insights
A heatmap visualizes correlations between numerical variables:
# Calculate correlation matrix for numerical columns corr = df[['Math', 'Science', 'English', 'Study_Hours']].corr() # Create a heatmap sns.heatmap( corr, # Correlation matrix to plot annot=True, # Show correlation values on the heatmap cmap='coolwarm', # Use red (positive) to blue (negative) color scheme vmin=-1, vmax=1 # Set color scale range from -1 to 1 ) # Add a title plt.title('Correlation Heatmap of Student Metrics') # Display the plot plt.show()
First we calculate the correlation matrix with . corr()
, which is score showing how strongly each two numerical variable (e.g. Math
and Study_Hours
) changing together (-1 to 1). The heatmap plots those values in colors (red for positive and blue for negative correlations) and annot=True
adds the exact number in each cell. In AI, this helps you to find out redundant features (like Math
and Science
being correlated, you might not need both in a model) and key predictors (like Study_Hours
strongly correlating to scores). It’s an important tool in machine learning to avoid multicollinearity, much like optimizing data structures in Java to avoid redundancy.
Violin Plot: Distribution by Category
A violin plot shows the distribution of a numerical variable across categories:
# Create a violin plot sns.violinplot( x='Grade_Level', # Categorical variable for x-axis (Freshman vs Sophomore) y='Math', # Numerical variable for y-axis (Math scores) data=df # DataFrame source ) # Add a title plt.title('Math Score Distribution by Grade Level') # Display the plot plt.show()
This violin plot compares Math
score distributions for Freshman and Sophomore students. The wider parts of each “violin” show where scores are concentrated, while the tails indicate the range. Unlike a box plot, it also shows density, giving a fuller picture of the data’s shape—like a histogram stretched vertically. In AI, this is handy for understanding how a feature (like scores) varies across groups—e.g., are Sophomores scoring higher? This can inform whether Grade_Level
is a useful feature for predicting outcomes like Pass
status, akin to analyzing group differences in Java data processing.
Swarm Plot: Individual Points
A swarm plot displays individual data points without overlap:
# Create a swarm plot sns.swarmplot( x='Pass', # Categorical variable for x-axis (True vs False) y='Science', # Numerical variable for y-axis (Science scores) data=df, # DataFrame source hue='Grade_Level' # Color points by Grade_Level ) # Add a title plt.title('Science Scores by Pass Status') # Display the plot plt.show()
This swarm plot shows each student’s Science
score, grouped by Pass
status (True or False) and colored by Grade_Level
. Unlike a scatter plot, it prevents overlap by spreading points horizontally within each category, making every score visible—think of it as a detailed view of individual data points. In AI, this is excellent for spotting outliers (e.g., a failing student with a high score) or checking if Pass
status aligns with scores across grades. It’s a granular view that complements broader distribution plots, aiding in data quality checks before modeling, much like debugging individual records in Java.
Box Plot (Bonus): Comparing Distributions
A box plot summarizes numerical data distributions across multiple variables:
# Create a horizontal box plot for multiple columns sns.boxplot( data=df[['Math', 'Science', 'English']], # Select numerical columns to plot orient='h' # Horizontal orientation for readability ) # Add a title plt.title('Box Plot of Student Scores Across Subjects') # Display the plot plt.show()
This box plot displays the distribution of Math
, Science
, and English
scores in a horizontal layout (orient='h'
). Each box shows the median (central line), quartiles (box edges), and whiskers (range), with potential outliers as points beyond the whiskers. By passing multiple columns via df[['Math', 'Science', 'English']]
, it compares all subjects at once—no need to specify x or y axes individually. In AI, this is perfect for quickly assessing spread and central tendencies across features—e.g., is English
more consistent than Math
? It helps identify data variability or outliers that might affect model performance, similar to how you’d profile data ranges in Java before processing.
Try It Yourself: An Exercise
Using the student dataset, try these tasks:
- Create a box plot of English scores by Grade_Level.
- Create a heatmap of correlations excluding Study_Hours.
- Create a violin plot of Study_Hours by Pass status.
Hint: Use sns.boxplot
for ranges, df.corr()
for correlations, and sns.violinplot
for distributions. Give it a shot, then check the solution below!
Solution
# 1. Box plot of English scores by Grade_Level sns.boxplot( x='Grade_Level', # Categorical x-axis (Freshman vs Sophomore) y='English', # Numerical y-axis (English scores) data=df # DataFrame source ) plt.title('English Scores by Grade Level') # Add a descriptive title plt.show() # Display the plot # 2. Heatmap of correlations excluding Study_Hours corr = df[['Math', 'Science', 'English']].corr() # Calculate correlations for selected columns sns.heatmap( corr, # Correlation matrix to plot annot=True, # Display correlation values cmap='coolwarm' # Color scheme for visual clarity ) plt.title('Correlation Heatmap (Excluding Study_Hours)') # Title for context plt.show() # Display the plot # 3. Violin plot of Study_Hours by Pass status sns.violinplot( x='Pass', # Categorical x-axis (True vs False) y='Study_Hours', # Numerical y-axis (study hours) data=df # DataFrame source ) plt.title('Study Hours by Pass Status') # Add a title plt.show() # Display the plot
These solutions showcase different Seaborn strengths:
- Box Plot: Shows median, quartiles, and outliers of
English
scores per grade usingsns.boxplot
—great for comparing central tendencies and spread in AI data prep. It’s like summarizing data ranges in Java. - Heatmap: Focuses on score correlations with
df.corr()
, excludingStudy_Hours
to simplify. It’s a quick way to check if subjects are redundant for modeling, similar to optimizing data in Java. - Violin Plot: Reveals how
Study_Hours
distribute for passing vs. failing students withsns.violinplot
, highlighting density differences—useful for feature importance analysis in AI workflows.
palette
or add hue
—to see how they adapt!
Next Steps
You just harnessed Seaborn’s power to AI visualization! In Part 6, we’ll go into scikit-learn and start to build machine learning models. Get familiar with these plots—improvise with styles or factor more variables into your visualizations—to solidify your skills. Its ease and elegance makes Seaborn an essential step in your AI toolkit.
Code Demo
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
# print(sns.__file__) | |
# print(plt.__file__) | |
# print(pd.__file__) | |
# Let’s start with a simple example using the student data from previous parts: | |
data = { | |
'Name': ['Alice', 'Bob', 'Charlie', 'Dana'], | |
'Math': [85, 92, 78, 95], | |
'Science': [88, 85, 90, 92] | |
} | |
df = pd.DataFrame(data) | |
# Simple scatter plot using Seaborn | |
sns.lmplot(x='Math', y='Science', data=df, fit_reg=False) | |
plt.title('Math vs. Science Scores') | |
plt.show() | |
# Sample dataset for testing Seaborn features | |
data = { | |
'Name': ['Alice', 'Bob', 'Charlie', 'Dana', 'Eve', 'Frank', 'Grace', 'Hank', 'Ivy', 'Jack'], | |
'Math': [85, 92, 78, 95, 88, 65, 90, 82, 77, 94], | |
'Science': [88, 85, 90, 92, 80, 70, 87, 83, 79, 91], | |
'English': [90, 88, 85, 93, 87, 75, 89, 84, 80, 92], | |
'Study_Hours': [5, 6, 4, 7, 5.5, 3, 6.5, 4.5, 4, 6], | |
'Grade_Level': ['Freshman', 'Sophomore', 'Freshman', 'Sophomore', 'Freshman', 'Freshman', 'Sophomore', 'Sophomore', 'Freshman', 'Sophomore'], | |
'Pass': [True, True, False, True, True, False, True, True, False, True] | |
} | |
df = pd.DataFrame(data) | |
# 1. Pair Plot | |
# Create a pair plot with Seaborn | |
sns.pairplot( | |
df, # DataFrame to plot | |
hue='Grade_Level', # Color points by Grade_Level (Freshman vs Sophomore) | |
vars=['Math', 'Science', 'English', 'Study_Hours'], # Select numerical columns to compare | |
markers=['o', 's'] # Use circles for Freshman, squares for Sophomore | |
) | |
# Display the plot | |
plt.show() | |
# 2. Heatmap: Display a correlation matrix: | |
# Calculate correlation matrix for numerical columns | |
corr = df[['Math', 'Science', 'English', 'Study_Hours']].corr() | |
# Create a heatmap | |
sns.heatmap( | |
corr, # Correlation matrix to plot | |
annot=True, # Show correlation values on the heatmap | |
cmap='coolwarm', # Use red (positive) to blue (negative) color scheme | |
vmin=-1, vmax=1 # Set color scale range from -1 to 1 | |
) | |
# Add a title | |
plt.title('Correlation Heatmap of Student Metrics') | |
# Display the plot | |
plt.show() | |
# 3. Violin Plot: Distribution by Category | |
# A violin plot shows the distribution of a numerical variable across categories: | |
# Create a violin plot | |
sns.violinplot( | |
x='Grade_Level', # Categorical variable for x-axis (Freshman vs Sophomore) | |
y='Math', # Numerical variable for y-axis (Math scores) | |
data=df # DataFrame source | |
) | |
# Add a title | |
plt.title('Math Score Distribution by Grade Level') | |
# Display the plot | |
plt.show() | |
# 4. Swarm Plot: Individual Points | |
# A swarm plot displays individual data points without overlap: | |
# Create a swarm plot | |
sns.swarmplot( | |
x='Pass', # Categorical variable for x-axis (True vs False) | |
y='Science', # Numerical variable for y-axis (Science scores) | |
data=df, # DataFrame source | |
hue='Grade_Level' # Color points by Grade_Level | |
) | |
# Add a title | |
plt.title('Science Scores by Pass Status') | |
# Display the plot | |
plt.show() | |
# 5. Box Plot (Bonus): Comparing Distributions | |
# A box plot summarizes numerical data distributions across multiple variables: | |
# Create a horizontal box plot for multiple columns | |
sns.boxplot( | |
data=df[['Math', 'Science', 'English']], # Select numerical columns to plot | |
orient='h' # Horizontal orientation for readability | |
) | |
# Add a title | |
plt.title('Box Plot of Student Scores Across Subjects') | |
# Display the plot | |
plt.show() | |
# An Exercise: | |
# Using the student dataset, try these tasks: | |
# 1. Create a box plot of English scores by Grade_Level. | |
# 2. Create a heatmap of correlations excluding Study_Hours. | |
# 3. Create a violin plot of Study_Hours by Pass status. | |
# Hint: Use sns.boxplot for ranges, df.corr() for correlations, | |
# and sns.violinplot for distributions. | |
# 1. Box plot of English scores by Grade_Level | |
sns.boxplot( | |
x='Grade_Level', # Categorical x-axis (Freshman vs Sophomore) | |
y='English', # Numerical y-axis (English scores) | |
data=df # DataFrame source | |
) | |
plt.title('English Scores by Grade Level') # Add a descriptive title | |
plt.show() # Display the plot | |
# 2. Heatmap of correlations excluding Study_Hours | |
corr = df[['Math', 'Science', 'English']].corr() # Calculate correlations for selected columns | |
sns.heatmap( | |
corr, # Correlation matrix to plot | |
annot=True, # Display correlation values | |
cmap='coolwarm' # Color scheme for visual clarity | |
) | |
plt.title('Correlation Heatmap (Excluding Study_Hours)') # Title for context | |
plt.show() # Display the plot | |
# 3. Violin plot of Study_Hours by Pass status | |
sns.violinplot( | |
x='Pass', # Categorical x-axis (True vs False) | |
y='Study_Hours', # Numerical y-axis (study hours) | |
data=df # DataFrame source | |
) | |
plt.title('Study Hours by Pass Status') # Add a title | |
plt.show() # Display the plot |
Conclusion
Seaborn does interactively pique your data visualization skills on another level especially useful for AI. Pair plots help identify relationships among features and box plots summarize distributions; now you can analyze datasets with confidence. Play with this student data some more, and you’ll be set to take on machine learning! Happy plotting!