Purely Functional Data Structures in Python

True Story Follows

In my own personal goal to become a better programmer, I’m frequently searching for the newest and hottest techmology to be privy to, or perhaps find some earth shatteringly erudite YouTube video that would allow me to passively become a stronger developer, but I’ve come to find more and more that this largely does not work, at least for the purposes of becoming a better craftsman. As what should have been clear, the most fruitful sources of knowledge would come in the form of information that would take effort to digest. One such piece was to revisit one of my most hardcore college professor’s work, Purely Functional Data Structures by Dr. Chris Okasaki.

To draw a parallel between mental prowess and weight lifting, Dr. Okasaki could easily bench press 635 lbs if the weight room were a metaphor for programming. He would literally rep out 405 metaphorical lbs routinely in class just to demonstrate a mind blowing concept.

My initial goal was to take code written in Standard ML to generate data structures and translate it into Python so that the code might become easier to understand. Typically these data structures achieve some operation in constant or logarithmic time, but in Python the program becomes overwhelmingly slow because we find that the run time efficiency is proportional to some large constant. Even so, I have some ideas to put these data structures to use in a production environment to bring about a substantial benefit, but that proof will have to come in a follow-on blog post. For now, perhaps some thought provoking concepts can be entertained.

The General Idea Behind Purely Functional Data Structures

In order to justify even using a purely functional data in the first place in an imperative language (such as Python), especially in a case where those common data structures are built in data types, we need to capitalize on the benefits a purely functional structure introduces. Most notably, every purely functional data structure is completely immutable, so any sort of update will return a new object entirely, and the original data structure will still exist until its reference count drops to 0. This means that caching becomes trivially easy, and meta data about a structure, such as its length, can also easily calculated no more than 1 time.

The “Hello World” of Purely Functional Data Structures

If we have a linked list of nodes, and I want to amend a node to the head of list, I would simply return a new node whose next value is the current head. If I wanted to insert a node into the middle of the list, I would return a copy of the 1st node that instead points to a copy of the 2nd node that instead points to a copy of the 3rd node and so on until we reached the middle of the list where the update occurred, and at that point we point to a new node that points to the next element in the original list. This doesn’t seem practical at all until we start representing nodes such operations occur against the datastructure occur in logarithmic time or better.

Red Black Trees

The first data structure I tried out was a red black tree, which allows insertion, deletion, and retrieval in logarithmic time. I actually wrote a separate blog post about that topic you want to see the code.

Random Access Lists

My explanations won’t do justice, so be sure to actually read Purely Functional Data Structures if you want more than the glossed over details. To explain my own understanding in depth is really just repeating Dr. Okasaki’s text but in a less accurate and less informative manner. But if I were in a situation kind of like that one in Swordfish where Wolverine has a gun to his head, and I was asked to explain purely functional random access lists, the general idea is that you have a two-dimensional data structure where you can traverse to the right in logarithmic time and traverse downward in logarithmic time, thereby giving you an overall logarithmic run time on operations.

This can be done by creating a list of full binary search trees, where only the first two trees are allowed to be of equal rank (i.e. two trees that each have exactly 1 node or exactly 3 nodes or exactly 7 nodes). In this case, by adding an additional node to those two trees, they can be joined into another single binary search tree. If each insertion corresponds to an index in a list, then the tree is created in a deterministic format independent of the values at each node. To retrieve a node that corresponds to an index, scan the roots of the trees (which is a linear search of those nodes, but logarithmic relative to the size of the entire list), until you’ve found the correct tree. Now binary search the tree for the index, and this is also logarithmic time.

Inserting an item is constant time because this only requires either adding a single node to create a tree of size 1 or adding a single node that becomes the new root of two trees of equal rank. Searching for an item is logarithmic time, as the above explanation tries to outline, and updating an item is also logarithmic time because after we create a new node to replace an existing node in the tree, we logarithmically back trace up to the corresponding root and back to the end of the top level linked list, creating new node copies as we go.

class NullObject(object):

    def __init__(self):
        self.index = -1
        self.value = None

    @property
    def left_child(self):
        return NullObject()

    @property
    def right_child(self):
        return NullObject()

    @property
    def right_sibling(self):
        return NullObject()

    @property
    def head(self):
        return NullObject()

    def __len__(self):
        return 0

    def __setitem__(self, index, value):
        raise IndexError("Assignment index out of range")

    def __getitem__(self, index):
        raise IndexError("list index out of range")


class Node(object):

    def __init__(self,
                 index,
                 value,
                 left_child=NullObject(),
                 right_child=NullObject()):

        self.index = index
        self.left_child = left_child
        self.right_child = right_child
        self.value = value
        self._length = len(self.left_child) + len(self.right_child) + 1

    def __getitem__(self, index):
        if index == self.index:
            return self.value
        if index <= self.right_child.index:
            return self.right_child[index]
        return self.left_child[index]

    def __setitem__(self, index, value):
        if index == self.index:
            return Node(
                index,
                value,
                left_child=self.left_child,
                right_child=self.right_child
            )
        if index <= self.right_child.index:
            return Node(
                self.index,
                self.value,
                left_child=self.left_child,
                right_child=self.right_child.__setitem__(index, value)
            )
        return Node(
            self.index,
            self.value,
            left_child=self.left_child.__setitem__(index, value),
            right_child=self.right_child,
        )

    def __len__(self):
        return self._length


class RandomAccessList(object):

    def __init__(self, head=NullObject(), right_sibling=NullObject()):
        self.head = head
        self.right_sibling = right_sibling

    def __getitem__(self, index):
        if index > self.right_sibling.head.index:
            return self.head[index]
        return self.right_sibling[index]

    def __setitem__(self, index, value):
        if index > self.right_sibling.head.index:
            return RandomAccessList(
                head=self.head.__setitem__(index, value),
                right_sibling=self.right_sibling
            )
        return RandomAccessList(
            head=self.head,
            right_sibling=self.right_sibling.__setitem__(index, value),
        )

    def append(self, value):
        if self._first_two_trees_equal_height():
            return RandomAccessList(
                head=Node(
                    index=self.head.index + 1,
                    value=value,
                    left_child=self.head,
                    right_child=self.right_sibling.head,
                ),
                right_sibling=self.right_sibling.right_sibling
            )
        else:
            return RandomAccessList(
                head=Node(
                    index=self.head.index + 1,
                    value=value,
                ),
                right_sibling=self,
            )

    def _first_two_trees_equal_height(self):
        return len(self.right_sibling.head) == len(self.head) and len(self.right_sibling.head) > 0

    def __nonzero__(self):
        return not isinstance(self.head, NullObject)

    def __len__(self):
        return len(self.head) + len(self.right_sibling)

Skew Binomial Queues

The white paper I read from to understand these concepts can be found here: Optimal Purely Functional Priority Queues by Chris “The Running Man” Okasaki and Gerth Brodal.

Now we get to our first data structure that I would predict to be practical in my own projects where I’m not actually using a functional language. At this point this is all a theory, so I won’t go into an explanation as to why too deeply, but consider how often the average developer / hacker uses a priority queue: almost never. The built in data types are fast enough not to need these when working in process, and if a large enough dataset is being used to warrant a database, we’re probably just going to use the database to query for results.

With that in mind, there are many cases where a particular redundant database query is in fact querying “the next batch of items I need to work with”. This starts to sound like – get this – a priority queue, which if used in a SQL query’s place would surface the benefits of purely functional data structures (immutability, caching, parallelism, etc).

A skew binomial queue allows for enqueueing in constant time, peeking in constant time, merging in constant time, and dequeueing in logarithmic time.

It took me a few days to actually figure the whole thing out, and in the process it took me back to the days of Dr. Okasaki’s homework problems where the breakthroughs happened through recursive leaps of faith and the solutions seemed trivially easy after the fact.

Now the idea, also far less informative than Dr. Okasaki’s writing because I’m distilling it down into a single blog post, is that we start with a Binomial Heap where a tree with rank N has nodes that are allowed up to N children. This means that when two trees of equal rank are combined, that can be done in a single operation because one tree can now just link to the second tree from the root node, and this idea can be applied recursively. With each combination, the highest priority tree stays at the top.

Now, skew binomial heaps don’t need to be complete binomial heaps. Dr. Okasaki makes the distinction between simple links, type A skew links, and type B skew links, but all links can be generically described as creating a new skew binomial heap from an input list of up to 3 skew binomial heaps. Given the input heaps, the highest priority heap will be the root. Sort the remaining heaps by rank, lowest to highest, and those heaps will be added to the left side of the new root heap’s list of children. Of the child heaps that we just sorted, the rank of the new heap will be the right most child’s rank + 1.

Now apply the ideas from random access lists to create a skew binomial queue. If we have a list of those skew binomial heaps, we can create a rule where the first two heaps in the list are allowed to be equal rank, but the others are not and are sorted in ascending order by rank (like a binary number where the next index corresponds to a higher power of 2). Then, adding an item to the first heap in the list might cause the first two heaps to be merged. Since only the first two heaps are allowed to be equal rank, this operation is guaranteed to not cascade. Note that the rank of the tree, not the values present in the tree, determine the order in the list.

To meld two skew binomial queues, it’s similar in nature to adding 2 binary numbers, where you might end up having to carry a “1” log N number of times. Each index in the base 2 numbering system corresponds to a heap’s rank, and we still allow for the first two heaps to be of equal rank.

To dequeue from the skew binomial queue, scan the roots for the highest priority value. Pop the root node from this heap to return that value to the caller. Now we need to reconstruct the data structure to put it back into a valid state. The heap that you popped from is removed from the list of heaps. The children of the node you popped can be casted as a skew binomial queue since the list of children is in itself a valid skew binomial queue. With the casted skew binomial queue, merge that with the original queue that you just popped from.

At this point, you can probably stop and you have a reasonable data structure. Insertion is O(1), melding is O(log N), peeking is O(log N), and dequeueing is O(log N). To make peeking take O(1), you could simply keep track of the root nodes in the skew binomial queue as you manipulate them and internally keep the highest priority value between thw two. This is taking advantage of the data structure’s immutability, but as I found out while trying to follow Dr. Okasaki’s instructions, this was the wrong way to go about making that operation take constant time.

The reason this is not “correct” is because the O(1) peek concept is a building block to making meld take O(1) time as well. Instead, you should create a new data structure entirely, that when thought of as a whole, creates another heap, but you actually don’t need to even think about the structure entirely if you just take a recursive leap of faith.

Simply put, create a single node that acts as an entry point to a priority queue. This node has a single value and a priority queue (two distinctly different data types!). When you enqueue a value, if this value is higher priority than the current highest priority value, store that value in the node, and push the current highest priority value into the priority queue. If it is a lower priority, push the value into the queue and keep the currently highest priority value as is. In other words, keep the highest priority value at bay from ever entering the queue! Follow this train of thought for dequeue and meld.

Finally, we come to bootstrapped skew binomial queues, which allows us to turn meld into an O(1) operation. Again, read Dr. Okasaki’s work for the best explanation, but I ended up having to read both Dr. Okasaki’s book and his white paper to finally “get it”. When we create a bootstrapped binomial queue, we don’t modify the existing skew binomial queue at all. Like the above concept of a global root, we create a layer on top of what we already have and trust that it works. This will also require a recursive leap of faith.

Create a bootstrapped skew binomial node that contains the highest priority value and a primitive priority queue that contains object types of itself (the bootstrapped skew binomial node). Boom, I’m done, there it is. I can’t fully articulate the datastructure, but it works. Melding two of these types takes O(1) time.

class NullObject(object):

    def __init__(self):
        self.value = None

    @property
    def children(self):
        return []

    @property
    def head(self):
        return NullObject()

    def __len__(self):
        return 0

    @property
    def rank(self):
        return -1

    def is_empty(self):
        return True


class SkewBinomialHeap(object):

    def __init__(self, value, rank=0, children=tuple()):
        self.value = value
        self.rank = rank
        self.children = children
        self._length = 1 + sum([len(child) for child in children])

    @classmethod
    def create_new(cls, value):
        return SkewBinomialHeap(value)

    def peek(self):
        return self.value

    def __len__(self):
        return self._length

    @classmethod
    def link(cls, *heaps):
        """ This method accounts for both simple and skewed links. """
        return SkewBinomialHeap(
            cls._highest_priority_heap(*heaps).peek(),
            rank=heaps[-1].rank + 1,
            children=cls._lower_priority_heaps(*heaps) + list(cls._highest_priority_heap(*heaps).children),
        )

    @classmethod
    def _highest_priority_heap(cls, *heaps):
        return sorted(
            heaps,
            key=lambda heap: heap.peek()
        )[0]

    @classmethod
    def _lower_priority_heaps(cls, *heaps):
        return sorted(
            sorted(
                heaps,
                key=lambda heap: heap.peek()
            )[1:],
            key=lambda heap: heap.rank
        )


class SkewBinomialQueue(object):

    def __init__(self, heap_head=NullObject(), right_sibling=NullObject()):
        self.heap_head = heap_head
        self.right_sibling = right_sibling

    def enqueue(self, value):
        if self._first_two_trees_equal_rank():
            return SkewBinomialQueue(
                heap_head=SkewBinomialHeap.link(
                    SkewBinomialHeap.create_new(value),
                    self.heap_head,
                    self.right_sibling.heap_head,
                ),
                right_sibling=self.right_sibling.right_sibling,
            )

        return SkewBinomialQueue(
            heap_head=SkewBinomialHeap.create_new(value),
            right_sibling=self,
        )

    def dequeue(self):
        def _merge_popped_queue_with_remaining_queue(popped_queue, remaining_queue):
            return popped_queue.peek(), reduce(
                SkewBinomialQueue.meld,
                map(
                    lambda heap: SkewBinomialQueue(heap_head=heap),
                    [child for child in popped_queue.heap_head.children if child.rank > 0],
                ) + [remaining_queue]
            )._bulk_insert([child.value for child in popped_queue.heap_head.children if child.rank == 0])
        return _merge_popped_queue_with_remaining_queue(*self._pop_highest_priority_queue())

    def _pop_highest_priority_queue(self):
        if isinstance(self.right_sibling, NullObject):
            return self, SkewBinomialQueue()
        if (
            isinstance(self.right_sibling, NullObject) or
            isinstance(self.right_sibling.heap_head, NullObject) or
            self.heap_head.peek() < self.right_sibling.peek()
        ):
            return (SkewBinomialQueue(heap_head=self.heap_head), self.right_sibling)

        def distribute_popped_and_original(popped_queue, original_queue):
            return (
                popped_queue,
                SkewBinomialQueue(heap_head=self.heap_head, right_sibling=original_queue)
            )

        return distribute_popped_and_original(*self.right_sibling._pop_highest_priority_queue())

    def _bulk_insert(self, values):
        if not values:
            return self
        return self.enqueue(values[0])._bulk_insert(values[1:])

    def _first_two_trees_equal_rank(self):
        if isinstance(self.heap_head, NullObject):
            return False
        if isinstance(self.right_sibling, NullObject):
            return False
        return self.heap_head.rank == self.right_sibling.heap_head.rank

    def peek(self):
        if isinstance(self.heap_head, NullObject):
            raise IndexError("Queue is empty")
        if self.right_sibling.is_empty():
            return self.heap_head.value
        return min(self.heap_head.value, self.right_sibling.peek())

    def __len__(self):
        return len(self.heap_head) + len(self.right_sibling)

    @classmethod
    def meld(cls, q1, q2):

        if q1.is_empty():
            return q2
        if q2.is_empty():
            return q1

        def meld_queues(accum, forest_one_tree):
            if accum.heap_head.rank == forest_one_tree.heap_head.rank:
                return SkewBinomialQueue.meld(
                    accum.right_sibling,
                    SkewBinomialQueue(
                        heap_head=SkewBinomialHeap.link(
                            accum.heap_head,
                            forest_one_tree.heap_head
                        )
                    )
                )
            if forest_one_tree.heap_head.rank < accum.heap_head.rank:
                return SkewBinomialQueue(
                    heap_head=forest_one_tree.heap_head,
                    right_sibling=accum
                )
            return SkewBinomialQueue(
                heap_head=accum.heap_head,
                right_sibling=forest_one_tree
            )

        return reduce(
            meld_queues,
            sorted(
                q1._as_individual_forests() + q2._as_individual_forests(),
                key=lambda skew_bin_q: skew_bin_q.heap_head.rank,
                reverse=True
            )
        )

    def _as_individual_forests(self):
        if isinstance(self.heap_head, NullObject):
            return []
        if isinstance(self.right_sibling, NullObject):
            return [SkewBinomialQueue(heap_head=self.heap_head)]
        return (
            [SkewBinomialQueue(heap_head=self.heap_head)] +
            self.right_sibling._as_individual_forests()
        )

    def is_empty(self):
        return isinstance(self.heap_head, NullObject)

    def __lt__(self, other_binomial_queue):
        if other_binomial_queue.is_empty():
            return True
        if self.is_empty():
            return False
        return self.heap_head.value < other_binomial_queue.heap_head.value


class RootedSkewBinomialQueue(object):

    def __init__(self, highest_priority_value=NullObject(), primitive_priority_queue=SkewBinomialQueue()):
        self.highest_priority_value = highest_priority_value
        self.primitive_priority_queue = primitive_priority_queue

    def is_empty(self):
        return isinstance(self.highest_priority_value, NullObject)

    def enqueue(self, value):
        if self.is_empty():
            return RootedSkewBinomialQueue(
                highest_priority_value=value
            )

        if self.highest_priority_value <= value:
            return RootedSkewBinomialQueue(
                self.highest_priority_value,
                self.primitive_priority_queue.enqueue(value)
            )

        return RootedSkewBinomialQueue(
            value,
            self.primitive_priority_queue.enqueue(
                self.highest_priority_value
            )
        )

    def dequeue(self):
        if self.primitive_priority_queue.is_empty():
            return self.highest_priority_value, RootedSkewBinomialQueue()

        return self.highest_priority_value, RootedSkewBinomialQueue(
            *self.primitive_priority_queue.dequeue()
        )

    @classmethod
    def meld(cls, rooted_q1, rooted_q2):

        def merge_queues_with_root(highest_priority_value, lower_priority_queue):
            return RootedSkewBinomialQueue(
                highest_priority_value,
                SkewBinomialQueue.meld(
                    rooted_q1.primitive_priority_queue,
                    rooted_q2.primitive_priority_queue
                ).enqueue(lower_priority_queue.highest_priority_value)
            )

        if rooted_q1.highest_priority_value <= rooted_q2.highest_priority_value:
            return merge_queues_with_root(rooted_q1.highest_priority_value, rooted_q2)
        return merge_queues_with_root(rooted_q2.highest_priority_value, rooted_q1)

    def peek(self):
        if self.is_empty():
            raise IndexError("Queue is empty")
        return self.highest_priority_value


class BootstrappedSkewBinomialQueue(object):

    def __init__(self, highest_priority_value=NullObject(), primitive_priority_queue=SkewBinomialQueue(), length=0):
        self.highest_priority_value = highest_priority_value
        self.primitive_priority_queue = primitive_priority_queue
        self._length = length

    def is_empty(self):
        return isinstance(self.highest_priority_value, NullObject)

    def enqueue(self, value):
        return BootstrappedSkewBinomialQueue.meld(
            BootstrappedSkewBinomialQueue(
                value,
                length=1
            ),
            self,
        )

    def dequeue(self):
        if self.is_empty():
            raise IndexError("Queue is empty")

        if self.primitive_priority_queue.is_empty():
            return self.highest_priority_value, BootstrappedSkewBinomialQueue()

        def _bump_global_root(highest_priority_bootstrapped_queue, updated_primitive_queue):
            return self.highest_priority_value, BootstrappedSkewBinomialQueue(
                highest_priority_bootstrapped_queue.highest_priority_value,
                SkewBinomialQueue.meld(
                    highest_priority_bootstrapped_queue.primitive_priority_queue,
                    updated_primitive_queue
                ),
                length=len(self) - 1
            )

        return _bump_global_root(*self.primitive_priority_queue.dequeue())

    @classmethod
    def meld(cls, bootstrapped_q1, bootstrapped_q2):
        if bootstrapped_q1.is_empty():
            return bootstrapped_q2
        if bootstrapped_q2.is_empty():
            return bootstrapped_q1

        if bootstrapped_q1.peek() <= bootstrapped_q2.peek():
            return BootstrappedSkewBinomialQueue(
                bootstrapped_q1.peek(),
                bootstrapped_q1.primitive_priority_queue.enqueue(
                    bootstrapped_q2
                ),
                length=len(bootstrapped_q1) + len(bootstrapped_q2),
            )
        return BootstrappedSkewBinomialQueue(
            bootstrapped_q2.peek(),
            bootstrapped_q2.primitive_priority_queue.enqueue(
                bootstrapped_q1
            ),
            length=len(bootstrapped_q1) + len(bootstrapped_q2),
        )

    def peek(self):
        if self.is_empty():
            raise IndexError("Queue is empty")
        return self.highest_priority_value

    def __lt__(self, other_bootstrapped_q):
        if other_bootstrapped_q.is_empty():
            return True
        if self.is_empty():
            return False
        return self.peek() <= other_bootstrapped_q.peek()

    def __len__(self):
        return self._length