The Absolutely Simplest Consistent Hashing Example

This post was originally published here

Lately I've been studying Redis a lot. When using key/value databases
like Redis, as well as caches like Memcached,
if you want to scale keys across multiple nodes, you need a consistent
hashing algorithm. Consistent hashing is what we use when we want to distribute
a set of keys along a span of key/value servers in a…well consistent fashion.

If you Google around to learn what consistent hashing means,
the article that most directly tells you “the answer” without
a lot of handwringing is Consistent Hashing by Tom White.
Not only does it explain the concept very clearly, it even has a plain and
simple code example in Java.

The recipe in Tom's post is dependent on the capabilities of Java's TreeMap
which we don't have in Python, but after some contemplation it became apparent
that the functionality of circle.tailMap(hash) is something we already
have using bisect, that is, we have a
sorted array of integers, and a new number. Where in the array does the
new number go? bisect.bisect() will give you that, with the same efficiency as
TreeMap.

As a sanity check, I searched a bit more for Python implementations. I found
a recipe by Amir Salihefendic, which
seems to be based on the Java recipe and is pretty nice,
but in the post he's searching the circle for hash values using
a linear search, ouch! Turns out
Amir is in fact using bisect in his Python Cheese Shop package hash_ring, but by then it was too late, I had already written my own recipe
as well as tests (which hash_ring doesn't appear to have, at least in the downloaded
distribution).
There's also Continuum,
taking a slightly more heavy-handed approach (three separate classes and an expensive
IndexError being caught to detect keys beyond the circle). Both systems, Continuum
more so, seem to encourage using hostnames directly as keys – as noted by
Jeremy Zawodny,
with a persistent system like Redis this is a bad idea as it means you can't move
a particular key set to a new host.

So spending a bit of NIH capital, here's
my recipe, which provides a dictionary interface so that you can store hostnames or even actual
client instances, keyed to symbolic names:

import bisect
import md5

class ConsistentHashRing(object):
    """Implement a consistent hashing ring."""

    def __init__(self, replicas=100):
        """Create a new ConsistentHashRing.

        :param replicas: number of replicas.

        """
        self.replicas = replicas
        self._keys = []
        self._nodes = {}

    def _hash(self, key):
        """Given a string key, return a hash value."""

        return long(md5.md5(key).hexdigest(), 16)

    def _repl_iterator(self, nodename):
        """Given a node name, return an iterable of replica hashes."""

        return (self._hash("%s:%s" % (nodename, i))
                for i in xrange(self.replicas))

    def __setitem__(self, nodename, node):
        """Add a node, given its name.

        The given nodename is hashed
        among the number of replicas.

        """
        for hash_ in self._repl_iterator(nodename):
            if hash_ in self._nodes:
                raise ValueError("Node name %r is "
                            "already present" % nodename)
            self._nodes[hash_] = node
            bisect.insort(self._keys, hash_)

    def __delitem__(self, nodename):
        """Remove a node, given its name."""

        for hash_ in self._repl_iterator(nodename):
            # will raise KeyError for nonexistent node name
            del self._nodes[hash_]
            index = bisect.bisect_left(self._keys, hash_)
            del self._keys[index]

    def __getitem__(self, key):
        """Return a node, given a key.

        The node replica with a hash value nearest
        but not less than that of the given
        name is returned.   If the hash of the
        given name is greater than the greatest
        hash, returns the lowest hashed node.

        """
        hash_ = self._hash(key)
        start = bisect.bisect(self._keys, hash_)
        if start == len(self._keys):
            start = 0
        return self._nodes[self._keys[start]]

The map is used as a dictionary of node names to whatever you want, such as here we use Redis clients:

import redis
cr = = ConsistentHashRing(100)

cr["node1"] = redis.StrictRedis(host="host1")
cr["node2"] = redis.StrictRedis(host="host2")

client = cr["some key"]
data = client.get("some key")

I wanted to validate that the ring is in fact producing standard deviations like
those mentioned in the Java article, so this is tested like the following:

import unittest
import collections
import random
import math

class ConsistentHashRingTest(unittest.TestCase):
    def test_get_distribution(self):
        ring = ConsistentHashRing(100)

        numnodes = 10
        numhits = 1000
        numvalues = 10000

        for i in range(1, 1 + numnodes):
            ring["node%d" % i] = "node_value%d" % i

        distributions = collections.defaultdict(int)
        for i in xrange(numhits):
            key = str(random.randint(1, numvalues))
            node = ring[key]
            distributions[node] += 1

        # count of hits matches what is observed
        self.assertEquals(sum(distributions.values()), numhits)

        # I've observed standard deviation for 10 nodes + 100
        # replicas to be between 10 and 15.   Play around with
        # the number of nodes / replicas to see how different
        # tunings work out.
        standard_dev = self._pop_std_dev(distributions.values())
        self.assertLessEqual(standard_dev, 20)

        # if the stddev is good, it's safe to assume
        # all nodes were used
        self.assertEquals(len(distributions), numnodes)

        # just to test getting keys, see that we got the values
        # back and not keys or indexes or whatever.
        self.assertEquals(
                set(distributions.keys()),
                set("node_value%d" % i for i in range(1, 1 + numnodes))
            )

    def _pop_std_dev(self, population):
        mean = sum(population) / len(population)
        return math.sqrt(
                sum(pow(n - mean, 2) for n in population)
                / len(population)
            )

Related Posts

Gynvael’s Mission 11 (en): Python bytecode reverse-engineering Gynvael Coldwind is a security researcher at Google, who hosts weekly livestreams about security and programming in Polish and English). As part of th...
Leaving HPE For the past two years I have been employed by Hewlett Packard Enterprise to work on the various tools, libraries, and frameworks that make up the ope...
Unix locales vs Unicode (‘ascii’ codec can’t encode character…) You might get unusual errors about Unicode and inability to convert to ASCII. Programs might just crash at random. Those are often simple to fix &mdas...
Structuring and automating a Python project with the Python Project Template To create a project that other people can use and contribute to, you need to follow a specific directory structure. Moreover, releasing a new version ...

<div id="comment_blurb">The content of this post belongs to the original author, . To leave a comment, head on over to <a href="http://techspot.zzzeek.org/2012/07/07/the-absolutely-simplest-consistent-hashing-example">their original article</a>.</div>