Sans Pareil Technologies, Inc.

Key To Your Business

Lab 13 - Binary Search Tree

In this exercise we will implement a binary search tree (unbalanced) using std::unique_ptr to manage dynamically allocated node instances stored in the tree.

Node


We will use a Node structure that will store the data for the tree element, as well as std::unique_ptr instances to the left and right nodes in the tree. We will also store a raw pointer to the parent node to enable tree traversal.

Iterators


We will use a BaseIterator class that inherits from std::iterator and will be aliases to iterator and const_iterator as appropriate for the linked list. Since we have a previous link, this iterator will be bi-directional.

using iterator = BaseIterator<Data>;
using const_iterator = BaseIterator<Data const>;

BinarySearchTree.h

Declaration of the binary search tree (unbalanced) implementation.

#pragma once
#include <memory>
#include <functional>
#include <stdexcept>

namespace csc240
{
  template <typename Data, typename Comparator = std::less<Data>>
  class BinarySearchTree
  {
    struct Node;

  public:
#include "BSTIterator.hpp"

    using iterator = BaseIterator<Data,Comparator>;
    using const_iterator = BaseIterator<Data const,Comparator>;
    using reverse_iterator = std::reverse_iterator<iterator>;
    using const_reverse_iterator = std::reverse_iterator<const_iterator>;

    BinarySearchTree() = default;

    template <typename Iterator>
    BinarySearchTree( Iterator first, Iterator last ) { assign( first, last ); }

    BinarySearchTree( const std::initializer_list<Data>& list ) :
      BinarySearchTree( list.begin(), list.end() ) {}

    const Data& root();
    const Data& root() const
    {
      return const_cast<BinarySearchTree*>( this )->root();
    }

    const Data& leftMost();
    const Data& leftMost() const
    {
      return const_cast<BinarySearchTree*>( this )->leftMost();
    }

    const Data& rightMost();
    const Data& rightMost() const
    {
      return const_cast<BinarySearchTree*>( this )->rightMost();
    }

    bool empty() const noexcept { return !rootNode; }
    bool empty() noexcept { return !rootNode; }

    bool exists( const Data& data ) noexcept
    {
      return exists( data, rootNode.get() );
    }

    bool exists( const Data& data ) const noexcept
    {
      return const_cast<BinarySearchTree*>( this )->exists( data );
    }

    std::size_t size() const noexcept { return count; }
    std::size_t size() noexcept { return count; }

    template<typename InputIterator>
    void assign( InputIterator first, InputIterator last );

    void assign( const std::initializer_list<Data>& list )
    {
      assign( list.begin(), list.end() );
    }

    iterator find( const Data& data ) noexcept
    {
      return find( data, rootNode.get() );
    }

    const_iterator find( const Data& data ) const noexcept
    {
      auto iter = const_cast<BinarySearchTree*>( this )->find( data );
      return const_iterator{ iter.node };
    }

    iterator emplace( Data&& data );

    iterator remove( const Data& data ) { return remove( data, rootNode.get() ); }

    void clear();

    iterator begin() noexcept;
    const_iterator cbegin() const noexcept;

    iterator end() noexcept { return iterator{ nullptr }; }
    const_iterator cend() const noexcept { return const_iterator{ nullptr }; }

    reverse_iterator rbegin() noexcept;
    const_reverse_iterator crbegin() const noexcept;

    reverse_iterator rend() noexcept { return std::make_reverse_iterator( begin() ); }
    const_reverse_iterator crend() const noexcept
    {
      return std::make_reverse_iterator( cbegin() );
    }

  private:
    struct Node
    {
      using Ptr = std::unique_ptr<Node>;

      explicit Node( Data&& data ) : data{ std::move( data ) }, left{}, right{} {}

      Node( const Node& ) = delete;
      Node& operator=( const Node& ) = delete;

      Data data;
      Ptr left;
      Ptr right;
      Node* parent = nullptr;
    };

    bool exists( const Data& data, const Node* node ) noexcept;
    iterator find( const Data& data, Node* node ) noexcept;
    iterator remove( const Data& data, Node* node );

    Comparator comparator{};
    typename Node::Ptr rootNode;
    mutable typename Node::Ptr sentinel = std::make_unique<Node>( Data{} );
    Node* first = nullptr;
    Node* last = nullptr;
    std::size_t count = 0;
  };

  template <typename Data, typename Comparator>
  bool operator==( const BinarySearchTree<Data, Comparator>& lhs,
    const BinarySearchTree<Data, Comparator>& rhs )
  {
    return ( lhs.size() == rhs.size() ) &&
      std::equal( lhs.cbegin(), lhs.cend(), rhs.cbegin() );
  }

  template <typename Data, typename Comparator>
  bool operator!=( const BinarySearchTree<Data, Comparator>& lhs,
    const BinarySearchTree<Data, Comparator>& rhs )
  {
    return !( lhs == rhs );
  }

#include "private/BSTImpl.h"
}

BSTIterator.hpp

The iterator for the binary search tree that has been extracted into a separate include file to make reading the public interface for the binary search tree easier.

template <typename T, typename C>
struct BaseIterator : std::iterator<std::bidirectional_iterator_tag, T>
{
  BaseIterator() = delete;
  BaseIterator( const BaseIterator& ) = default;

  BaseIterator& operator++()
  {
    if ( node->right )
    {
      if ( node->right->left ) node = nextLeft( node->right.get() );
      else node = node->right.get();
    }
    else if ( node->parent )
    {
      node = nextParent( node );
    }
    else node = nullptr;

    return *this;
  }

  BaseIterator operator++( int )
  {
    BaseIterator temp{ node };
    operator++();
    return temp;
  }

  BaseIterator& operator--()
  {
    if ( !node ) return *this;
    if ( node->left )
    {
      if ( node->left->right ) node = previousRight( node->left.get() );
      else node = node->left.get();
    }
    else node = node->parent;

    return *this;
  }

  BaseIterator operator--( int )
  {
    BaseIterator temp{ node };
    operator--();
    return temp;
  }

  bool operator==( const BaseIterator& rhs ) { return node == rhs.node; }
  bool operator!=( const BaseIterator& rhs ) { return !operator==( rhs ); }

  T& operator*() { return node->data; }
  T* operator->() { return &( node->data ); }

private:
  bool visited( Node* current )
  {
    return ( node == current ) || comparator( node->data, current->data );
  }

  Node* nextLeft( Node* current )
  {
    if ( !current ) return current;
    Node* temp = current;
    while ( temp->left && comparator( node->data, temp->left->data ) )
    {
      temp = temp->left.get();
    }

    return temp;
  }

  Node* nextParent( Node* current )
  {
    if ( !current ) return current;
    Node* temp = current->parent;
    while ( temp && comparator( temp->data, node->data ) )
    {
      temp = temp->parent;
    }

    return temp;
  }

  Node* previousRight( Node* current )
  {
    if ( !current ) return current;
    Node* temp = current;

    while ( temp->right && comparator( temp->right->data, node->data ) )
    {
      temp = temp->right.get();
    }

    return temp;
  }

  explicit BaseIterator( Node* node ) : node{ node } {}

  Node* node;
  C comparator;

  friend class BinarySearchTree<Data,Comparator>;
};

BSTImpl.h

The include file in which the more complex functions for the binary search tree are implemented.

template <typename Data, typename Comparator>
const Data& BinarySearchTree<Data, Comparator>::root()
{
  if ( ! rootNode ) throw std::out_of_range( "Empty tree!" );
  return rootNode.get()->data;
}

template <typename Data, typename Comparator>
const Data& BinarySearchTree<Data,Comparator>::leftMost()
{
  if ( ! rootNode ) throw std::out_of_range( "Empty tree!" );
  if ( first ) return first->data;

  first = rootNode.get();
  while ( first->left ) first = first->left.get();
  return first->data;
}

template <typename Data, typename Comparator>
const Data& BinarySearchTree<Data,Comparator>::rightMost()
{
  if ( ! rootNode ) throw std::out_of_range( "Empty tree!" );
  if ( last ) return last->data;

  last = rootNode.get();
  while ( last->right ) last = last->right.get();
  return last->data;
}

template <typename Data, typename Comparator>
template<typename InputIterator>
void BinarySearchTree<Data,Comparator>::assign(
  InputIterator first, InputIterator last )
{
  clear();

  for ( ; first != last; ++first )
  {
    auto data = Data{ *first };
    emplace( std::move( data ) );
  }
}

template <typename Data, typename Comparator>
bool BinarySearchTree<Data,Comparator>::exists( const Data& data, const Node* node ) noexcept
{
  if ( ! node ) return false;
  if ( node->data == data ) return true;

  return ( comparator( data, node->data ) ) ?
    exists( data, node->left.get() ) : exists( data, node->right.get() );
}

template <typename Data, typename Comparator>
auto BinarySearchTree<Data, Comparator>::find( const Data& data, Node* node ) noexcept -> iterator
{
  if ( ! node ) return end();
  if ( node->data == data ) return iterator{ node };

  return ( comparator( data, node->data ) ) ?
    find( data, node->left.get() ) : find( data, node->right.get() );
}

template <typename Data, typename Comparator>
auto BinarySearchTree<Data,Comparator>::emplace( Data&& data ) -> iterator
{
  auto node = std::make_unique<Node>( std::move( data ) );

  if ( ! rootNode )
  {
    rootNode = std::move( node );
    first = rootNode.get();
    last = rootNode.get();
    ++count;
    return iterator{ first };
  }

  Node* current = rootNode.get();
  Node* parent = current;

  if ( current->data == node->data ) return iterator{ parent };

  while ( current )
  {
    parent = current;
    current = ( comparator( current->data, node->data ) ) ?
      current->right.get() :
      current->left.get();
  }

  node->parent = parent;
  if ( comparator( parent->data, node->data ) )
  {
    if ( last && comparator( last->data, node->data ) ) last = node.get();
    parent->right = std::move( node );
  }
  else
  {
    if ( first && comparator( node->data, first->data ) ) first = node.get();
    parent->left = std::move( node );
  }

  ++count;
  return iterator{ parent };
}

template<typename Data, typename Comparator>
auto BinarySearchTree<Data, Comparator>::remove( const Data& data, Node* node ) -> iterator
{
  if ( ! node ) return end();

  if ( node->data == data )
  {
    if ( first == node ) first = nullptr;
    if ( last == node ) last = nullptr;

    if ( node->parent )
    {
      if ( node->parent->left.get() == node )
      {
        node->parent->left = nullptr;
      }
      else
      {
        node->parent->right = nullptr;
      }
    }
    else
    {
      rootNode = nullptr;
      count = 1;
    }

    --count;
    return ( node->parent ) ? iterator{ node->parent } : begin();
  }

  return ( comparator( data, node->data ) ) ?
    remove( data, node->left.get() ) : remove( data, node->right.get() );
}

template<typename Data, typename Comparator>
void BinarySearchTree<Data, Comparator>::clear()
{
  rootNode = nullptr;
  first = last = nullptr;
  count = 0;
}

template<typename Data, typename Comparator>
auto BinarySearchTree<Data,Comparator>::begin() noexcept -> iterator
{
  if ( rootNode && !first ) leftMost();
  return iterator{ first };
}

template<typename Data, typename Comparator>
auto BinarySearchTree<Data,Comparator>::cbegin() const noexcept -> const_iterator
{
  if ( rootNode && !first ) const_cast<BinarySearchTree*>( this )->leftMost();
  return const_iterator{ first };
}

template<typename Data, typename Comparator>
auto BinarySearchTree<Data,Comparator>::rbegin() noexcept -> reverse_iterator
{
  if ( rootNode && !last ) rightMost();
  sentinel->parent = last;
  return std::make_reverse_iterator( iterator{ sentinel.get() } );
}

template<typename Data, typename Comparator>
auto BinarySearchTree<Data,Comparator>::crbegin() const noexcept -> const_reverse_iterator
{
  if ( rootNode && !last ) const_cast<BinarySearchTree*>( this ) ->rightMost();
  sentinel->parent = last;
  return std::make_reverse_iterator( const_iterator{ sentinel.get() } );
}

BinarySearchTree.cpp

Unit test suite for the binary search tree implementation

#include "catch.hpp"
#include "BinarySearchTree.h"
#include <iostream>
#include <array>

namespace csc240
{
  namespace internal
  {
    template <typename Data, typename Comparator>
    void iterate( BinarySearchTree<Data, Comparator>& bst )
    {
      for ( auto& value : bst ) {}
      for ( auto iter = bst.cbegin(); iter != bst.cend(); std::advance( iter, 1 ) )
      {
        if ( iter != bst.cbegin() )
        {
          std::advance( iter, -1 );
          std::advance( iter, 1 );
        }
      }

      auto cend = bst.crend();
      for ( auto iter = bst.crbegin(); iter != bst.crend(); ++iter )
      {
        *iter;
      }
    }

    template <typename Data, typename Comparator>
    void testExists( const BinarySearchTree<Data,Comparator>& bst, const Data& value )
    {
      REQUIRE( bst.exists( value ) );
      iterate( const_cast<BinarySearchTree<Data,Comparator>&>( bst ) );
    }

    template <typename Comparator>
    void testUint8( BinarySearchTree<uint8_t,Comparator>& bst )
    {
      REQUIRE( bst.empty() );
      REQUIRE_THROWS( bst.root() );
      iterate<uint8_t,Comparator>( bst );

      REQUIRE_FALSE( bst.exists( 5 ) );
      REQUIRE( bst.find( 5 ) == bst.end() );
      auto iter = bst.emplace( 5 );
      testExists<uint8_t,Comparator>( bst, 5 );
      REQUIRE( 5 == bst.root() );
      REQUIRE( 1 == bst.size() );
      REQUIRE( 5 == *iter );
      REQUIRE( 5 == *( bst.find( 5 ) ) );
      REQUIRE_FALSE( bst.empty() );

      REQUIRE_FALSE( bst.exists( 3 ) );
      REQUIRE( bst.find( 3 ) == bst.end() );
      iter = bst.emplace( 3 );
      testExists<uint8_t,Comparator>( bst, 3 );
      REQUIRE( 2 == bst.size() );
      REQUIRE( 5 == *iter );
      REQUIRE( 3 == *( bst.find( 3 ) ) );
      REQUIRE_FALSE( bst.empty() );

      REQUIRE_FALSE( bst.exists( 8 ) );
      REQUIRE( bst.find( 8 ) == bst.end() );
      iter = bst.emplace( 8 );
      testExists<uint8_t,Comparator>( bst, 8 );
      REQUIRE( 3 == bst.size() );
      REQUIRE( 5 == *iter );
      REQUIRE( 8 == *( bst.find( 8 ) ) );
      REQUIRE_FALSE( bst.empty() );

      REQUIRE_FALSE( bst.exists( 1 ) );
      REQUIRE( bst.find( 1 ) == bst.end() );
      iter = bst.remove( 1 );
      REQUIRE( 3 == bst.size() );
      REQUIRE( iter == bst.end() );
      REQUIRE_FALSE( bst.empty() );

      iter = bst.remove( 3 );
      REQUIRE_FALSE( bst.exists( 3 ) );
      REQUIRE( bst.find( 3 ) == bst.end() );
      REQUIRE( 2 == bst.size() );
      REQUIRE( iter != bst.end() );
      REQUIRE_FALSE( bst.empty() );

      iter = bst.remove( 8 );
      REQUIRE_FALSE( bst.exists( 8 ) );
      REQUIRE( bst.find( 8 ) == bst.end() );
      REQUIRE( 1 == bst.size() );
      REQUIRE( iter != bst.end() );
      REQUIRE_FALSE( bst.empty() );

      REQUIRE_FALSE( bst.exists( 1 ) );
      REQUIRE( bst.find( 1 ) == bst.end() );
      bst.emplace( 1 );
      testExists<uint8_t,Comparator>( bst, 1 );
      REQUIRE( 2 == bst.size() );
      REQUIRE( 1 == *( bst.find( 1 ) ) );
      REQUIRE_FALSE( bst.empty() );

      bst.remove( 5 );
      REQUIRE( bst.empty() );
      REQUIRE( 0 == bst.size() );
    }
  }
}

SCENARIO( "Binary search tree works with simple types" )
{
  GIVEN( "A BST that orders uint8_t in ascending order" )
  {
    csc240::BinarySearchTree<uint8_t> bst;
    csc240::internal::testUint8( bst );
  }

  GIVEN( "A BST that orders uint8_t in descending order" )
  {
    csc240::BinarySearchTree<uint8_t, std::greater<uint8_t>> bst;
    csc240::internal::testUint8( bst );
  }
}

SCENARIO( "Binary tree works with initialiser lists" )
{
  GIVEN( "A BST of strings in ascending order" )
  {
    WHEN( "Constructed using initialiser list" )
    {
      csc240::BinarySearchTree<std::string> bst{ "c", "f", "a", "e", "d", "b" };
      REQUIRE( 6 == bst.size() );
      REQUIRE( "c" == bst.root() );
      REQUIRE( "a" == bst.leftMost() );
      REQUIRE( "f" == bst.rightMost() );

      uint8_t c = 97;
      for ( const auto& value : bst  )
      {
        REQUIRE( value[0] == c++ );
      }
    }

    AND_WHEN( "Constructed using different order" )
    {
      csc240::BinarySearchTree<std::string> bst{ "c", "f", "b", "e", "d", "a" };
      REQUIRE( 6 == bst.size() );
      REQUIRE( "c" == bst.root() );
      REQUIRE( "a" == bst.leftMost() );
      REQUIRE( "f" == bst.rightMost() );

      uint8_t c = 97;
      for ( const auto& value : bst  )
      {
        REQUIRE( value[0] == c++ );
      }
    }
  }

  GIVEN( "A BST of strings in descending order" )
  {
    WHEN( "Constructed using initialiser list" )
    {
      csc240::BinarySearchTree<std::string,std::greater<std::string>> bst{ "c", "f", "a", "e", "b", "d" };
      REQUIRE( 6 == bst.size() );
      REQUIRE( "c" == bst.root() );
      REQUIRE( "f" == bst.leftMost() );
      REQUIRE( "a" == bst.rightMost() );

      uint8_t c = 102;
      for ( const auto& value : bst )
      {
        REQUIRE( value[0] == c-- );
      }
    }

    AND_WHEN( "Constructed using different order" )
    {
      csc240::BinarySearchTree<std::string,std::greater<std::string>> bst{ "c", "f", "a", "e", "d", "b" };
      REQUIRE( 6 == bst.size() );
      REQUIRE( "c" == bst.root() );
      REQUIRE( "f" == bst.leftMost() );
      REQUIRE( "a" == bst.rightMost() );

      uint8_t c = 102;
      for ( const auto& value : bst  )
      {
        REQUIRE( value[0] == c-- );
      }
    }
  }
}

SCENARIO( "Binary tree works with input iterators" )
{
  std::array<uint8_t, 9> array{ 27, 20, 62, 15, 25, 40, 71, 16, 21 };

  GIVEN( "A BST of integers in ascending order" )
  {
    csc240::BinarySearchTree<uint8_t> bst{ std::begin( array ), std::end( array ) };

    WHEN( "Constructed using input iterators" )
    {
      REQUIRE( bst.size() == array.size() );
      auto iter = bst.cbegin();
      REQUIRE( 15 == *(iter++) );
      REQUIRE( 16 == *(iter++) );
      REQUIRE( 20 == *(iter++) );
      REQUIRE( 21 == *(iter++) );
      REQUIRE( 25 == *(iter++) );
      REQUIRE( 27 == *(iter++) );
      REQUIRE( 40 == *(iter++) );
      REQUIRE( 62 == *(iter++) );
      REQUIRE( 71 == *(iter++) );
    }
  }

  GIVEN( "A BST of integers in descending order" )
  {
    csc240::BinarySearchTree<uint8_t,std::greater<uint8_t>> bst{ std::begin( array ), std::end( array ) };

    WHEN( "Constructed using input iterators" )
    {
      REQUIRE( bst.size() == array.size() );
      auto iter = bst.cbegin();
      REQUIRE( 71 == *( iter++ ) );
      REQUIRE( 62 == *( iter++ ) );
      REQUIRE( 40 == *( iter++ ) );
      REQUIRE( 27 == *( iter++ ) );
      REQUIRE( 25 == *( iter++ ) );
      REQUIRE( 21 == *( iter++ ) );
      REQUIRE( 20 == *( iter++ ) );
      REQUIRE( 16 == *( iter++ ) );
      REQUIRE( 15 == *( iter++ ) );
    }
  }
}

SCENARIO( "const qualified functions work" )
{
  GIVEN( "A const BST of strings" )
  {
    const csc240::BinarySearchTree<std::string> bst{ "c", "f", "a", "e", "b", "d" };
    REQUIRE( 6 == bst.size() );
    REQUIRE( "c" == bst.root() );
    REQUIRE( "a" == bst.leftMost() );
    REQUIRE( "f" == bst.rightMost() );

    REQUIRE( bst.exists( "d" ) );
    REQUIRE( "d" == *( bst.find( "d" ) ) );

    auto rhs = csc240::BinarySearchTree<std::string>{ "a", "b", "c", "d", "e", "f" };
    REQUIRE( bst == rhs );
  }
}