Purely Functional Dictionaries in Python

True Story Follows

So I’ve been on the grind lately, straight writing purely functional data structures in Python…because who wouldn’t? I recently wrote about purely functional heaps and purely functional random access lists and purely functional red black trees inspired by none other than the main man, Dr. Chris Okasaki. Every day, I’m one step closer to actually using these for a practical purpose, and another workhorse in the game of data structures is a dictionary, so here I’ve put that together.

The context for these is in my previous 2 posts (referenced above), and this is just a continuation of putting together another a data structure in which every operation, including deletes, is an additive operation that returns a new root entirely to an otherwise immutable object.

Purely Functional Dictionaries

The idea here is take an input key and convert it to bits and use those to represent a binary tree in which a 0 corresponds to a left child and a 1 corresponds to a right child. The values are stored in the leaves of the tree. This idea can be expanded upon by collapsing series of nodes in which no branching occurs, and instead store a chain of bits on the node itself.

The most significant bits of the key are stored at the top. In this way, data is stored in the order of the keys, and now the tree supports range scans and what have you.

It should be noted that 1.) The representation of this structure is not in the least bit efficient since I’m using lists of booleans to represent bits (where each item in a list is effectively a pointer, and now we’re inefficient both because the pointer dramatically exceeds the size of 1 bit and we don’t get any of those sweet bitwise operations that you’d actually want), 2.) when tested on large data sets, I got segmentation faults that led me to conclude that perhaps Python is not the greatest purely functional language, and 3.) despite what you might think, I’m not a hipster, as much as I might enjoy coffee shops.

from bitarray import bitarray
import uuid


class NullObject(object):

    def __init__(self):
        self.value = None

    def __len__(self):
        return 0

    def is_empty(self):
        return True

    def __nonzero__(self):
        return False

    def iteritems(self, *args):
        return (nothing for nothing in ())


NULL_OBJECT = NullObject()


class IntegerMap(object):

    def __init__(self,
                 big_endian_k_value=-1,
                 bits_this_node=tuple(),
                 left_child=NULL_OBJECT,
                 right_child=NULL_OBJECT,
                 original_key=NULL_OBJECT,
                 value=NULL_OBJECT):

        self.left_child = left_child
        self.right_child = right_child
        self.bits_this_node = bits_this_node or []
        self.big_endian_k_value = big_endian_k_value
        self.value = value
        self.original_key = original_key

    def __nonzero__(self):
        return not self.is_empty()

    def __getitem__(self, *ordered_keys):
        return self._lookup(
            self._keys_to_bits(ordered_keys)
        )

    def __setitem__(self, key, value):
        if not isinstance(key, tuple):
            return self.__setitem__((key, ), value)
        return self._insert(
            self._keys_to_bits(key),
            value,
            key,
        )

    def __delitem__(self, key):
        if not isinstance(key, tuple):
            return self.__delitem__((key, ))

        return self._delete(
            self._keys_to_bits(key)
        )

    def __add__(self, other_int_map):
        return self.update(other_int_map)

    def __sub__(self, other_int_map):
        return self.subtract(other_int_map)

    def is_empty(self):
        return self.big_endian_k_value == -1

    def intersecting_keys(self, other_int_map):
        return self._intersection(other_int_map).keys()

    def keys(self, *base_keys):
        for key, _ in self.iteritems():
            yield key

    def values(self, *base_keys):
        for _, value in self.iteritems():
            yield value

    def iteritems(self, *base_keys):
        for key, value in self.left_child.iteritems(*base_keys):
            yield key, value

        if not isinstance(self.value, NullObject):
            yield self.original_key, self.value

        for key, value in self.right_child.iteritems(*base_keys):
            yield key, value

    def subtract(self, other_int_map):
        if self.is_empty():
            return self

        if other_int_map.is_empty():
            return self

        if self._is_leaf() and other_int_map._is_leaf():
            if (
                (
                    self.bits_this_node +
                    [False] * (other_int_map.big_endian_k_value - self.big_endian_k_value) +
                    [False] * (self.big_endian_k_value - len(self.bits_this_node) + 1)
                ) ==
                (
                    other_int_map.bits_this_node +
                    [False] * (self.big_endian_k_value - other_int_map.big_endian_k_value) +
                    [False] * (other_int_map.big_endian_k_value - len(other_int_map.bits_this_node) + 1)
                )
            ):
                return NULL_OBJECT

        if self.bits_this_node == other_int_map.bits_this_node:
            return IntegerMap(
                big_endian_k_value=max(self.big_endian_k_value, other_int_map.big_endian_k_value),
                bits_this_node=self.bits_this_node,
                left_child=(self.left_child or IntegerMap()).subtract(other_int_map.left_child),
                right_child=(self.right_child or IntegerMap()).subtract(other_int_map.right_child),
                original_key=self.original_key,
                value=self.value,
            )._collapse()

        if len(other_int_map.bits_this_node) > len(self.bits_this_node):
            return self.subtract(
                self._elongate_shorter_parent_to_direction(
                    self,
                    other_int_map,
                    min(len(self.bits_this_node), len(other_int_map.bits_this_node)),
                )
            )
        return self._elongate_shorter_parent_to_direction(
            other_int_map,
            self,
            min(len(self.bits_this_node), len(other_int_map.bits_this_node)),
        ).subtract(other_int_map)

    def _get_next_big_endian_k_delta(self, child, big_endian_k_delta):
        return (
            self.big_endian_k_value -
            len(self.bits_this_node) -
            1 -
            child.big_endian_k_value +
            big_endian_k_delta
        )

    def _lookup(self, bit_array, big_endian_k_delta=0):
        if len(bit_array) - 1 > self.big_endian_k_value + big_endian_k_delta:
            raise KeyError("case 1 Input key is not present, key bigger than others")

        if len(bit_array) == len(self.bits_this_node) + big_endian_k_delta:
            if bit_array == self.bits_this_node + [False] * big_endian_k_delta:
                return self.value
            raise KeyError("case 2 Input key not present")

        def _amend_least_significant_bits(incomplete_bit_array):
            return incomplete_bit_array + [False] * (self.big_endian_k_value - len(incomplete_bit_array) + 1)

        if len(self.bits_this_node) > len(bit_array):
            raise KeyError("case 3 Input key not present")

        if self._is_leaf() and len(bit_array) - 1 == self.big_endian_k_value:
            if bit_array == _amend_least_significant_bits(self.bits_this_node + [False] * big_endian_k_delta):
                return self.value

        if self.bits_this_node != bit_array[:len(self.bits_this_node)]:
            raise KeyError("case 5 Input key is not present")

        if bit_array[len(self.bits_this_node):][0] is False:
            if isinstance(self.left_child, NullObject):
                raise KeyError("case 6 input key is not present")
            return self.left_child._lookup(
                bit_array[1 + len(self.bits_this_node):],
                big_endian_k_delta=self._get_next_big_endian_k_delta(
                    self.left_child,
                    big_endian_k_delta,
                ),
            )
        if isinstance(self.right_child, NullObject):
            raise KeyError("case 7 input key is not present")

        return self.right_child._lookup(
            bit_array[1 + len(self.bits_this_node):],
            big_endian_k_delta=self._get_next_big_endian_k_delta(
                self.right_child,
                big_endian_k_delta,
            ),
        )

    def _get_first_differing_bit_index(self, bit_array1, bit_array2):

        def _traverse_to_min_index(current_index, max_index):
            if current_index == max_index:
                return None
            if bit_array1[:max_index][current_index] != bit_array2[:max_index][current_index]:
                return current_index
            return _traverse_to_min_index(current_index + 1, max_index)

        return _traverse_to_min_index(0, min(len(bit_array1), len(bit_array2)))

    def _insert(self, bit_array, value, original_key):
        return self.update(
            IntegerMap(
                big_endian_k_value=len(bit_array) - 1,
                bits_this_node=bit_array,
                original_key=original_key,
                value=value
            )
        )

    def _keys_to_bits(self, ordered_keys):
        return self._amend_least_significant_bits(
            reduce(
                lambda accum, bit_list: accum + bit_list,
                [self._key_to_bits(key) for key in ordered_keys]
            )
        )

    def _amend_least_significant_bits(self, bit_array):
        return bit_array + (
            [False] * (self.big_endian_k_value - len(bit_array) + 1)
        )

    def _key_to_bits(self, key):
        try:
            return self._uuid_to_bits(uuid.UUID(str(key)))
        except ValueError:
            try:
                return self._int_to_bits(int(key))
            except ValueError:
                return self._string_to_bits(key)

    def _get_non_empty(self, other_int_map):
        if self.is_empty():
            return other_int_map
        if other_int_map.is_empty():
            return self

    def _update_for_equal_top_level_bits(self, other_int_map):
        if other_int_map.bits_this_node == self.bits_this_node:

            def _amend_least_significant_bits_to_leaf(leaf):
                return IntegerMap(
                    big_endian_k_value=max(self.big_endian_k_value, other_int_map.big_endian_k_value),
                    bits_this_node=leaf.bits_this_node + [False] * abs(self.big_endian_k_value - other_int_map.big_endian_k_value),
                    left_child=leaf.left_child,
                    right_child=leaf.right_child,
                    original_key=leaf.original_key,
                    value=leaf.value
                )

            if self._is_leaf() and not other_int_map._is_leaf():
                return _amend_least_significant_bits_to_leaf(
                    self
                ).update(other_int_map)

            if not self._is_leaf() and other_int_map._is_leaf():
                return self.update(
                    _amend_least_significant_bits_to_leaf(other_int_map)
                )

            return IntegerMap(
                big_endian_k_value=max(self.big_endian_k_value, other_int_map.big_endian_k_value),
                bits_this_node=self.bits_this_node,
                left_child=(self.left_child or IntegerMap()).update(
                    other_int_map.left_child
                ) or NULL_OBJECT,
                right_child=(self.right_child or IntegerMap()).update(
                    other_int_map.right_child
                ) or NULL_OBJECT,
                original_key=other_int_map.original_key,
                value=other_int_map.value
            )._collapse()

        if len(other_int_map.bits_this_node) > len(self.bits_this_node):
            return self.update(
                self._elongate_shorter_parent_to_direction(
                    self,
                    other_int_map,
                    min(len(self.bits_this_node), len(other_int_map.bits_this_node)),
                )
            )
        return self._elongate_shorter_parent_to_direction(
            other_int_map,
            self,
            min(len(self.bits_this_node), len(other_int_map.bits_this_node)),
        ).update(other_int_map)

    def _update_for_insert_into_child(self, left_child, right_child, first_differing_bit_index):

        def _move_down_one_level(child):
            return IntegerMap(
                big_endian_k_value=max(
                    left_child.big_endian_k_value,
                    right_child.big_endian_k_value
                ) - first_differing_bit_index - 1,
                bits_this_node=child.bits_this_node[first_differing_bit_index + 1:],
                left_child=child.left_child,
                right_child=child.right_child,
                original_key=child.original_key,
                value=child.value,
            )

        return IntegerMap(
            big_endian_k_value=max(left_child.big_endian_k_value, right_child.big_endian_k_value),
            bits_this_node=self.bits_this_node[:first_differing_bit_index],
            left_child=_move_down_one_level(left_child),
            right_child=_move_down_one_level(right_child),
        )

    def update(self, other_int_map, big_endian_k_delta=0):
        if self.is_empty() or other_int_map.is_empty():
            return self._get_non_empty(other_int_map)

        def _get_int_map_for_first_differing_bit(first_differing_bit_index):
            if first_differing_bit_index is None:
                return self._update_for_equal_top_level_bits(other_int_map)

            if self.bits_this_node[first_differing_bit_index] is False:
                return self._update_for_insert_into_child(self, other_int_map, first_differing_bit_index)
            return self._update_for_insert_into_child(other_int_map, self, first_differing_bit_index)

        return _get_int_map_for_first_differing_bit(
            self._get_first_differing_bit_index(
                self.bits_this_node,
                other_int_map.bits_this_node,
            )
        )

    def _is_leaf(self):
        return isinstance(self.right_child, NullObject) and isinstance(self.left_child, NullObject)

    def _int_to_bits(self, integer):
        return [True if bit == '1' else False for bit in list(bin(integer))[2:]]

    def _uuid_to_bits(self, uuid_value):

        def _mutate(new_bitarray):
            return new_bitarray.frombytes(uuid_value.bytes) or new_bitarray

        return _mutate(bitarray()).tolist()

    def _string_to_bits(self, string):

        def _mutate(new_bitarray):
            return new_bitarray.fromstring(string) or new_bitarray

        return _mutate(bitarray()).tolist()

    def _delete(self, bit_array, big_endian_k_delta=0):
        if len(bit_array) - 1 > self.big_endian_k_value + big_endian_k_delta:
            raise KeyError("Input key is not present, key bigger than others")

        if bit_array == self.bits_this_node + [False] * big_endian_k_delta:
            return IntegerMap()

        def _branch_matches_bit_array(child, direction):
            return (
                not isinstance(child, NullObject) and
                self.bits_this_node + [direction] + child.bits_this_node == bit_array
            )

        def _collapsed_current_node(child_to_keep, direction_to_collapse_from):
            return IntegerMap(
                big_endian_k_value=self.big_endian_k_value,
                bits_this_node=self.bits_this_node + [direction_to_collapse_from] + child_to_keep.bits_this_node,
                left_child=child_to_keep.left_child,
                right_child=child_to_keep.right_child,
                original_key=child_to_keep.original_key,
                value=child_to_keep.value
            )

        if _branch_matches_bit_array(self.left_child, False):
            return _collapsed_current_node(self.right_child, True)

        if _branch_matches_bit_array(self.right_child, True):
            return _collapsed_current_node(self.left_child, False)

        def _convert_none_to_max(first_differing_bit_index):
            if first_differing_bit_index is None:
                return len(self.bits_this_node)
            return first_differing_bit_index

        def _delete_left_or_right_from_first_differing_bit(first_differing_bit_index):

            def _current_node_with_children(left_child, right_child):
                return IntegerMap(
                    big_endian_k_value=self.big_endian_k_value,
                    bits_this_node=self.bits_this_node,
                    left_child=left_child,
                    right_child=right_child,
                    original_key=self.original_key,
                    value=self.value,
                )

            def _child_with_delete(child):
                return child._delete(
                    bit_array[first_differing_bit_index + 1:],
                    big_endian_k_delta=self._get_next_big_endian_k_delta(
                        child,
                        big_endian_k_delta,
                    ),
                )

            if bit_array[first_differing_bit_index] is False:
                return _current_node_with_children(
                    _child_with_delete(self.left_child),
                    self.right_child,
                )
            return _current_node_with_children(
                self.left_child,
                _child_with_delete(self.right_child),
            )

        return _delete_left_or_right_from_first_differing_bit(
            _convert_none_to_max(
                self._get_first_differing_bit_index(
                    self.bits_this_node,
                    bit_array,
                )
            )
        )

    def _collapse(self):
        if self.is_empty():
            return self
        if (
            isinstance(self.right_child, NullObject) and
            not isinstance(self.left_child, NullObject)
        ):
            return IntegerMap(
                big_endian_k_value=self.big_endian_k_value,
                bits_this_node=self.bits_this_node + [False] + self.left_child.bits_this_node,
                left_child=self.left_child.left_child,
                right_child=self.left_child.right_child,
                original_key=self.left_child.original_key,
                value=self.left_child.value,
            )._collapse()
        if (
            isinstance(self.left_child, NullObject) and
            not isinstance(self.right_child, NullObject)
        ):
            return IntegerMap(
                big_endian_k_value=self.big_endian_k_value,
                bits_this_node=self.bits_this_node + [True] + self.right_child.bits_this_node,
                left_child=self.right_child.left_child,
                right_child=self.right_child.right_child,
                original_key=self.right_child.original_key,
                value=self.right_child.value,
            )._collapse()
        if self._is_leaf() and isinstance(self.value, NullObject):
            return IntegerMap()
        return self

    @staticmethod
    def _elongate_shorter_parent_to_direction(taller_parent, shorter_parent, lowest_length):

        def _elongate(shorter_parent, direction, big_endian_k_value, lowest_length):

            def _child_from_parent():
                return IntegerMap(
                    big_endian_k_value=big_endian_k_value - lowest_length - 1,
                    bits_this_node=shorter_parent.bits_this_node[lowest_length + 1:],
                    left_child=shorter_parent.left_child,
                    right_child=shorter_parent.right_child,
                    original_key=shorter_parent.original_key,
                    value=shorter_parent.value
                )
            if direction is True:
                return IntegerMap(
                    big_endian_k_value=big_endian_k_value,
                    bits_this_node=shorter_parent.bits_this_node[:lowest_length],
                    left_child=NULL_OBJECT,
                    right_child=_child_from_parent(),
                )
            return IntegerMap(
                big_endian_k_value=big_endian_k_value,
                bits_this_node=shorter_parent.bits_this_node[:lowest_length],
                left_child=_child_from_parent(),
                right_child=NULL_OBJECT,
            )

        return _elongate(
            shorter_parent,
            shorter_parent.bits_this_node[lowest_length:][0],
            max(taller_parent.big_endian_k_value, shorter_parent.big_endian_k_value),
            lowest_length,
        )

    def _intersection(self, other_int_map):
        if self.is_empty():
            return self
        if other_int_map.is_empty():
            return IntegerMap()

        if self._is_leaf() and other_int_map._is_leaf():

            if (
                (
                    self.bits_this_node +
                    [False] * (other_int_map.big_endian_k_value - self.big_endian_k_value) +
                    [False] * (self.big_endian_k_value - len(self.bits_this_node) + 1)
                ) ==
                (
                    other_int_map.bits_this_node +
                    [False] * (self.big_endian_k_value - other_int_map.big_endian_k_value) +
                    [False] * (other_int_map.big_endian_k_value - len(other_int_map.bits_this_node) + 1)
                )
            ):
                return self

        if self.bits_this_node == other_int_map.bits_this_node:
            return IntegerMap(
                big_endian_k_value=max(self.big_endian_k_value, other_int_map.big_endian_k_value),
                bits_this_node=self.bits_this_node,
                left_child=(self.left_child or IntegerMap())._intersection(other_int_map.left_child),
                right_child=(self.right_child or IntegerMap())._intersection(other_int_map.right_child),
                original_key=self.original_key,
                value=self.value,
            )._collapse()

        if len(other_int_map.bits_this_node) > len(self.bits_this_node):
            return self._intersection(
                self._elongate_shorter_parent_to_direction(
                    self,
                    other_int_map,
                    min(len(self.bits_this_node), len(other_int_map.bits_this_node)),
                )
            )
        return self._elongate_shorter_parent_to_direction(
            other_int_map,
            self,
            min(len(self.bits_this_node), len(other_int_map.bits_this_node)),
        )._intersection(other_int_map)