// Copyright (C) 2015 ChaosForge Ltd
// http://chaosforge.org/
//
// This file is part of NV Libraries.
// For conditions of distribution and use, see copyright notice in nv.hh

/**
* @file hash_table.hh
* @author Kornel Kisielewicz epyon@chaosforge.org
* @brief hash table classes
*/

#ifndef NV_STL_HASH_TABLE_HH
#define NV_STL_HASH_TABLE_HH

#include <nv/core/common.hh>
#include <nv/stl/memory.hh>
#include <nv/stl/functional/hash.hh>
#include <nv/stl/container/hash_table_policy.hh>
#include <nv/stl/utility/pair.hh>

namespace nv
{

	extern void* g_hash_table_empty[2];

	template <
		typename HashEntryPolicy,
		typename RangeHashPolicy,
		typename RehashPolicy
	>
	class hash_table_storage : protected HashEntryPolicy
	{
	public:
		typedef HashEntryPolicy base_type;
		typedef typename base_type::value_type     value_type;
		typedef typename base_type::entry_type     entry_type;
		typedef typename base_type::hash_type      hash_type;

		struct node_type : entry_type
		{
			node_type* next;
		};

	private:
		template < bool IsConst >
		struct iterator_base
		{
		public:
			typedef typename base_type::value_type                         value_type;
			typedef conditional_t<IsConst, const value_type*, value_type*> pointer;
			typedef conditional_t<IsConst, const value_type&, value_type&> reference;
			typedef ptrdiff_t                                              difference_type;
			typedef forward_iterator_tag                                   iterator_category;

		public:
			constexpr iterator_base() : m_node( nullptr ) {}
			constexpr iterator_base( const iterator_base& it ) : m_node( it.m_node ) {}
			const entry_type* entry() const { return m_node;  }
			reference operator*() const { return m_node->value; }
			pointer operator->() const { return &( m_node->value ); }
		protected:
			constexpr explicit iterator_base( node_type* node ) : m_node( node ) {}

			node_type* m_node;
		};


		template < bool IsConst >
		struct node_iterator : iterator_base< IsConst >
		{
			typedef iterator_base< IsConst > base_type;

			constexpr node_iterator() : iterator_base( nullptr ) {}
			constexpr node_iterator( const node_iterator& it ) : iterator_base( it.m_node ) {}

			template < bool B, typename = enable_if_t< IsConst && !B > >
			constexpr node_iterator( const node_iterator< B >& it )
				: iterator_base( it.m_node )
			{
			}

			node_iterator& operator++() { increment(); return *this; }
			node_iterator operator++( int ) { node_iterator temp( *this ); increment(); return temp; }

			friend constexpr bool operator==( const node_iterator& lhs, const node_iterator& rhs ) { return lhs.m_node == rhs.m_node; }
			friend constexpr bool operator!=( const node_iterator& lhs, const node_iterator& rhs ) { return lhs.m_node != rhs.m_node; }
		protected:
			constexpr explicit node_iterator( node_type* node ) : iterator_base( node ) {}

			template < typename T1, typename T2, typename T3 >
			friend class hash_table_storage;
			template < bool B > friend struct node_iterator;

			inline void increment() { m_node = m_node->next; }

		};

		template < bool IsConst >
		struct table_iterator : public iterator_base< IsConst >
		{
			typedef iterator_base< IsConst > base_type;

			constexpr table_iterator()
				: iterator_base(), m_bucket( nullptr )
			{
			}
			constexpr table_iterator( const table_iterator& other )
				: iterator_base( other ), m_bucket( other.m_bucket )
			{
			}

			template < bool B, typename = enable_if_t< IsConst && !B > >
			constexpr table_iterator( const table_iterator< B >& it )
				: iterator_base( it.m_node ), m_bucket( it.m_bucket )
			{
			}

			table_iterator& operator++() { increment(); return *this; }
			table_iterator operator++( int ) { bucket_iterator temp( *this ); increment(); return temp; }

			friend constexpr bool operator==( const table_iterator& lhs, const table_iterator& rhs ) { return lhs.m_node == rhs.m_node; }
			friend constexpr bool operator!=( const table_iterator& lhs, const table_iterator& rhs ) { return lhs.m_node != rhs.m_node; }
		protected:
			template < typename T1, typename T2, typename T3 >
			friend class hash_table_storage;
			template < bool B >     friend struct table_iterator;

			constexpr table_iterator( node_type* node, node_type** bucket )
				: iterator_base( node ), m_bucket( bucket )
			{
			}
			constexpr explicit table_iterator( node_type** bucket )
				: iterator_base( *bucket ), m_bucket( bucket )
			{
			}

			void next_bucket()
			{
				++m_bucket;
				while ( *m_bucket == nullptr ) ++m_bucket; // Sentinel guarantees end
				m_node = *m_bucket;
			}

			void increment()
			{
				m_node = m_node->next;
				while ( m_node == nullptr ) m_node = *++m_bucket; // Sentinel guarantees end
			}
			node_type** m_bucket;
		};

	public:
		typedef node_iterator< false > local_iterator;
		typedef node_iterator< true >  const_local_iterator;
		typedef table_iterator< false >      iterator;
		typedef table_iterator< true >       const_iterator;

		typedef size_t size_type;
		typedef node_type* bucket_type;

		hash_table_storage()
		{
			zero();
			m_max_load_factor = 1.0f;
		}

		explicit hash_table_storage( size_t count )
			: m_buckets( nullptr ), m_bucket_count( 0 ), m_element_count( 0 ), m_max_load_factor( 1.0f )
		{
			if ( count > 1 )
			{
				m_bucket_count = RehashPolicy::get_bucket_count( count );
				m_buckets = allocate_buckets( m_bucket_count );
			}
			else zero();
		}

		iterator erase( const_iterator which )
		{
			iterator result( which.m_node, which.m_bucket );
			++result;
			do_remove( which.m_node, which.m_bucket );
			return result;
		}

		inline void rehash( size_type new_buckets )
		{
			size_type needed_count    = RehashPolicy::get_bucket_count( m_element_count, m_max_load_factor );
			size_type requested_count = RehashPolicy::get_bucket_count( new_buckets );
			do_rehash( nv::max( needed_count, requested_count ) );
		}

		inline void reserve( size_type new_elements )
		{
			if ( new_elements > m_element_count )
			{
				size_type requested_count = RehashPolicy::get_bucket_count( new_elements, m_max_load_factor );
				do_rehash( requested_count );
			}
		}

		void clear( size_type new_count )
		{
			free_nodes( m_buckets, m_bucket_count );
			m_element_count = 0;
		}

		iterator begin()
		{
			iterator result( m_buckets );
			if ( !result.m_node ) result.next_bucket(); 
			return result;
		}
		const_iterator cbegin() const
		{
			const_iterator result( m_buckets );
			if ( !result.m_node ) result.next_bucket();
			return result;
		}
		inline iterator             end() { return iterator( m_buckets + m_bucket_count ); }
		inline const_iterator       cend() const { return const_iterator( m_buckets + m_bucket_count ); }
		inline const_iterator       begin() const { return cbegin(); }
		inline const_iterator       end() const { return cend(); }

		inline local_iterator       begin( size_type n ) { return local_iterator( m_buckets[n] ); }
		inline local_iterator       end( size_type ) { return local_iterator( nullptr ); }
		inline const_local_iterator begin( size_type n ) const { return const_local_iterator( m_buckets[n] ); }
		inline const_local_iterator end( size_type ) const { return const_local_iterator( nullptr ); }
		inline const_local_iterator cbegin( size_type n ) const { return const_local_iterator( m_buckets[n] ); }
		inline const_local_iterator cend( size_type ) const { return const_local_iterator( nullptr ); }

		inline bool empty() const { return m_element_count == 0; }
		inline size_type size() const { return m_element_count; }
		inline size_type bucket_count() const { return m_bucket_count; }
		inline size_type max_size() const { return 2147483647; }
		inline size_type max_bucket_count() const { return 2147483647; }
		inline float load_factor() const { return (float)m_element_count / (float)m_bucket_count; }
		inline float max_load_factor() const { return m_max_load_factor; }
		inline void max_load_factor( float ml ) { m_max_load_factor = ml; }

		using base_type::hash_function;
		using base_type::key_eq;

		~hash_table_storage()
		{
			free_nodes( m_buckets, m_bucket_count );
			free_buckets( m_buckets, m_bucket_count );
		}
	public:
		inline iterator unconst( const_iterator i )                       const { return iterator( i.m_node, i.m_bucket ); }
		inline iterator unlocalize( size_type n, const_local_iterator i ) const { return iterator( i.m_node, m_buckets + n ); }

		template < typename... Args >
		iterator insert( size_type index, Args&&... args )
		{
			node_type* node = alloc_node();
			base_type::entry_construct( node, forward<Args>( args )... );
			node->next = m_buckets[index];
			m_buckets[index] = node;
			m_element_count++;
			return iterator( node, m_buckets + index );
		}

		size_type get_bucket_index( hash_type hash_code )
		{
			return bucket_index( hash_code, m_bucket_count );
		}

		bool rehash_check( size_type new_elements )
		{
			uint32 new_buckets = RehashPolicy::is_rehash_required( m_bucket_count, m_element_count + new_elements, m_max_load_factor );
			if ( new_buckets )
			{
				do_rehash( new_buckets );
				return true;
			}
			else
				return false;
		}

		local_iterator erase_local( size_type index, const_local_iterator which )
		{
			local_iterator result( which.m_node );
			++result;
			do_remove( which.m_node, m_buckets + index );
			return result;
		}


	protected:
		void do_rehash( size_type new_count )
		{
			NV_ASSERT( new_count > 1, "Rehash fail!" );
			node_type** new_buckets = allocate_buckets( new_count );
			node_type* current;
			for ( size_type i = 0; i < m_bucket_count; ++i )
			{
				while ( ( current = m_buckets[i] ) != nullptr )
				{
					size_type new_index = bucket_index( base_type::get_entry_hash( current ), new_count );
					m_buckets[i] = current->next;
					current->next = new_buckets[new_index];
					new_buckets[new_index] = current;
				}
			}
			free_buckets( m_buckets, m_bucket_count );
			m_bucket_count = new_count;
			m_buckets = new_buckets;
		}

		void zero()
		{
			m_buckets       = (node_type**)&g_hash_table_empty[0];
			m_bucket_count  = 1;
			m_element_count = 0;
		}

		void do_remove( node_type* node, bucket_type* bucket )
		{
			node_type* current = *bucket;
			node_type* next = current->next;
			if ( current == node )
				*bucket = next;
			else
			{
				while ( next != node )
				{
					current = next;
					next = current->next;
				}
				current->next = next->next;
			}
			free_node( node );
			m_element_count--;
		}

		size_type bucket_index( hash_type hash_code, size_t b_count )
		{
			return RangeHashPolicy::get<hash_type>( hash_code, b_count );
		}
		node_type** allocate_buckets( size_type new_count )
		{
			NV_ASSERT( new_count > 1, "allocate_buckets fail!" );
			node_type** buckets = ( node_type** )nvmalloc( ( new_count + 1 ) * sizeof( node_type* ) );
			nvmemset( buckets, 0, ( new_count + 1 ) * sizeof( node_type* ) );
			buckets[ new_count ] = reinterpret_cast<node_type*>( (uintptr_t)~0 ); // sentinel
			return buckets;
		}
		void free_buckets( node_type** buckets, size_type count )
		{
			if ( count > 1 )
				nvfree( buckets );
		}

		node_type* alloc_node()
		{
			return (node_type*)nvmalloc( sizeof( node_type ) );
		}

		void free_node( node_type* node )
		{
			node->~node_type();
			nvfree( node );
		}

		void free_nodes( node_type** buckets, size_type count )
		{
			if ( count > 1 )
			for ( size_type i = 0; i < count; ++i )
			{
				node_type* node = buckets[i];
				while ( node )
				{
					node_type* temp = node;
					node = node->next;
					free_node( temp );
				}
				buckets[i] = nullptr;
			}
		}

		node_type** m_buckets;
		size_type   m_bucket_count;
		size_type   m_element_count;
		float       m_max_load_factor;
	};

	template <
		typename HashEntryPolicy,
		typename Storage = hash_table_storage< 
			HashEntryPolicy,
			hash_table_range_mod_policy,
			hash_table_prime_rehash_policy
		>
	>
	class hash_table : public Storage
	{
	public:
		typedef Storage                            base_type;
		typedef typename base_type::key_type       key_type;
		typedef typename base_type::value_type     value_type;
		typedef typename base_type::entry_type     entry_type;
		typedef typename base_type::hash_type      hash_type;
		typedef typename base_type::mapped_type    mapped_type;
		typedef size_t                             size_type;
		typedef ptrdiff_t                          difference_type;
		typedef value_type&                        reference;
		typedef const value_type&                  const_reference;

		typedef typename base_type::local_iterator       local_iterator;
		typedef typename base_type::const_local_iterator const_local_iterator;
		typedef typename base_type::iterator             iterator;
		typedef typename base_type::const_iterator       const_iterator;

		typedef pair< iterator, bool >                   insert_return_type;

	public:
		explicit hash_table() {}
		explicit hash_table( size_type size ) : base_type( size ) {}

		iterator       find( const key_type& key )
		{
			const hash_type c = base_type::get_hash( key );
			size_type       b = base_type::get_bucket_index( c );
			return do_find_node( b, key, c );
		}
		const_iterator find( const key_type& key ) const
		{
			const hash_type c = base_type::get_hash( key );
			size_type       b = base_type::get_bucket_index( c );
			return do_find_node( b, key, c );
		}

		inline insert_return_type insert( const value_type& value )
		{
			const hash_type h = base_type::get_value_hash( value );
			size_type       b = base_type::get_bucket_index( h );
			iterator        r = do_find_node( b, value, h );

			if ( r == this->end() )
			{
				if ( base_type::rehash_check( 1 ) )
				{
					b = base_type::get_bucket_index( h );
				}
				return insert_return_type( base_type::insert( b, value, h ), true );
			}

			return insert_return_type( r, false );
		}

		template <typename InputIterator>
		void insert( InputIterator first, InputIterator last )
		{
			size_type estimate = estimate_distance( first, last );
			base_type::rehash_check( estimate );
			for ( ; first != last; ++first ) insert( *first );
		}

		using base_type::erase;

		iterator erase( const_iterator first, const_iterator last )
		{
			while ( first != last )
				first = erase( first );
			return base_type::unconst( first );
		}

		size_type erase( const key_type& key )
		{
			// TODO : optimize for non-multi sets! (one key possible)
			const hash_type c = base_type::get_hash( key );
			size_type       b = base_type::get_bucket_index( c );
			const_local_iterator i = this->cbegin( b );
			size_type result = 0;
			while ( i != this->cend( b ) )
			{
				if ( base_type::entry_compare( i.entry(), key, c ) )
				{
					i = base_type::erase_local( b, i );
					++result;
				}
				else
					++i;
			}
			return result;
		}

	protected:

		inline insert_return_type insert_key( const key_type& key )
		{
			const hash_type h = base_type::get_hash( key );
			size_type       b = base_type::get_bucket_index( h );
			iterator        r = do_find_node( b, key, h );

			if ( r == this->end() )
			{
				if ( base_type::rehash_check( 1 ) )
				{
					b = base_type::get_bucket_index( h );
				}
				return insert_return_type( base_type::insert( b, value_type( key ), h ), true );
			}

			return insert_return_type( r, false );
		}

		template < typename QueryType >
		iterator do_find_node( size_type index, const QueryType& query, hash_type h ) const
		{
			const_local_iterator first = this->cbegin( index );
			const_local_iterator last  = this->cend( index );
			while ( first != last )
			{
				if ( base_type::entry_compare( first.entry(), query, h ) )
					return base_type::unlocalize( index, first );
				++first;
			}
			return base_type::unconst( cend() );
		}

	};

/*
	template < 
		typename Key, 
		typename T, 
		typename Compare,
		typename EntryPolicy,
		typename StoragePolicy
	>
	class associative_container : public StoragePolicy< EntryPolicy >
	{
	public:
		typedef Key   key_type;
		typedef T     value_type;
		typedef typename EntryPolicy::value_type value_type;
		typedef typename StoragePolicy::storage_type storage_type;
		typedef typename StoragePolicy::iterator iterator;
		typedef typename StoragePolicy::const_iterator const_iterator;
	public:
		associative_container();
		bool empty() const;
		bool size() const;
		bool max_size() const;
		void clear() const;
		void insert( const key_type& k, value_type&& obj );
		void insert( key_type&& k, value_type&& obj );
		void insert( entry_type&& obj );
		void insert( const entry_type& obj );
		template < typename ...Args >
		void emplace( key_type&& k, Args&&... args ); 
		template < typename ...Args >
		void emplace( const key_type& k, Args&&... args );

	};
*/
}

#endif // NV_STL_HASH_TABLE_HH
