// Copyright (C) 2012-2017 ChaosForge Ltd
// http://chaosforge.org/
//
// This file is part of Nova libraries. 
// For conditions of distribution and use, see copying.txt file in root folder.

#ifndef NV_CORE_RANDOM_HH
#define NV_CORE_RANDOM_HH

#include <nv/common.hh>
#include <nv/stl/math.hh>
#include <nv/stl/limits.hh>
#include <nv/stl/type_traits/primary.hh>

namespace nv
{
	enum class random_dist
	{
		LINEAR,
		GAUSSIAN,
		RGAUSSIAN,
		STEP_1,
		STEP_2,
		STEP_3,
		STEP_4,
		MLINEAR,
		MGAUSSIAN,
		MRGAUSSIAN,
		MSTEP_1,
		MSTEP_2,
		MSTEP_3,
		MSTEP_4,
	};


	template< typename T >
	struct random_range
	{
		T min;
		T max;
		random_dist dist = random_dist::LINEAR;
	};


	class random_base
	{
	public:
		typedef uint32 result_type; 
		typedef uint32 seed_type;

		virtual seed_type set_seed( seed_type seed = 0 ) = 0;
		virtual result_type rand() = 0;

		seed_type randomize()
		{
			return set_seed( randomized_seed() );
		}

		uint32 urand( uint32 val )
		{
			uint32 x, max = 0xFFFFFFFFUL - ( 0xFFFFFFFFUL % val );
			while ( ( x = rand() ) >= max );
			return x / ( max / val );
		}

		sint32 srand( sint32 val )
		{
			NV_ASSERT( val >= 0, "Bad srand range!" );
			return static_cast<sint32>( urand( static_cast<uint32>( val ) ) );
		}

		bool coin_flip()
		{
			return rand() % 2 == 0;
		}

		f32 frand()
		{
			return rand() / 4294967296.0f;
		}

		f32 frand( f32 val )
		{
			return frand() * val;
		}

		inline sint32 srange( sint32 min, sint32 max )
		{
			if ( max < min ) ::nv::swap( max, min );
			// this method probably reduces range //
			uint32 roll = urand( static_cast<uint32>( max - min ) + 1 );
			return static_cast<sint32>( roll ) + min;
		}

		inline uint32 urange( uint32 min, uint32 max )
		{
			if ( max < min ) ::nv::swap( max, min );
			return urand( max - min + 1 ) + min;
		}

		inline f32 frange( f32 min, f32 max )
		{
			if ( max < min ) ::nv::swap( max, min );
			return frand( max - min ) + min;
		}

		uint32 dice( uint32 count, uint32 sides )
		{
			if ( count == 0 || sides == 0 ) return 0;
			uint32 result = count;
			while ( count-- > 0 )
			{
				result += urand( sides );
			};
			return result;
		}

		template < typename T >
		T range( T min, T max )
		{
			return component_wise<T>::apply( min, max, range_op( this ) );

// 			return component_wise<T>::apply( min, max, [&]( auto a, auto b )
// 			{
// 				return range_impl( a, b );
// 			} );
		}


// 		template < typename T >
// 		math::tvec2<T> range( math::tvec2<T> min, math::tvec2<T> max )
// 		{
// 			return math::tvec2<T>(
// 				range_impl( min.x, max.x, is_floating_point<T>() ),
// 				range_impl( min.y, max.y, is_floating_point<T>() )
// 				);
// 		}
// 
// 		template < typename T >
// 		math::tvec3<T> range( math::tvec3<T> min, math::tvec3<T> max )
// 		{
// 			return math::tvec3<T>(
// 				range_impl( min.x, max.x, is_floating_point<T>() ),
// 				range_impl( min.y, max.y, is_floating_point<T>() ),
// 				range_impl( min.z, max.z, is_floating_point<T>() )
// 				);
// 		}
// 
// 		template < typename T >
// 		math::tvec4<T> range( math::tvec4<T> min, math::tvec4<T> max )
// 		{
// 			return math::tvec4<T>(
// 				range_impl( min.x, max.x, is_floating_point<T>() ),
// 				range_impl( min.y, max.y, is_floating_point<T>() ),
// 				range_impl( min.z, max.z, is_floating_point<T>() ),
// 				range_impl( min.w, max.w, is_floating_point<T>() )
// 				);
// 		}



		vec3 unit_vec3( bool = false )
		{
			return precise_unit_vec3();
			//			return precise ? precise_unit_vec3() : fast_unit_vec3();
		}
		vec2 unit_vec2( bool = false )
		{
			return precise_unit_vec2();
			//			return precise ? precise_unit_vec2() : fast_unit_vec2();
		}

		vec2 disk_point( bool precise = false )
		{
			return precise ? precise_disk_point() : fast_disk_point();
		}

		vec3 sphere_point( bool precise = false )
		{
			return precise ? precise_sphere_point() : fast_sphere_point();
		}

		vec2 ellipse_point( const vec2& radii, bool precise = false )
		{
			return precise ? precise_ellipse_point( radii ) : fast_ellipse_point( radii );
		}

		vec3 ellipsoid_point( const vec3& radii, bool precise = false )
		{
			return precise ? precise_ellipsoid_point( radii ) : fast_ellipsoid_point( radii );
		}

		vec2 ellipse_edge( const vec2& radii, bool = false )
		{
			return unit_vec2() * radii;
		}

		vec3 ellipsoid_edge( const vec3& radii, bool = false )
		{
			return unit_vec3() * radii;
		}

		vec2 hollow_disk_point( float iradius, float oradius, bool precise = false )
		{
			return precise ? precise_hollow_disk_point( iradius, oradius ) : fast_hollow_disk_point( iradius, oradius );
		}

		vec3 hollow_sphere_point( float iradius, float oradius, bool precise = false )
		{
			return precise ? precise_hollow_sphere_point( iradius, oradius ) : fast_hollow_sphere_point( iradius, oradius );
		}

		vec2 hollow_ellipse_point( const vec2& iradii, const vec2& oradii, bool precise = false )
		{
			return precise ? precise_hollow_ellipse_point( iradii, oradii ) : fast_hollow_ellipse_point( iradii, oradii );
		}

		vec3 hollow_ellipsoid_point( const vec3& iradii, const vec3& oradii, bool precise = false )
		{
			return precise ? precise_hollow_ellipsoid_point( iradii, oradii ) : fast_hollow_ellipsoid_point( iradii, oradii );
		}

		//vec2 fast_unit_vec2();
		vec2 precise_unit_vec2();
		//vec3 fast_unit_vec3();
		vec3 precise_unit_vec3();

		vec2 fast_disk_point();
		vec2 precise_disk_point();
		vec3 fast_sphere_point();
		vec3 precise_sphere_point();

		vec2 fast_ellipse_point( const vec2& radii )
		{
			return fast_disk_point() * radii;
		}
		vec2 precise_ellipse_point( const vec2& radii );

		vec3 fast_ellipsoid_point( const vec3& radii )
		{
			return fast_sphere_point() * radii;
		}

		vec3 precise_ellipsoid_point( const vec3& radii );

		vec2 fast_hollow_disk_point( float iradius, float oradius );
		vec2 precise_hollow_disk_point( float iradius, float oradius );
		vec3 fast_hollow_sphere_point( float iradius, float oradius );
		vec3 precise_hollow_sphere_point( float iradius, float oradius );

		vec2 fast_hollow_ellipse_point( const vec2& iradii, const vec2& oradii );
		vec2 precise_hollow_ellipse_point( const vec2& iradii, const vec2& oradii );
		vec3 fast_hollow_ellipsoid_point( const vec3& iradii, const vec3& oradii );
		vec3 precise_hollow_ellipsoid_point( const vec3& iradii, const vec3& oradii );

		// Box Muller
		template < typename T = f32 >
		T gaussian_bm()
		{
			T u1;
			T u2;
			do { u1 = frand(); u2 = frand(); } while ( u1 <= numeric_limits<T>::min() );
			T z0 = sqrt( T(-2) * log( u1 ) ) * cos( math::two_pi<T>() * u2 );
			return z0;
		}

		// Marsaglia
		template < typename T = f32 >
		T gaussian_m()
		{
			T x1, x2, w, y1;
			do
			{
				x1 = T( 2 ) * frand() - T( 1 );
				x2 = T( 2 ) * frand() - T( 1 );
				w = x1 * x1 + x2 * x2;
			} while ( w >= T( 1 ) );
			w = sqrt( ( T(-2) * log( w ) ) / w );
			y1 = x1 * w;
			return y1;
		}

		template < typename T = f32 >
		T gaussian_01( T sigma = T( 0.1 ) )
		{
			T g = T( 0.5 ) + gaussian_m<T>() * sigma * T(0.5);
			return math::clamp( g, T( 0 ), T( 1 ) );
		}

		template < typename T = f32 >
		T rgaussian_01( T sigma = T( 0.1 ) )
		{
			T g = gaussian_01<T>( sigma ) - T(0.5);
			if ( g < T(0) ) g += T(1);
			return math::clamp( g, T( 0 ), T( 1 ) );
		}

		template < typename T >
		T eval( const random_range<T>& r )
		{
			switch ( r.dist )
			{
			case random_dist::LINEAR     : return component_wise<T>::apply( r.min, r.max, range_op( this ) );
			case random_dist::GAUSSIAN   : return component_wise<T>::apply( r.min, r.max, gaussian_op( this ) );
			case random_dist::RGAUSSIAN  : return component_wise<T>::apply( r.min, r.max, rgaussian_op( this ) );
			case random_dist::STEP_1     : return component_wise<T>::apply( r.min, r.max, coin_flip_op( this ) );
			case random_dist::STEP_2     : return component_wise<T>::apply( r.min, r.max, step_op( this, 2 ) );
			case random_dist::STEP_3     : return component_wise<T>::apply( r.min, r.max, step_op( this, 3 ) );
			case random_dist::STEP_4     : return component_wise<T>::apply( r.min, r.max, step_op( this, 4 ) );
			case random_dist::MLINEAR    : return frand()                        * ( r.max - r.min ) + r.min;
			case random_dist::MGAUSSIAN  : return gaussian_01<float>()           * ( r.max - r.min ) + r.min;
			case random_dist::MRGAUSSIAN : return rgaussian_01<float>()          * ( r.max - r.min ) + r.min;
			case random_dist::MSTEP_1    : return coin_flip() ? r.min : r.max;
			case random_dist::MSTEP_2    : return fstep( r.min, r.max, 2 );
			case random_dist::MSTEP_3    : return fstep( r.min, r.max, 3 );
			case random_dist::MSTEP_4    : return fstep( r.min, r.max, 4 );
			default: return T();
			}
		}

	protected:
		static seed_type randomized_seed();

		struct range_op
		{
			range_op( random_base* base ) : self( base ) {}
			random_base* self;
			template <typename T, typename enable_if< is_floating_point<T>::value >::type* = nullptr>
			T operator()( T a, T b ) { return self->frange( a, b );  }
			template <typename T, typename enable_if< !is_floating_point<T>::value >::type* = nullptr>
			T operator()( T a, T b ) { return self->srange( a, b ); }
		};

		struct coin_flip_op
		{
			coin_flip_op( random_base* base ) : self( base ) {}
			random_base* self;
			template <typename T>
			T operator()( T a, T b ) { return self->coin_flip() ? a : b; }
		};

		struct step_op
		{
			step_op( random_base* base, uint32 s ) : self( base ), steps(s) {}
			uint32 steps;
			random_base* self;
			template <typename T>
			T operator()( T a, T b ) { return self->fstep( a, b, steps ); }
		};

		struct gaussian_op
		{
			gaussian_op( random_base* base ) : self( base ) {}
			random_base* self;
			template <typename T>
			T operator()( T a, T b ) { return self->gaussian_01<T>() * ( b - a ) + a; }
		};

		struct rgaussian_op
		{
			rgaussian_op( random_base* base ) : self( base ) {}
			random_base* self;
			template <typename T>
			T operator()( T a, T b ) { return ( self->rgaussian_01<T>() ) * ( b - a ) + a; }
		};

		template <typename T, typename enable_if< !math::is_vec<T>::value >::type* = nullptr >
		T fstep( T min, T max, uint32 steps )
		{
			return T( srand( steps + 1 ) ) * ( max - min ) / T( steps ) + min; 
		}

		template <typename T, typename enable_if< math::is_vec<T>::value >::type* = nullptr >
		T fstep( T min, T max, uint32 steps )
		{
			return value_type_t<T>( srand( steps + 1 ) ) * ( max - min ) / value_type_t<T>( steps ) + min;
		}

		template < typename T >
		struct component_wise
		{
			template < typename Functor >
			static T apply( T min, T max, typename enable_if< is_arithmetic<T>::value, Functor >::type f )
			{
				return f( min, max );
			}
		};

		template < typename T > 
		struct component_wise< math::tvec2< T > >
		{
			template < typename Functor >
			static math::tvec2<T> apply( math::tvec2<T> min, math::tvec2<T> max, Functor f )
			{
				return math::tvec2<T>(
					f( min.x, max.x ),
					f( min.y, max.y )
					);
			}
		};

		template < typename T >
		struct component_wise< math::tvec3< T > >
		{
			template < typename Functor >
			static math::tvec3<T> apply( math::tvec3<T> min, math::tvec3<T> max, Functor f )
			{
				return math::tvec3<T>(
					f( min.x, max.x ),
					f( min.y, max.y ),
					f( min.z, max.z )
					);
			}
		};

		template < typename T >
		struct component_wise< math::tvec4< T > >
		{
			template < typename Functor >
			static math::tvec4<T> apply( math::tvec4<T> min, math::tvec4<T> max, Functor f )
			{
				return math::tvec4<T>(
					f( min.x, max.x ),
					f( min.y, max.y ),
					f( min.z, max.z ),
					f( min.w, max.w )
					);
			}
		};

	};

	class random_mersenne : public random_base
	{
	public:
		static constexpr uint32 mersenne_n = 624;
		static constexpr uint32 mersenne_m = 397;
		static constexpr uint32 mersenne_static_seed = 5489;

		explicit random_mersenne( seed_type seed = randomized_seed() );
		virtual seed_type set_seed( seed_type seed = 0 );
		virtual result_type rand();
	private:
		void mt_init( uint32 seed );
		void mt_update();
		uint32 mt_uint32();

		uint32  m_state[mersenne_n];
		uint32* m_next;
		uint32  m_remaining;
		uint32  m_seeded : 1;
		uint32  m_static_system_seed : 1;
	};

	class random_xor128 : public random_base
	{
	public:
		explicit random_xor128( seed_type seed = randomized_seed() );
		random_xor128( const random_xor128& other )
		{
			m_state[0] = other.m_state[0];
			m_state[1] = other.m_state[1];
			m_state[2] = other.m_state[2];
			m_state[3] = other.m_state[3];
		}
		virtual seed_type set_seed( seed_type seed = randomized_seed() );
		virtual result_type rand();
	private:
		uint32  m_state[4]; // xyzw
	};

	class random : public random_mersenne
	{
	public:
		explicit random( seed_type seed = randomized_seed() ) : random_mersenne( seed ) {};
		static random& get();
	};

}

NV_RTTI_DECLARE( nv::random_dist )

#endif // NV_CORE_RANDOM_HH
