Posted on

I have been using dask for speeding up some larger scale analyses. Dask is a really great tool for inplace replacement for parallelizing some pyData-powered analyses, such as numpy, pandas and even scikit-learn.

However, I recently found an interesting case where using same syntax in dask.dataframe for pandas.dataframe does not acheive what I want. So in this post, I will document how to overcome it for my future self.

As usual, lets import all the useful libraries:

import pandas as pd
import dask.dataframe as dd

I will use the famous titanic dataset as an example to show that how dask can act weirdly under groupby + apply operations.

titanic = pd.read_csv('')
Survived Pclass Name Sex Age Siblings/Spouses Aboard Parents/Children Aboard Fare
0 0 3 Mr. Owen Harris Braund male 22.0 1 0 7.2500
1 1 1 Mrs. John Bradley (Florence Briggs Thayer) Cum... female 38.0 1 0 71.2833
2 1 3 Miss. Laina Heikkinen female 26.0 0 0 7.9250
3 1 1 Mrs. Jacques Heath (Lily May Peel) Futrelle female 35.0 1 0 53.1000
4 0 3 Mr. William Henry Allen male 35.0 0 0 8.0500

I will illustrate the problem by counting how many survivors in each age and sex group, using the following function:

def count_survival(d):
    summarize survivor, and return an dataframe for the single value-ed array
    return pd.DataFrame({'survived':[d.Survived.sum()]})

A regular pandas way to do it would be:

titanic    \
Age Sex
0.42 male 0 1
0.67 male 0 1
0.75 female 0 2
0.83 male 0 2
0.92 male 0 1

Lets translate the pandas.dataframe to a dask.dataframe and do the same

dask_job = titanic \
    .pipe(dd.from_pandas, npartitions=24)\
    .groupby(['Age','Sex']) \
    .apply(count_survival, meta={'survived':'f8'}) 

This is not going to return any result until we do dask_job.compute(), but dask also include a visualize function to show the task graph:



The resultant task graph is much more complicated than I would've expected, and this is actually because data shuffling behind the scene. Suggested by the dask documentation, this issue can be resolved by setting a groupby key as index:

dask_job = titanic \
    .pipe(dd.from_pandas, npartitions=24)\
    .groupby(['Age','Sex']) \
    .apply(count_survival, meta={'survived':'f8'}) 
