Pyspark Covid Example

Objective

  1. Explain the core functionalities and use cases of PySpark for big data processing and Pandas for data manipulation.
  2. Set up the environment: Install and configure PySpark and Pandas to work together in a Python environment.
  3. Load and explore data: Import data into Pandas and PySpark DataFrames and perform basic data exploration.
  4. Convert between DataFrames: Convert a Pandas DataFrame to a Spark DataFrame for distributed processing.
  5. Perform data manipulation: Create new columns, filter data, and perform aggregations using PySpark.
  6. Utilize SQL queries: Use Spark SQL for querying data and leveraging user-defined functions (UDFs)

Setup


NOTE: This example is performed in jupyter notebook and brought over. So installing the packages at the start are part of the process if it is your first time, at all other times the conditions would have been satisfied already. I’ll just show them for information purposes

Install Libraries

  • I am using quarto for this document so I will install these packages first using the commands below
  • It will load the packages in the virtual environment located at
  • "C:/~/EMHRC/OneDrive/Documents/.virtualenvs/r-reticulate......"
!pip install pyspark
!pip install findspark
!pip install pandas

Local Drive

  • Using Windows cmd terminal I need to install these packages in python
  • From C:\~\ use pip to install
  • pip install pyspark
  • pip install findspark
  • pip install pandas

Initialize Spark Session


A Spark session is crucial for working with PySpark. It enables DataFrame creation, data loading, and various operations.

Importing libraries

  • findspark is used to locate the Spark installation.
  • pandas is imported for data manipulation.
import findspark  # This helps us find and use Apache Spark
findspark.init()  # Initialize findspark to locate Spark

from pyspark.sql import SparkSession  
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DateType
import pandas as pd  

Creating a Spark session

  • SparkSession.builder.appName("COVID-19 Data Analysis").getOrCreate() initializes a Spark session with the specified application name.

Checking Spark session

  • The code checks if the Spark session is active and prints an appropriate message.
# Initialize a Spark Session
spark = SparkSession\
        .builder\
        .appName("COVID-19 Data Analysis")\
        .config("spark.sql.execution.arrow.pyspark.enabled", "true")\
        .getOrCreate()
# Check if the Spark Session is active
if 'spark' in locals() and isinstance(spark, SparkSession):
    print("SparkSession is active and ready to use.")
else:
    print("SparkSession is not active. Please create a SparkSession.")
    
# ----- OUTPUT
This is run on jupyter notebook not in this quarto document***
WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
SparkSession is active and ready to use.

Import Data to Pandas


Import from SCV

  • Let’s read the COVID-19 data from the URL
vaccination_data = pd.read_csv('https://cf-courses-data.s3.us.cloud-object-storage.appdomain.cloud/KpHDlIzdtR63BdTofl1mOg/owid-covid-latest.csv')

Query First 5

  • vaccination_data.head() retrieves the first five rows of the DataFrame vaccination_data.This gives us a quick look at the data contained within the data set.
  • The print function is used to display a message indicating what is being shown, followed by the actual data.

Filter Columns

  • Let’s define a list called columns_to_display, which contains the names of the columns as : ['continent', 'total_cases', 'total_deaths', 'total_vaccinations', 'population'].
  • By using vaccination_data[columns_to_display].head(), let’s filter the DataFrame to only show the specified columns and again display the first five records of this subset.
  • The continent column is explicitly converted to string, while the numeric columns (total cases, total deaths, total vaccinations, population) are filled with zeros for NaN values and then converted to int64 (which is compatible with LongType in Spark).
  • The use of fillna(0) ensures that NaN values do not cause type issues during the Spark DataFrame creation.
print("Displaying the first 5 records of the vaccination data:")

columns_to_display = ['continent', 'total_cases', 'total_deaths', 'total_vaccinations', 'population']

# Show the first 5 records
print(vaccination_data[columns_to_display].head())

# OUTPUT
Displaying the first 5 records of the vaccination data:
  continent  total_cases  total_deaths  total_vaccinations    population
0      Asia     235214.0        7998.0                 NaN  4.112877e+07
1       NaN   13145380.0      259117.0                 NaN  1.426737e+09
2    Europe     335047.0        3605.0                 NaN  2.842318e+06
3    Africa     272139.0        6881.0                 NaN  4.490323e+07
4   Oceania       8359.0          34.0                 NaN  4.429500e+04

Pandas df to Spark DataFrame


Let’s convert the Pandas DataFrame, which contains our COVID-19 vaccination data, into a Spark DataFrame. This conversion is crucial as it allows us to utilize Spark’s distributed computing capabilities, enabling us to handle larger datasets and perform operations in a more efficient manner.

Defining the schema

  • StructType:
    • A class that defines a structure for a DataFrame.
    • StructField:
      • Represents a single field in the schema.

      • Parameters:

        1. Field name: The name of the field.

        2. Data type: The type of data for the field.

        3. Nullable: A boolean indicating whether null values are allowed.

    • Data types:
      • StringType(): Used for text fields.

      • LongType(): Used for numerical fields.

Transform

Data type conversion:

  • astype(str):
    • Used to convert the 'continent' column to string type.
  • fillna(0):
    • Replaces any NaN values with 0, ensuring that the numerical fields do not contain any missing data.
  • astype(‘int64’):
    • Converts the columns from potentially mixed types to 64-bit integers for consistent numerical representation.

Filter out Columns

Creating a Spark DataFrame

  • createDataFrame:
    • The createDataFrame method of the Spark session (spark) is called with vaccination_data (the Pandas DataFrame) as its argument.

    • Parameters:

      • It takes as input a subset of the pandas DataFrame that corresponds to the fields defined in the schema, accessed using schema.fieldNames().
  • This function automatically converts the Pandas DataFrame into a Spark DataFrame, which is designed to handle larger data sets across a distributed environment.
  • The resulting spark_df will have the defined schema, which ensures consistency and compatibility with Spark’s data processing capabilities.

Storing the result

# Convert to Spark DataFrame directly
# Define the schema
schema = StructType([
    StructField("continent", StringType(), True),
    StructField("total_cases", LongType(), True),
    StructField("total_deaths", LongType(), True),
    StructField("total_vaccinations", LongType(), True),
    StructField("population", LongType(), True)
])

# Convert the columns to the appropriate data types
vaccination_data['continent'] = vaccination_data['continent'].astype(str)  # Ensures continent is a string
vaccination_data['total_cases'] = vaccination_data['total_cases'].fillna(0).astype('int64')  # Fill NaNs and convert to int
vaccination_data['total_deaths'] = vaccination_data['total_deaths'].fillna(0).astype('int64')  # Fill NaNs and convert to int
vaccination_data['total_vaccinations'] = vaccination_data['total_vaccinations'].fillna(0).astype('int64')  # Fill NaNs and convert to int
vaccination_data['population'] = vaccination_data['population'].fillna(0).astype('int64')  # Fill NaNs and convert to int

# Filter out unwanted columns
spark_df = spark.createDataFrame(vaccination_data[schema.fieldNames()])  # Use only the specified fields in this case all fieldNames()
# Show the Spark DataFrame
spark_df.show()

# OUTPUT
+-------------+-----------+------------+------------------+----------+
|    continent|total_cases|total_deaths|total_vaccinations|population|
+-------------+-----------+------------+------------------+----------+
|         Asia|     235214|        7998|                 0|  41128772|
|          nan|   13145380|      259117|                 0|1426736614|
|       Europe|     335047|        3605|                 0|   2842318|
|       Africa|     272139|        6881|                 0|  44903228|
|      Oceania|       8359|          34|                 0|     44295|
|       Europe|      48015|         159|                 0|     79843|
|       Africa|     107481|        1937|                 0|  35588996|
|North America|       3904|          12|                 0|     15877|
|North America|       9106|         146|                 0|     93772|
|South America|   10101218|      130663|                 0|  45510324|
|         Asia|     452273|        8777|                 0|   2780472|
|North America|      44224|         292|                 0|    106459|
|          nan|  301499099|     1637249|        9104304615|4721383370|
|      Oceania|   11861161|       25236|                 0|  26177410|
|       Europe|    6082444|       22534|                 0|   8939617|
|         Asia|     835757|       10353|                 0|  10358078|
|North America|      39127|         849|                 0|    409989|
|         Asia|     696614|        1536|                 0|   1472237|
|         Asia|    2051348|       29499|                 0| 171186368|
|North America|     108582|         593|                 0|    281646|
+-------------+-----------+------------+------------------+----------+
only showing top 20 rows

Spark DataFrame Structure


Let’s examine the structure of the Spark DataFrame that we created from the Pandas DataFrame. Understanding the schema of a DataFrame is crucial as it provides insight into the data types of each column and helps ensure that the data is organized correctly for analysis.

Displaying the schema

  • The method spark_df.printSchema() is called to output the structure of the Spark DataFrame.
  • This method prints the names of the columns along with their data types (e.g., StringType, IntegerType, DoubleType, etc.), providing a clear view of how the data is organized.
print("Schema of the Spark DataFrame:")
spark_df.printSchema()
# Print the structure of the DataFrame (columns and types)

# OUTPUT
Schema of the Spark DataFrame:
root
 |-- continent: string (nullable = true)
 |-- total_cases: long (nullable = true)
 |-- total_deaths: long (nullable = true)
 |-- total_vaccinations: long (nullable = true)
 |-- population: long (nullable = true)

Explore DataFrame


View DataFrame contents

  • To view the contents in the DataFrame, use the following code:
# List the names of the columns you want to display
columns_to_display = ['continent', 'total_cases', 'total_deaths', 'total_vaccinations', 'population']
# Display the first 5 records of the specified columns
spark_df.select(columns_to_display).show(5)

# OUTPUT
+---------+-----------+------------+------------------+----------+
|continent|total_cases|total_deaths|total_vaccinations|population|
+---------+-----------+------------+------------------+----------+
|     Asia|     235214|        7998|                 0|  41128772|
|      nan|   13145380|      259117|                 0|1426736614|
|   Europe|     335047|        3605|                 0|   2842318|
|   Africa|     272139|        6881|                 0|  44903228|
|  Oceania|       8359|          34|                 0|     44295|
+---------+-----------+------------+------------------+----------+
only showing top 5 rows

Filter Columns

  • To view certain columns
print("Displaying the 'continent' and 'total_cases' columns:")
# Show only the 'continent' and 'total_cases' columns
spark_df.select('continent', 'total_cases').show(5)

# OUTPUT
Displaying the 'continent' and 'total_cases' columns:
+---------+-----------+
|continent|total_cases|
+---------+-----------+
|     Asia|     235214|
|      nan|   13145380|
|   Europe|     335047|
|   Africa|     272139|
|  Oceania|       8359|
+---------+-----------+
only showing top 5 rows

Filter On Conditions

print("Filtering records where 'total_cases' is greater than 1,000,000:")
 # Show records with more than 1 million total cases
spark_df.filter(spark_df['total_cases'] > 1000000).show(5) 

# OUTPUT
Filtering records where 'total_cases' is greater than 1,000,000:
+-------------+-----------+------------+------------------+----------+
|    continent|total_cases|total_deaths|total_vaccinations|population|
+-------------+-----------+------------+------------------+----------+
|          nan|   13145380|      259117|                 0|1426736614|
|South America|   10101218|      130663|                 0|  45510324|
|          nan|  301499099|     1637249|        9104304615|4721383370|
|      Oceania|   11861161|       25236|                 0|  26177410|
|       Europe|    6082444|       22534|                 0|   8939617|
+-------------+-----------+------------+------------------+----------+
only showing top 5 rows

Create New Column


Create a new column called death_percentage, which calculates the death rate during the COVID-19 pandemic. This calculation is based on the total_deaths (the count of deaths) and the population (the total population) columns in our Spark DataFrame. This new metric will provide valuable insight into the impact of COVID-19 in different regions.

  • Let’s import the functions module from pyspark.sql as F, which contains built-in functions for DataFrame operations.

Calculating the death percentage:

  • Let’s create a new DataFrame spark_df_with_percentage by using the withColumn() method to add a new column called death_percentage.
  • The formula (spark_df['total_deaths'] / spark_df['population']) * 100 computes the death percentage by dividing the total deaths by the total population and multiplying by 100.

Formatting the percentage:

  • Let’s update the death_percentage column to format its values to two decimal places using F.format_number(), and concatenate a percentage symbol using F.concat() and F.lit('%').
  • This makes the death percentage easier to read and interpret.

Selecting relevant columns:

  • Let’s define a list columns_to_display that includes 'total_deaths', 'population', 'death_percentage', 'continent', 'total_vaccinations', and 'total_cases'.
  • Finally, let’s display the first five records of the modified DataFrame with the new column by calling spark_df_with_percentage.select(columns_to_display).show(5).
from pyspark.sql import functions as F

spark_df_with_percentage = spark_df.withColumn(
    'death_percentage', 
    (spark_df['total_deaths'] / spark_df['population']) * 100
)
spark_df_with_percentage = spark_df_with_percentage.withColumn(
    'death_percentage',
    F.concat(
        # Format to 2 decimal places
        F.format_number(spark_df_with_percentage['death_percentage'], 2), 
        # Append the percentage symbol 
        F.lit('%')  
    )
)
columns_to_display = ['total_deaths', 'population', 'death_percentage', 'continent', 'total_vaccinations', 'total_cases']
spark_df_with_percentage.select(columns_to_display).show(5)

# OUTPUT
+------------+----------+----------------+---------+------------------+-----------+
|total_deaths|population|death_percentage|continent|total_vaccinations|total_cases|
+------------+----------+----------------+---------+------------------+-----------+
|        7998|  41128772|           0.02%|     Asia|                 0|     235214|
|      259117|1426736614|           0.02%|      nan|                 0|   13145380|
|        3605|   2842318|           0.13%|   Europe|                 0|     335047|
|        6881|  44903228|           0.02%|   Africa|                 0|     272139|
|          34|     44295|           0.08%|  Oceania|                 0|       8359|
+------------+----------+----------------+---------+------------------+-----------+
only showing top 5 rows

Group & Summarize


Let’s calculate the total number of deaths per continent using the data in our Spark DataFrame. Grouping and summarizing data is a crucial aspect of data analysis, as it allows us to aggregate information and identify trends across different categories.

Grouping the data

The spark_df.groupby(['continent']) method groups the data by the continent column. This means that all records associated with each continent will be aggregated together.

Aggregating the deaths

The agg({"total_deaths": "SUM"}) function is used to specify the aggregation operation. In this case, we want to calculate the sum of the total_deaths for each continent. This operation will create a new DataFrame where each continent is listed alongside the total number of deaths attributed to it.

Displaying the results

The show() method is called to display the results of the aggregation. This will output the total number of deaths for each continent in a tabular format.

print("Calculating the total deaths per continent:")
# Group by continent and sum total death rates
spark_df.groupby(['continent']).agg({"total_deaths": "SUM"}).show()  

# OUTPUT
Calculating the total deaths per continent:
[Stage 14:=========================>                              (34 + 9) / 75]
+-------------+-----------------+
|    continent|sum(total_deaths)|
+-------------+-----------------+
|       Europe|          2102483|
|       Africa|           259117|
|          nan|         22430618|
|North America|          1671178|
|South America|          1354187|
|      Oceania|            32918|
|         Asia|          1637249|
+-------------+-----------------+

User-defined functions (UDFs)


UDFs in PySpark allow us to create custom functions that can be applied to individual columns within a DataFrame. This feature provides increased flexibility and customization in data processing, enabling us to define specific transformations or calculations that are not available through built-in functions. In this section, let’s define a UDF to convert total deaths in the dataset.

Importing pandas_udf

The pandas_udf function is imported from pyspark.sql.functions. This decorator allows us to define a UDF that operates on Pandas Series

Defining the UDF

This function convert_total_deaths() takes in a parameter total_deaths and returns double its value. You can replace the logic with any transformation you want to apply to the column data.

Registering the UDF

The line spark.udf.register("convert_total_deaths", convert_total_deaths, IntegerType()) registers the UDF with Spark indicating that the function returns an integer, allowing us to use it in Spark SQL queries and DataFrame operations.

from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType
# Function definition
def convert_total_deaths(total_deaths):
    return total_deaths * 2
# Here you can define any transformation you want
# Register the UDF with Spark
spark.udf.register("convert_total_deaths", convert_total_deaths, IntegerType())

# OUTPUT
<function __main__.convert_total_deaths(total_deaths)>

SparkSQL


  • We can execute SQL queries directly on Spark DataFrames
# Drop the existing temporary view if it exists
spark.sql("DROP VIEW IF EXISTS data_v")

# Create a new temporary view
spark_df.createTempView('data_v')

# Execute the SQL query using the UDF
spark.sql('SELECT continent, total_deaths, convert_total_deaths(total_deaths) as converted_total_deaths FROM data_v').show()

# OUTPUT
+-------------+------------+----------------------+
|    continent|total_deaths|converted_total_deaths|
+-------------+------------+----------------------+
|         Asia|        7998|                 15996|
|          nan|      259117|                518234|
|       Europe|        3605|                  7210|
|       Africa|        6881|                 13762|
|      Oceania|          34|                    68|
|       Europe|         159|                   318|
|       Africa|        1937|                  3874|
|North America|          12|                    24|
|North America|         146|                   292|
|South America|      130663|                261326|
|         Asia|        8777|                 17554|
|North America|         292|                   584|
|          nan|     1637249|               3274498|
|      Oceania|       25236|                 50472|
|       Europe|       22534|                 45068|
|         Asia|       10353|                 20706|
|North America|         849|                  1698|
|         Asia|        1536|                  3072|
|         Asia|       29499|                 58998|
|North America|         593|                  1186|
+-------------+------------+----------------------+
only showing top 20 rows

Query Data


Let’s execute SQL queries to retrieve specific records from the temporary view which was created earlier. Let’s demonstrate how to display all records from the data table and filter those records based on vaccination totals. This capability allows for efficient data exploration and analysis using SQL syntax.

Displaying All Records

The first query retrieves all records from the temporary view data using the SQL command SELECT * FROM data_v. The show() method is called to display the results in a tabular format. This is useful for getting an overview of the entire dataset.

spark.sql('SELECT * FROM data_v').show()

# OUTPUT
+-------------+-----------+------------+------------------+----------+
|    continent|total_cases|total_deaths|total_vaccinations|population|
+-------------+-----------+------------+------------------+----------+
|         Asia|     235214|        7998|                 0|  41128772|
|          nan|   13145380|      259117|                 0|1426736614|
|       Europe|     335047|        3605|                 0|   2842318|
|       Africa|     272139|        6881|                 0|  44903228|
|      Oceania|       8359|          34|                 0|     44295|
|       Europe|      48015|         159|                 0|     79843|
|       Africa|     107481|        1937|                 0|  35588996|
|North America|       3904|          12|                 0|     15877|
|North America|       9106|         146|                 0|     93772|
|South America|   10101218|      130663|                 0|  45510324|
|         Asia|     452273|        8777|                 0|   2780472|
|North America|      44224|         292|                 0|    106459|
|          nan|  301499099|     1637249|        9104304615|4721383370|
|      Oceania|   11861161|       25236|                 0|  26177410|
|       Europe|    6082444|       22534|                 0|   8939617|
|         Asia|     835757|       10353|                 0|  10358078|
|North America|      39127|         849|                 0|    409989|
|         Asia|     696614|        1536|                 0|   1472237|
|         Asia|    2051348|       29499|                 0| 171186368|
|North America|     108582|         593|                 0|    281646|
+-------------+-----------+------------+------------------+----------+
only showing top 20 rows

Filter Records

The second query is designed to filter the data set to show only those continents where the total vaccinations exceed 1 million. The SQL command used here is SELECT continent FROM data_v WHERE total_vaccinations > 1000000. The show() method is again used to display the results, specifically listing the continents that meet the filter criteria.

print("Displaying continent with total vaccinated more than 1 million:")
# SQL filtering
spark.sql("SELECT continent FROM data_v WHERE total_vaccinations > 1000000").show()

# OUTPUT
Displaying continent with total vaccinated more than 1 million:
+-------------+
|    continent|
+-------------+
|          nan|
|North America|
|       Europe|
|       Europe|
|          nan|
|          nan|
|          nan|
|         Asia|
|         Asia|
|       Europe|
|          nan|
|         Asia|
|      Oceania|
|          nan|
|          nan|
|          nan|
|          nan|
+-------------+