// Copyright (C) 2015 ChaosForge Ltd
// http://chaosforge.org/
//
// This file is part of Nova libraries. 
// For conditions of distribution and use, see copying.txt file in root folder.

/**
* @file hash_table.hh
* @author Kornel Kisielewicz epyon@chaosforge.org
* @brief hash table classes
*/

#ifndef NV_STL_HASH_TABLE_HH
#define NV_STL_HASH_TABLE_HH

#include <nv/common.hh>
#include <nv/stl/memory.hh>
#include <nv/stl/functional/hash.hh>
#include <nv/stl/container/hash_table_policy.hh>
#include <nv/stl/utility/pair.hh>

namespace nv
{

	extern void* g_hash_table_empty[2];

	template <
		typename HashEntryPolicy,
		typename RehashPolicy,
		typename BaseClass
	>
	class hash_table_storage : protected BaseClass
	{
	public:
		typedef HashEntryPolicy                      base_type;
		typedef typename HashEntryPolicy::value_type value_type;
		typedef typename HashEntryPolicy::entry_type entry_type;
		typedef typename HashEntryPolicy::hash_type  hash_type;

		struct node_type : entry_type
		{
			node_type* next;
		};

	private:
		template < bool IsConst >
		struct iterator_base
		{
		public:
			typedef typename HashEntryPolicy::value_type                   value_type;
			typedef conditional_t<IsConst, const value_type*, value_type*> pointer;
			typedef conditional_t<IsConst, const value_type&, value_type&> reference;
			typedef ptrdiff_t                                              difference_type;
			typedef forward_iterator_tag                                   iterator_category;

		public:
			constexpr iterator_base() : m_node( nullptr ) {}
			const entry_type* entry() const { return m_node;  }
			reference operator*() const { return m_node->value; }
			pointer operator->() const { return &( m_node->value ); }
		protected:
			constexpr explicit iterator_base( node_type* node ) : m_node( node ) {}

			node_type* m_node;
		};


		template < bool IsConst >
		struct node_iterator : iterator_base< IsConst >
		{
			typedef iterator_base< IsConst > base_type;

			constexpr node_iterator() = default;
			template < bool B, typename = enable_if_t< IsConst && !B > >
			constexpr node_iterator( const node_iterator< B >& it ) : base_type( it.m_node ) {}

			node_iterator& operator++() { increment(); return *this; }
			node_iterator operator++( int ) { node_iterator temp( *this ); increment(); return temp; }

			friend constexpr bool operator==( const node_iterator& lhs, const node_iterator& rhs ) { return lhs.m_node == rhs.m_node; }
			friend constexpr bool operator!=( const node_iterator& lhs, const node_iterator& rhs ) { return lhs.m_node != rhs.m_node; }
		protected:
			constexpr explicit node_iterator( node_type* node ) : base_type( node ) {}

			template < typename T1, typename T2, typename T3 >
			friend class hash_table_storage;
			template < bool B > friend struct node_iterator;

			inline void increment() { base_type::m_node = base_type::m_node->next; }

		};

		template < bool IsConst >
		struct table_iterator : public iterator_base< IsConst >
		{
			typedef iterator_base< IsConst > base_type;

			constexpr table_iterator()
				: base_type(), m_bucket( nullptr )
			{
			}
			constexpr table_iterator( const table_iterator& other )
				: base_type( other ), m_bucket( other.m_bucket )
			{
			}

			template < bool B, typename = enable_if_t< IsConst && !B > >
			constexpr table_iterator( const table_iterator< B >& it )
				: base_type( it.m_node ), m_bucket( it.m_bucket )
			{
			}

			table_iterator& operator++() { increment(); return *this; }
			table_iterator operator++( int ) { table_iterator temp( *this ); increment(); return temp; }

			friend constexpr bool operator==( const table_iterator& lhs, const table_iterator& rhs ) { return lhs.m_node == rhs.m_node; }
			friend constexpr bool operator!=( const table_iterator& lhs, const table_iterator& rhs ) { return lhs.m_node != rhs.m_node; }
		protected:
			template < typename T1, typename T2, typename T3 >
			friend class hash_table_storage;
			template < bool B >     friend struct table_iterator;

			constexpr table_iterator( node_type* node, node_type** bucket )
				: base_type( node ), m_bucket( bucket )
			{
			}
			constexpr explicit table_iterator( node_type** bucket )
				: base_type( *bucket ), m_bucket( bucket )
			{
			}

			void next_bucket()
			{
				++m_bucket;
				while ( *m_bucket == nullptr ) ++m_bucket; // Sentinel guarantees end
				base_type::m_node = *m_bucket;
			}

			void increment()
			{
				base_type::m_node = base_type::m_node->next;
				while ( base_type::m_node == nullptr ) base_type::m_node = *++m_bucket; // Sentinel guarantees end
			}
			node_type** m_bucket;
		};

	public:
		typedef node_iterator< false > local_iterator;
		typedef node_iterator< true >  const_local_iterator;
		typedef table_iterator< false >      iterator;
		typedef table_iterator< true >       const_iterator;

		typedef size_t size_type;
		typedef node_type* bucket_type;

		hash_table_storage()
		{
			zero();
			m_max_load_factor = 1.0f;
		}

		explicit hash_table_storage( size_t count )
			: m_buckets( nullptr ), m_bucket_count( 0 ), m_element_count( 0 ), m_max_load_factor( 1.0f )
		{
			if ( count > 1 )
			{
				m_bucket_count = RehashPolicy::get_bucket_count( count );
				m_buckets = allocate_buckets( m_bucket_count );
			}
			else zero();
		}

 		hash_table_storage( const hash_table_storage& ) = delete;
 		hash_table_storage& operator=( const hash_table_storage& ) = delete;

		inline hash_table_storage( hash_table_storage&& other )
			: m_buckets( other.m_buckets )
			, m_bucket_count( other.m_bucket_count )
			, m_element_count( other.m_element_count )
			, m_max_load_factor( other.m_max_load_factor )
		{
			other.m_buckets = nullptr;
			other.m_bucket_count = 0;
			other.m_element_count = 0;
		}
		inline hash_table_storage& operator=( hash_table_storage&& other )
		{
			if ( this != &other )
			{
				free_nodes( m_buckets, m_bucket_count );
				free_buckets( m_buckets, m_bucket_count );
				m_buckets         = other.m_buckets;
				m_bucket_count    = other.m_bucket_count;
				m_element_count   = other.m_element_count;
				m_max_load_factor = other.max_load_factor;
			}
			return *this;
		}

		size_type bucket_size( size_type n ) const
		{
			if ( n >= m_bucket_count ) return 0;
			node_type* current = m_buckets[n];
			size_type result = 0;
			while ( current )
			{
				++result;
				current = current->next;
			}
			return result;
		}

		iterator erase( const_iterator which )
		{
			iterator result( which.m_node, which.m_bucket );
			++result;
			do_remove( which.m_node, which.m_bucket );
			return result;
		}

		inline void rehash( size_type new_buckets )
		{
			size_type needed_count    = RehashPolicy::get_bucket_count( m_element_count, m_max_load_factor );
			size_type requested_count = RehashPolicy::get_bucket_count( new_buckets );
			do_rehash( nv::max( needed_count, requested_count ) );
		}

		inline void reserve( size_type new_elements )
		{
			if ( new_elements > m_element_count )
			{
				size_type requested_count = RehashPolicy::get_bucket_count( new_elements, m_max_load_factor );
				do_rehash( requested_count );
			}
		}

		void clear()
		{
			free_nodes( m_buckets, m_bucket_count );
			m_element_count = 0;
		}

		iterator begin()
		{
			iterator result( m_buckets );
			if ( !result.m_node ) result.next_bucket(); 
			return result;
		}
		const_iterator cbegin() const
		{
			const_iterator result( m_buckets );
			if ( !result.m_node ) result.next_bucket();
			return result;
		}
		inline iterator             end() { return iterator( m_buckets + m_bucket_count ); }
		inline const_iterator       cend() const { return const_iterator( m_buckets + m_bucket_count ); }
		inline const_iterator       begin() const { return cbegin(); }
		inline const_iterator       end() const { return cend(); }

		inline local_iterator       begin( size_type n ) { return local_iterator( m_buckets[n] ); }
		inline local_iterator       end( size_type ) { return local_iterator( nullptr ); }
		inline const_local_iterator begin( size_type n ) const { return const_local_iterator( m_buckets[n] ); }
		inline const_local_iterator end( size_type ) const { return const_local_iterator( nullptr ); }
		inline const_local_iterator cbegin( size_type n ) const { return const_local_iterator( m_buckets[n] ); }
		inline const_local_iterator cend( size_type ) const { return const_local_iterator( nullptr ); }

		inline bool empty() const { return m_element_count == 0; }
		inline size_type size() const { return m_element_count; }
		inline size_type bucket_count() const { return m_bucket_count; }
		inline size_type max_size() const { return 2147483647; }
		inline size_type max_bucket_count() const { return 2147483647; }
		inline float load_factor() const { return static_cast<float>( m_element_count ) / static_cast<float>( m_bucket_count ); }
		inline float max_load_factor() const { return m_max_load_factor; }
		inline void max_load_factor( float ml ) { m_max_load_factor = ml; }

		~hash_table_storage()
		{
			free_nodes( m_buckets, m_bucket_count );
			free_buckets( m_buckets, m_bucket_count );
		}
	protected:
		inline iterator unconst( const_iterator i )                       const { return iterator( i.m_node, i.m_bucket ); }
		inline iterator unlocalize( size_type n, const_local_iterator i ) const { return iterator( i.m_node, m_buckets + n ); }

		template < typename... Args >
		iterator insert( size_type index, hash_type hash_code, Args&&... args )
		{
			node_type* node = alloc_node();
			HashEntryPolicy::entry_construct( node, hash_code, nv::forward<Args>( args )... );
			node->next = m_buckets[index];
			m_buckets[index] = node;
			m_element_count++;
			return iterator( node, m_buckets + index );
		}

		size_type get_bucket_index( hash_type hash_code ) const
		{
			return bucket_index( hash_code, m_bucket_count );
		}

		bool rehash_check( size_type new_elements )
		{
			uint32 new_buckets = RehashPolicy::is_rehash_required( m_bucket_count, m_element_count + new_elements, m_max_load_factor );
			if ( new_buckets )
			{
				do_rehash( new_buckets );
				return true;
			}
			else
				return false;
		}

		local_iterator erase_local( size_type index, const_local_iterator which )
		{
			local_iterator result( which.m_node );
			++result;
			do_remove( which.m_node, m_buckets + index );
			return result;
		}

		void do_rehash( size_type new_count )
		{
			NV_ASSERT( new_count > 1, "Rehash fail!" );
			node_type** new_buckets = allocate_buckets( new_count );
			node_type* current;
			for ( size_type i = 0; i < m_bucket_count; ++i )
			{
				while ( ( current = m_buckets[i] ) != nullptr )
				{
					size_type new_index = bucket_index( HashEntryPolicy::get_entry_hash( current ), new_count );
					m_buckets[i] = current->next;
					current->next = new_buckets[new_index];
					new_buckets[new_index] = current;
				}
			}
			free_buckets( m_buckets, m_bucket_count );
			m_bucket_count = new_count;
			m_buckets = new_buckets;
		}

		void zero()
		{
			m_buckets       = reinterpret_cast<node_type**>( &g_hash_table_empty[0] );
			m_bucket_count  = 1;
			m_element_count = 0;
		}

		void do_remove( node_type* node, bucket_type* bucket )
		{
			node_type* current = *bucket;
			node_type* next = current->next;
			if ( current == node )
				*bucket = next;
			else
			{
				while ( next != node )
				{
					current = next;
					next = current->next;
				}
				current->next = next->next;
			}
			free_node( node );
			m_element_count--;
		}

		size_type bucket_index( hash_type hash_code, size_t b_count ) const
		{
			return RehashPolicy::template get_index<hash_type>( hash_code, b_count );
		}
		node_type** allocate_buckets( size_type new_count )
		{
			NV_ASSERT( new_count > 1, "allocate_buckets fail!" );
			node_type** buckets = reinterpret_cast<node_type**>( nvmalloc( ( new_count + 1 ) * sizeof( node_type* ) ) );
			nvmemset( buckets, 0, ( new_count + 1 ) * sizeof( node_type* ) );
			buckets[ new_count ] = reinterpret_cast<node_type*>( uintptr_t(~0) ); // sentinel
			return buckets;
		}
		void free_buckets( node_type** buckets, size_type count )
		{
			if ( count > 1 )
				nvfree( buckets );
		}

		node_type* alloc_node()
		{
			return static_cast<node_type*>( nvmalloc( sizeof( node_type ) ) );
		}

		void free_node( node_type* node )
		{
			HashEntryPolicy::entry_destroy( node );
			nvfree( node );
		}

		void free_nodes( node_type** buckets, size_type count )
		{
			if ( count > 1 )
			for ( size_type i = 0; i < count; ++i )
			{
				node_type* node = buckets[i];
				while ( node )
				{
					node_type* temp = node;
					node = node->next;
					free_node( temp );
				}
				buckets[i] = nullptr;
			}
		}

		node_type** m_buckets;
		size_type   m_bucket_count;
		size_type   m_element_count;
		float       m_max_load_factor;
	};

	template <
		typename HashEntryPolicy,
		template < typename > class HashQueryPoilcy = hash_table_no_extra_types_policy,
		typename RehashPolicy = hash_table_prime_rehash_policy,
		typename SuperClass = empty_type
	>
	class hash_table_base : public hash_table_storage< HashEntryPolicy, RehashPolicy, SuperClass >
	{
		typedef hash_table_storage< HashEntryPolicy, RehashPolicy, SuperClass > base_type;
	public:
		typedef typename base_type::value_type        value_type;
		typedef typename base_type::entry_type        entry_type;
		typedef typename base_type::hash_type         hash_type;
		typedef typename HashEntryPolicy::query_type  query_type;
		typedef typename HashEntryPolicy::key_type    key_type;
		typedef typename HashEntryPolicy::mapped_type mapped_type;
		typedef size_t                                size_type;
		typedef ptrdiff_t                             difference_type;
		typedef value_type&                           reference;
		typedef const value_type&                     const_reference;
		typedef value_type*                           pointer;
		typedef const value_type*                     const_pointer;

		typedef typename base_type::local_iterator       local_iterator;
		typedef typename base_type::const_local_iterator const_local_iterator;
		typedef typename base_type::iterator             iterator;
		typedef typename base_type::const_iterator       const_iterator;

		typedef pair< iterator, bool >                   insert_return_type;

	public: // constructors
		hash_table_base() {}
		explicit hash_table_base( size_type size ) : base_type( size ) {}
		hash_table_base( const hash_table_base& other ) = delete;
		hash_table_base( hash_table_base&& other ) = default;
		template< typename InputIterator >
		hash_table_base( InputIterator first, InputIterator last ) : base_type()
		{
			insert( first, last );
		}

	public: // assignements

		hash_table_base& operator=( const hash_table_base& other ) = delete;
		hash_table_base& operator=( hash_table_base&& other ) = default;

	public: // iterators

		using base_type::begin;
		using base_type::cbegin;
		using base_type::end;
		using base_type::cend;

	public: // capacity

		using base_type::empty;
		using base_type::size;
		using base_type::max_size;

	public: // modifiers

		using base_type::clear;

		inline insert_return_type insert( const value_type& value )
		{
			const hash_type h = HashEntryPolicy::get_value_hash( value );
			size_type       b = base_type::get_bucket_index( h );
			iterator        r = do_find_node( b, h, value );

			if ( r == this->end() )
			{
				if ( base_type::rehash_check( 1 ) ) b = base_type::get_bucket_index( h );
				return insert_return_type( base_type::insert( b, h, value ), true );
			}

			return insert_return_type( r, false );
		}

		inline insert_return_type insert( value_type&& value )
		{
			const hash_type h = HashEntryPolicy::get_value_hash( value );
			size_type       b = base_type::get_bucket_index( h );
			iterator        r = do_find_node( b, h, value );

			if ( r == this->end() )
			{
				if ( base_type::rehash_check( 1 ) ) b = base_type::get_bucket_index( h );
				return insert_return_type( base_type::insert( b, h, nv::forward( value ) ), true );
			}

			return insert_return_type( r, false );
		}

		template <typename InputIterator>
		void insert( InputIterator first, InputIterator last )
		{
			size_type estimate = estimate_distance( first, last );
			base_type::rehash_check( estimate );
			for ( ; first != last; ++first ) insert( *first );
		}

		using base_type::erase;

		iterator erase( const_iterator first, const_iterator last )
		{
			while ( first != last )
				first = erase( first );
			return base_type::unconst( first );
		}

		size_type erase( const query_type& key )
		{
			const hash_type c = HashEntryPolicy::get_hash( key );
			size_type       b = base_type::get_bucket_index( c );
			return do_erase( b, c, key );
		}

		template < typename T >
		enable_if_t< HashQueryPoilcy<T>::value, size_type >
		erase( const T& key )
		{
			const hash_type c = HashEntryPolicy::get_hash( key );
			size_type       b = base_type::get_bucket_index( c );
			return do_erase( b, c, key );
		}

	public: // lookup

		iterator       find( const query_type& key )
		{
			const hash_type c = HashEntryPolicy::get_hash( key );
			size_type       b = base_type::get_bucket_index( c );
			return do_find_node( b, c, key );
		}
		
		const_iterator find( const query_type& key ) const
		{
			const hash_type c = HashEntryPolicy::get_hash( key );
			size_type       b = base_type::get_bucket_index( c );
			return do_find_node( b, c, key );
		}

		template < typename T >
		enable_if_t< HashQueryPoilcy<T>::value, iterator >
		find( const T& key )
		{
			const hash_type c = HashEntryPolicy::get_hash( key );
			size_type       b = base_type::get_bucket_index( c );
			return do_find_node( b, c, key );
		}

		template < typename T >
		enable_if_t< HashQueryPoilcy<T>::value, const_iterator >
		find( const T& key ) const
		{
			const hash_type c = HashEntryPolicy::get_hash( key );
			size_type       b = base_type::get_bucket_index( c );
			return do_find_node( b, c, key );
		}

	protected:

		template < typename KeyConvertible >
		inline insert_return_type insert_key( const KeyConvertible& key )
		{
			const hash_type h = HashEntryPolicy::get_hash( key );
			size_type       b = base_type::get_bucket_index( h );
			iterator        r = do_find_node( b, h, key );

			if ( r == this->end() )
			{
				if ( base_type::rehash_check( 1 ) ) b = base_type::get_bucket_index( h );
				return insert_return_type( base_type::insert( b, h, key, mapped_type() ), true );
			}

			return insert_return_type( r, false );
		}

		template < typename KeyConvertible, typename M >
		inline insert_return_type try_insert( const KeyConvertible& key, M&& obj )
		{
			const hash_type h = HashEntryPolicy::get_hash( key );
			size_type       b = base_type::get_bucket_index( h );
			iterator        r = do_find_node( b, h, key );

			if ( r == this->end() )
			{
				if ( base_type::rehash_check( 1 ) ) b = base_type::get_bucket_index( h );
				return insert_return_type( base_type::insert( b, h, key, ::nv::forward< M >( obj ) ), true );
			}
			return insert_return_type( r, false );
		}

		template < typename KeyConvertible, typename M >
		inline insert_return_type try_insert( KeyConvertible&& key, M&& obj )
		{
			const hash_type h = HashEntryPolicy::get_hash( key );
			size_type       b = base_type::get_bucket_index( h );
			iterator        r = do_find_node( b, h, key );

			if ( r == this->end() )
			{
				if ( base_type::rehash_check( 1 ) ) b = base_type::get_bucket_index( h );
				return insert_return_type( base_type::insert( b, h, ::nv::move(key), ::nv::forward< M >( obj ) ), true );
			}
			return insert_return_type( r, false );
		}

		template < typename ComparableType >
		inline iterator do_find_node( size_type index, hash_type h, const ComparableType& query ) const
		{
			const_local_iterator first = this->cbegin( index );
			const_local_iterator last  = this->cend( index );
			while ( first != last )
			{
				if ( HashEntryPolicy::entry_compare( first.entry(), h, query ) )
					return base_type::unlocalize( index, first );
				++first;
			}
			return base_type::unconst( this->cend() );
		}

		template < typename ComparableType >
		inline size_type do_erase( size_type index, hash_type h, const ComparableType& query ) 
		{
			// TODO : optimize for non-multi sets! (one key possible)
			const_local_iterator first = this->cbegin( index );
			const_local_iterator last  = this->cend( index );
			size_type result = 0;
			while ( first != last )
			{
				if ( HashEntryPolicy::entry_compare( first.entry(), h, query ) )
				{
					first = base_type::erase_local( index, first );
					++result;
				}
				else
					++first;
			}
			return result;
		}

	};

/*
	template < 
		typename Key, 
		typename T, 
		typename Compare,
		typename EntryPolicy,
		typename StoragePolicy
	>
	class associative_container : public StoragePolicy< EntryPolicy >
	{
	public:
		typedef Key   key_type;
		typedef T     value_type;
		typedef typename EntryPolicy::value_type value_type;
		typedef typename StoragePolicy::storage_type storage_type;
		typedef typename StoragePolicy::iterator iterator;
		typedef typename StoragePolicy::const_iterator const_iterator;
	public:
		associative_container();
		bool empty() const;
		bool size() const;
		bool max_size() const;
		void clear() const;
		void insert( const key_type& k, value_type&& obj );
		void insert( key_type&& k, value_type&& obj );
		void insert( entry_type&& obj );
		void insert( const entry_type& obj );
		template < typename ...Args >
		void emplace( key_type&& k, Args&&... args ); 
		template < typename ...Args >
		void emplace( const key_type& k, Args&&... args );

	};
*/
}

#endif // NV_STL_HASH_TABLE_HH
