Testing with Apache Spark and Python

Apache spark and pyspark in particular are fantastically powerful frameworks for large scale data processing and analytics.  In the past I’ve written about flink’s python api a couple of times, but my day-to-day work is in pyspark, not flink.  With any data processing pipeline, thorough testing is critical to ensuring veracity of the end-result, so along the way I’ve learned a few rules of thumb and build some tooling for testing pyspark projects.

Breaking out lambdas

When building out a job in pyspark, it can be very tempting to over-use the lambda functions.  So for instance a simple map function which just takes an rdd of lists and converts each element to a string could be written two ways:

rdd = rdd.map(lambda x: [str(y) for y in x])

or:

def stringify(x):
    return [str(y) for y in x]

...
rdd = rdd.map(stringify)

Which, while more verbose, exposes a pure python function that we can re-use and unit-test.

For a more involved example, lets write a similar function that takes in a RDD of dictionaries and stringifys the values of any key in a set of keys, keylist.  The first way:

rdd = rdd.map(lambda x: {k: str(v) if k in keylist else k: v for k, v in x.items()})

or the more testable way:

def stringify_values(x, keylist):
    return {k: str(v) if k in keylist else k: v for k, v in x.items()}

...

rdd = rdd.map(lambda x: stringify_values(x, keylist))

This simple rule of thumb goes a long way to increasing the testability of a pyspark codebase, but sometimes you do need to test the spark-y portions of code.

DummyRDD

One way to do that, for larger scale tests, is to just run a local instance of spark for the sake of the tests, but this can be slow, especially if you are having to spin up/down spark contexts over and over for different tests (if you do want to do that, here is a great example of how to).  To get around that, I’ve started a project to write a mock version of pyspark which uses pure python datastructures under the hood to replicate pyspark behavior.

It is only intended for testing, and doesn’t begin to approach the full capabilities or API of pyspark (notably the dataframe or dataset APIs), but it is getting pretty close to having implemented the RDD functionality.   Check out the source here: https://github.com/wdm0006/DummyRDD.  DummyRDD works by implementing the underlying RDD data structure simply as a python list, so that you can use python’s map, filter, etc on that list as if it were an RDD.  Of course, spark is lazily loaded, so to get comparable outcomes, we actually store copies of each intermediate step in memory, so large spark jobs run with the dummy backend will consume large amounts of memory, but for testing this may be ok.

A quick example, showing off some of the methods that are implemented:

import os
import random

from dummy_spark import SparkContext, SparkConf
from dummy_spark.sql import SQLContext
from dummy_spark import RDD

__author__ = 'willmcginnis'

# make a spark conf
sconf = SparkConf()

# set some property (won't do anything)
sconf.set('spark.executor.extraClassPath', 'foo')

# use the spark conf to make a spark context
sc = SparkContext(master='', conf=sconf)

# set the log level (also doesn't do anything)
sc.setLogLevel('INFO')

# maybe make a useless sqlcontext (nothing implimented here yet)
sqlctx = SQLContext(sc)

# add pyfile just appends to the sys path
sc.addPyFile(os.path.dirname(__file__))

# do some hadoop configuration into the ether
sc._jsc.hadoopConfiguration().set('foo', 'bar')

# maybe make some data
rdd = sc.parallelize([1, 2, 3, 4, 5])

# map and collect
print('nmap()')
rdd = rdd.map(lambda x: x ** 2)
print(rdd.collect())

# add some more in there
print('nunion()')
rdd2 = sc.parallelize([2, 4, 10])
rdd = rdd.union(rdd2)
print(rdd.collect())

# filter and take
print('nfilter()')
rdd = rdd.filter(lambda x: x > 4)
print(rdd.take(10))

# flatmap
print('nflatMap()')
rdd = rdd.flatMap(lambda x: [x, x, x])
print(rdd.collect())

# group by key
print('ngroupByKey()')
rdd = rdd.map(lambda x: (x, random.random()))
rdd = rdd.groupByKey()
print(rdd.collect())
rdd = rdd.mapValues(list)
print(rdd.collect())

# forEachPartition
print('nforEachPartition()')
rdd.foreachPartition(lambda x: print('partition: ' + str(x)))

The README contains a list of all implemented methods, which will gradually grow over time.

Conclusion

So the two main points are: break the code out such that the logic portions can be unit-tested outside of a spark context, and build larger-scale integration type tests using a mocked spark backend if possible to keep the tests quick-to-run.  Other than that, normal best-practices prevail.

The post Testing with Apache Spark and Python appeared first on Will’s Noise.