Purely Functional Red Black Trees in Python

I can already hear the trolls coming for me. Why on earth would I write a purely functional data structure in quite possibly the slowest language for the task? It all came about because of a real world problem: Someone looked me in the eye and told me I wasn’t crazy enough to write a purely functional red black tree in Python.

Now did this actually happen? No. But it could happen any minute. It’s like getting lottery winner’s insurance for employers. It might not happen, but when one your best and most productive employees quits your company as a result of winning the lottery, you’ll laugh and say to yourself: “man I’m really glad I got that lottery winner’s insurance”

The Inspiration

Back at the Academy, one of my professors was Dr. Chris “I’m a Celebrity in the Functional Programming Game but I’m so Humble so I won’t Mention It Or Wear Tie-Dye Shirts” Okasaki. Just recently I revisited his book Purely Functional Datastructures, but when it was first introduced to me 10 years ago, I didn’t see the practical benefits of his ideas. In the same way, sitting in a college classroom and working through mind-bendy recursive problems on a blackboard day after day felt more like a programmer’s playground than actual practical exercises that you’d find yourself using out in the real world streets of white collar San Francisco, struggling to survive.

But after a few years on the streets of the developer game where talk about scalability is whispered on every corner, suddenly those whacky ideas about a complete absence of assignment statements suddenly becomes a plausibly good idea.

The Benefits

Aside from being ready to put that random passerbyer in the hallway in his or her place when they try to call me out, functional programming concepts have become increasingly appealing to me. Functional programming can be defined in many different ways, but you don’t necessarily need a purely functional language to implement a purely or predominantly functional program. In particular, the ideas of referential transparency and immutability are simple but powerful. In the former concept, a function with the same parameters will evaluate to the same result in any context. Stateless functions will have no side effects and therefore the possibility of a bug is diminished significantly. Couple that with consistent typing, minimal parameters, and a solid test suite – pretty soon your coworkers will be calling you “The Exterminator” because your code will have no bugs.

Immutable objects are completely thread safe and open the possibility of completely parallelizing a program. Taken together with referentially transparent functions, caching becomes a trivially simple problem.

So now imagine a real world context: You have a few dozen celery workers that all share access to a Redis server. If you implemented your own datastructure made up of nodes where each node could be individually fetched based on a distinct ID, and pointers to other nodes just become the distinct ID of those nodes that corresponded to a Redis key, you can now create a distributed datastrucuture in Redis. And if you have a proxy in front of Redis where your key space is sharded, you can have a single datastructure spread out in memory across multiple machines.

Now take it one step further: you have caching logic that backfills entities from Redis into a least recently used cache in memory. Now a copy of the datastructure is maintained in memory across multiple machines and sections of the datastructure are updated from Redis as needed.

Since the datastructure is immutable, locks aren’t necessary. At any given time, it would be gauranteed that at least one process was able to make progress in updating the datastructure with no risk of deadlock.

Now before the trolls start coming for me, I should point out that this above idea comes with many drawbacks. Since Redis is not durable, it’s possible for the entire structure to become entirely corrupted at any given time. Also, after we go through some code samples, you might notice that inserts are fairly slow. And furthermore, a purely functional datastructure generally relies on a language with garbage collection in order to remove nodes that no longer have anything referencing them. Hence, if the above example was carried through you would also have to ensure that either a time to live was associated with every node or you manually cleaned up after yourself when a node was no longer referenced. In the former case, your datastructure goes to not reliably durable to definitely not durable after a certain amount of time. In the latter case, your interactions with the tree will be slower and no longer purely functional. In all cases, a purely functional datastructure will be fairly expensive on the memory front, but the increasingly lowered costs of memory can be correlated to the rise in popularity of functional programming.

Purely Functional Red Black Trees

The exercise I went through was to implement a purely functional red black tree (but in a non-purely functional language). A red black tree is a variant of a standard binary tree that will re-balance the tree every time you insert or delete a node. Therefore the variance between best and worst case scenarios is minimized, and the time complexity of inserting and reading is O(log n). There are very informative technical writings about how to maintain a red-black tree, but the idea is essentially that a tree is assumed to be balanced, and upon inserting a new node, the new node is marked as red which denotes an unbalanced node. Adding one more unbalanced node will force the recently inserted node to trigger a rebalancing at the local parent which will propogate up with further rebalancing as necessary.

The purely functional aspect of things means that rather than inserting a new node into the tree, we create immutable copies of existing nodes all the way down the tree to the new node, and all of the unchanged nodes from the operation can still be referenced from the new copy, resulting in no destructive updates of any variable.

The Code

import uuid


class Color(object):
    RED = 0
    BLACK = 1


class RedBlackTree(object):

    def __init__(self, left, value, right, color=Color.RED):
        self._color = color
        self._left = left
        self._right = right
        self._value = value
        self._count = 1 + len(left) + len(right)
        self._node_uuid = uuid.uuid4()

    def __len__(self):
        return self._count

    @property
    def uuid(self):
        return self._node_uuid

    @property
    def color(self):
        return self._color

    @property
    def value(self):
        return self._value

    @property
    def right(self):
        return self._right

    @property
    def left(self):
        return self._left

    def blacken(self):
        if self.is_red():
            return RedBlackTree(
                self.left,
                self.value,
                self.right,
                color=Color.BLACK,
            )
        return self

    def is_empty(self):
        return False

    def is_black(self):
        return self._color == Color.BLACK

    def is_red(self):
        return self._color == Color.RED

    def rotate_left(self):
        return RedBlackTree(
            RedBlackTree(
                self.left,
                self.value,
                EmptyRedBlackTree().update(self.right.left),
                color=self.color,
            ),
            self.right.value,
            self.right.right,
            color=self.right.color,
        )

    def rotate_right(self):
        return RedBlackTree(
            self.left.left,
            self.left.value,
            RedBlackTree(
                EmptyRedBlackTree().update(self.left.right),
                self.value,
                self.right,
                color=self.color,
            ),
            color=self.left.color,
        )

    def recolored(self):
        return RedBlackTree(
            self.left.blacken(),
            self.value,
            self.right.blacken(),
            color=Color.RED,
        )

    def balance(self):
        if self.is_red():
            return self

        if self.left.is_red():
            if self.right.is_red():
                return self.recolored()
            if self.left.left.is_red():
                return self.rotate_right().recolored()
            if self.left.right.is_red():
                return RedBlackTree(
                    self.left.rotate_left(),
                    self.value,
                    self.right,
                    color=self.color,
                ).rotate_right().recolored()
            return self

        if self.right.is_red():
            if self.right.right.is_red():
                return self.rotate_left().recolored()
            if self.right.left.is_red():
                return RedBlackTree(
                    self.left,
                    self.value,
                    self.right.rotate_right(),
                    color=self.color,
                ).rotate_left().recolored()
        return self

    def update(self, node):
        if node.is_empty():
            return self
        if node.value < self.value:
            return RedBlackTree(
                self.left.update(node).balance(),
                self.value,
                self.right,
                color=self.color,
            ).balance()
        return RedBlackTree(
            self.left,
            self.value,
            self.right.update(node).balance(),
            color=self.color,
        ).balance()

    def insert(self, value):
        return self.update(
            RedBlackTree(
                EmptyRedBlackTree(),
                value,
                EmptyRedBlackTree(),
                color=Color.RED,
            )
        ).blacken()

    def is_member(self, value):
        if value < self._value:
            return self.left.is_member(value)
        if value > self._value:
            return self.right.is_member(value)
        return True


class EmptyRedBlackTree(RedBlackTree):

    def __init__(self):
        self._color = Color.BLACK

    def is_empty(self):
        return True

    def insert(self, value):
        return RedBlackTree(
            EmptyRedBlackTree(),
            value,
            EmptyRedBlackTree(),
            color=Color.RED,
        )

    def update(self, node):
        return node

    def is_member(self, value):
        return False

    @property
    def left(self):
        return EmptyRedBlackTree()

    @property
    def right(self):
        return EmptyRedBlackTree()

    def __len__(self):
        return 0

You’ll notice that assignment statements only exist in the constructor and are otherwise nowhere present. All methods called upon the tree return new copies of the data rather than updates to existing data. Now I’ll break my self-imposed rule of no assignment statements to demonstrate usage:

new_tree = EmptyRedBlackTree().insert(10)
assert isinstance(new_tree, RedBlackTree)
new_tree = new_tree.insert(11)
new_tree = new_tree.insert(12)
assert new_tree.value == 11
assert new_tree.left.value == 10
assert new_tree.right.value == 12

I’ll try to avoid going into detail about the algorithm because there are lots of other sites and videos that explain in great detail how to implement a red black tree. Hopefully the code sample above can be coupled with some of those great explanations to try and outline a working code sample. Again, the point of writing this in python was not for efficiency, but just to be able to learn how to implement a red black tree in a readable manner.

I took this a step further and put together a sloppier subclass of the above RedBlackTree in the form of a tree meant to be distributed across Redis keys. In order to serialize everything, I used the Schematics python package to create type safe objects that could be serialized and deserialized at will. Importantly, the class also implements overrides for less than and greater than along with a primary key to determine what attributes to compare relative to other elements in a similar conceptual manager to what SQL engines are doing upon storing a row.

import json
import redis
import uuid

from schematics.models import Model
from schematics.types import IntType, StringType, UUIDType

from .red_black_tree import (
    Color,
    EmptyRedBlackTree,
    RedBlackTree,
)

TTL = 300
REDIS_CLIENT = redis.from_url("redis://localhost:6379")
EMPTY_NODE = uuid.UUID(bytes=chr(0) * 16)

class ArbitraryEntity(Model):
    some_string = StringType(required=True)
    some_uuid = UUIDType(required=True)

    @classmethod
    def primary_key(cls):
        return (
            'some_uuid',
            'some_string',
        )

    def to_comparable_tuple(self):
        return tuple([getattr(self, attr) for attr in self.primary_key()])

    def __lt__(self, other_value):
        return self.to_comparable_tuple() < other_value.to_comparable_tuple()

    def __gt__(self, other_value):
        return self.to_comparable_tuple() > other_value.to_comparable_tuple()


class SerializableNode(Model):
    node_uuid = UUIDType(required=True)
    left = UUIDType(required=True)
    right = UUIDType(required=True)
    value_key = UUIDType(required=True)
    color = IntType(required=True)
    count = IntType(required=True)

    @classmethod
    def from_node(cls, node, value_key):
        return cls({
            "node_uuid": node.uuid,
            "left": EMPTY_NODE if node.left.is_empty() else node.left.uuid,
            "right": EMPTY_NODE if node.right.is_empty() else node.right.uuid,
            "value_key": value_key,
            "color": node.color,
            "count": len(node),
        })


class RedisRedBlackTree(RedBlackTree):

    def __init__(self, left_uuid, value_key, right_uuid, node_uuid, color=Color.RED, count=1):
        self._right_uuid = right_uuid
        self._value_key = value_key
        self._left_uuid = left_uuid
        self._color = color
        self._count = count
        self._node_uuid = node_uuid

    @property
    def right(self):
        if self._right_uuid == EMPTY_NODE:
            return EmptyRedBlackTree()
        return self._node_from_redis(self._right_uuid)

    @property
    def left(self):
        if self._left_uuid == EMPTY_NODE:
            return EmptyRedBlackTree()
        return self._node_from_redis(self._left_uuid)

    @property
    def value(self):
        return self._value_from_key(self._value_key)

    @staticmethod
    # @add_your_own_hardcore_cache_here
    def _value_from_key(key):
        return ArbitraryEntity(
            json.loads(
                REDIS_CLIENT.get(key)
            )
        )

    @staticmethod
    # @add_your_own_hardcore_cache_here
    def _node_from_redis(node_uuid):
        serializable_node = SerializableNode(
            json.loads(
                REDIS_CLIENT.get(str(node_uuid))
            )
        )
        return RedisRedBlackTree(
            serializable_node.left,
            serializable_node.value_key,
            serializable_node.right,
            node_uuid,
            color=serializable_node.color,
            count=serializable_node.count,
        )

    def insert(self, value):
        return self.propogate_changes_to_redis(
            super(RedisRedBlackTree, self).insert(value)
        )

    @classmethod
    def propogate_changes_to_redis(cls, root):
        cls._propogate_changes_to_redis(root)
        return cls._node_from_redis(root.uuid)

    @classmethod
    def _propogate_changes_to_redis(cls, root):

        if isinstance(root, RedisRedBlackTree) or isinstance(root, EmptyRedBlackTree):
            return

        value_key = str(uuid.uuid4())
        REDIS_CLIENT.setex(
            str(root.uuid),
            json.dumps(
                SerializableNode.from_node(
                    root,
                    value_key,
                ).to_primitive()
            ),
            TTL,
        )
        REDIS_CLIENT.setex(
            value_key,
            json.dumps(
                root.value.to_primitive()
            ),
            TTL,
        )

        cls._propogate_changes_to_redis(root.left)
        cls._propogate_changes_to_redis(root.right)


class EmptyRedisRedBlackTree(RedisRedBlackTree):

    def __init__(self):
        pass

    def insert(self, value):
        return self.propogate_changes_to_redis(
            EmptyRedBlackTree().insert(value)
        )

The End

So you might find this sort of idea practical in the gray area between throw-away datastructures in memory provided by the standard library of your programming language of choice and the highly optimized datastructures that are used in persistence layers. This might make itself useful in a context where we want to quickly and easily examine a large collection of very recent historical data for example. Or in some case where you needed to be able to query Redis.