// Copyright (C) 2015-2016 ChaosForge Ltd
// http://chaosforge.org/
//
// This file is part of Nova libraries. 
// For conditions of distribution and use, see copying.txt file in root folder.

// WARNING: this file is explicitly designed to fuck with your brain

#ifndef NV_INTERFACE_INTERPOLATE_HH
#define NV_INTERFACE_INTERPOLATE_HH

#include <nv/common.hh>
#include <nv/core/transform.hh>
#include <nv/stl/math.hh>
#include <nv/interface/data_descriptor.hh>

namespace nv
{

	template < typename T >
	struct no_interpolator
	{
		inline static T interpolate( float, const T& lhs, const T& ) { return lhs; }
		inline static T interpolate( float, const T& , const T& v1, const T& , const T& ) { return v1; }
	};

	template < typename T >
	struct linear_interpolator
	{
		inline static T interpolate( float f, const T& lhs, const T& rhs ) { return math::lerp( lhs, rhs, f ); }
		inline static T interpolate( float f, const T& , const T& v1, const T& v2, const T& ) { return math::lerp( v1, v2, f ); }
	};

	template < typename T >
	struct normalized_interpolator
	{
		inline static T interpolate( float f, const T& lhs, const T& rhs ) { return math::lerp( lhs, rhs, f ); }
		inline static T interpolate( float f, const T& , const T& v1, const T& v2, const T& ) { return math::lerp( v1, v2, f ); }
	};

	template <>
	struct normalized_interpolator< quat >
	{
		inline static quat interpolate( float f, const quat& lhs, const quat& rhs ) { return math::nlerp( lhs, rhs, f ); }
		inline static quat interpolate( float f, const quat& , const quat& v1, const quat& v2, const quat& ) { return math::nlerp( v1, v2, f ); }
	};

	template <>
	struct normalized_interpolator< transform >
	{
		inline static transform interpolate( float f, const transform& lhs, const transform& rhs ) { return transform( math::lerp( lhs.get_position(), rhs.get_position(), f ), math::nlerp( lhs.get_orientation(), rhs.get_orientation(), f ) ); }
		inline static transform interpolate( float f, const transform& , const transform& v1, const transform& v2, const transform& ) { return transform( math::lerp( v1.get_position(), v2.get_position(), f ), math::nlerp( v1.get_orientation(), v2.get_orientation(), f ) ); }
	};

	template < typename T >
	struct spherical_interpolator
	{
		inline static T interpolate( float f, const T& lhs, const T& rhs ) { return math::lerp( lhs, rhs, f ); }
		inline static T interpolate( float f, const T& , const T& v1, const T& v2, const T& ) { return math::lerp( v1, v2, f ); }
	};

	template <>
	struct spherical_interpolator< quat >
	{
		inline static quat interpolate( float f, const quat& lhs, const quat& rhs ) { return math::slerp( lhs, rhs, f ); }
		inline static quat interpolate( float f, const quat& , const quat& v1, const quat& v2, const quat& ) { return math::slerp( v1, v2, f ); }
	};

	template <>
	struct spherical_interpolator< transform >
	{
		inline static transform interpolate( float f, const transform& lhs, const transform& rhs ) { return transform( math::lerp( lhs.get_position(), rhs.get_position(), f ), math::slerp( lhs.get_orientation(), rhs.get_orientation(), f ) ); }
		inline static transform interpolate( float f, const transform& , const transform& v1, const transform& v2, const transform& ) { return transform( math::lerp( v1.get_position(), v2.get_position(), f ), math::slerp( v1.get_orientation(), v2.get_orientation(), f ) ); }
	};

	struct quadratic_interpolator_base
	{
		float weights[4];

		quadratic_interpolator_base( float value ) 
		{
			float interp_squared = value*value;
			float interp_cubed = interp_squared*value;
			weights[0] = 0.5f * ( -interp_cubed + 2.0f * interp_squared - value );
			weights[1] = 0.5f * ( 3.0f * interp_cubed - 5.0f * interp_squared + 2.0f );
			weights[2] = 0.5f * ( -3.0f * interp_cubed + 4.0f * interp_squared + value );
			weights[3] = 0.5f * ( interp_cubed - interp_squared );
		}
	};

	template < typename T >
	struct quadratic_interpolator : public quadratic_interpolator_base
	{
		using quadratic_interpolator_base::quadratic_interpolator_base;
		inline static T interpolate( float f, const T& lhs, const T& rhs ) { return math::lerp( lhs, rhs, f ); }
		inline T interpolate( float, const T& v0, const T& v1, const T& v2, const T& v3 ) const
		{
			return
				weights[0] * v0 +
				weights[1] * v1 +
				weights[2] * v2 +
				weights[3] * v3;
		}
	};

	template <>
	struct quadratic_interpolator< quat > : public quadratic_interpolator_base
	{
		using quadratic_interpolator_base::quadratic_interpolator_base;
		inline static quat interpolate( float f, const quat& lhs, const quat& rhs ) { return math::lerp( lhs, rhs, f ); }
		inline quat interpolate( float, const quat& v0, const quat& v1, const quat& v2, const quat& v3 ) const
		{
			float a = dot( v1, v2 ) > 0.0f ? 1.0f : -1.0f;
			return normalize(
				weights[0] * v0 +
				weights[1] * ( a * v1 ) +
				weights[2] * v2 +
				weights[3] * v3
				);
		}
	};

	template <>
	struct quadratic_interpolator< transform > : public quadratic_interpolator_base
	{
		using quadratic_interpolator_base::quadratic_interpolator_base;
		inline static transform interpolate( float f, const transform& lhs, const transform& rhs ) { return math::lerp( lhs, rhs, f ); }
		inline transform interpolate( float, const transform& v0, const transform& v1, const transform& v2, const transform& v3 ) const
		{
			float a = dot( v1.get_orientation(), v2.get_orientation() ) > 0.0f ? 1.0f : -1.0f;
			return transform( 
				weights[0] * v0.get_position() +
				weights[1] * v1.get_position() +
				weights[2] * v2.get_position() +
				weights[3] * v3.get_position(),
				normalize(
				weights[0] * v0.get_orientation() +
				weights[1] * ( a * v1.get_orientation() ) +
				weights[2] * v2.get_orientation() +
				weights[3] * v3.get_orientation()
				) );
		}
	};

	template < typename T >
	struct squad_interpolator
	{
		inline static T interpolate( float f, const T& lhs, const T& rhs ) { return math::slerp( lhs, rhs, f ); }
		inline static T interpolate( float f, const T& v0, const T& v1, const T& v2, const T& v3 ) { return math::slerp( v1, v2, f ); }
	};

	template <>
	struct squad_interpolator< quat >
	{
		inline static quat interpolate( float f, const quat& lhs, const quat& rhs ) { return math::slerp( lhs, rhs, f ); }
		inline static quat interpolate( float f, const quat& v0, const quat& v1, const quat& v2, const quat& v3 )
		{
			return normalize( math::squad(
				v1, v2,
				math::intermediate( v0, v1, v2 ),
				math::intermediate( v1, v2, v3 ),
				f ) );
		}
	};

	template <>
	struct squad_interpolator< transform >
	{
		inline static transform interpolate( float f, const transform& lhs, const transform& rhs ) { return spherical_interpolator<transform>::interpolate( f, lhs, rhs ); }
		inline static transform interpolate( float f, const transform& v0, const transform& v1, const transform& v2, const transform& v3 )
		{
			return transform(
				mix( v1.get_position(), v2.get_position(), f ),
				squad_interpolator< quat >::interpolate( f,
					v0.get_orientation(),
					v1.get_orientation(),
					v2.get_orientation(),
					v3.get_orientation()
					) );
		}
	};


	enum class interpolation
	{
		NONE,
		LINEAR,
		NORMALIZED,
		SPHERICAL,
		QUADRATIC,
		SQUADRATIC,
	};

	template < typename T, typename Interpolator, typename ...Args >
	T interpolate( float f, const Interpolator& i, Args&&... args )
	{
		NV_UNUSED( i );
		return i.interpolate( f, ::nv::forward<Args>( args )... );
	}
	
	template < typename T, typename ...Args >
	T interpolate( float f, interpolation i, Args&&... args )
	{
		switch ( i )
		{
		case interpolation::LINEAR        : return linear_interpolator<T>::interpolate( f, ::nv::forward<Args>( args )... );
		case interpolation::NORMALIZED    : return normalized_interpolator<T>::interpolate( f, ::nv::forward<Args>( args )... );
		case interpolation::SPHERICAL     : return spherical_interpolator<T>::interpolate( f, ::nv::forward<Args>( args )... );
		case interpolation::QUADRATIC     : return quadratic_interpolator<T>::interpolate( f, ::nv::forward<Args>( args )... );
		case interpolation::SQUADRATIC    : return squad_interpolator<T>::interpolate( f, ::nv::forward<Args>( args )... );
		default: case interpolation::NONE : return no_interpolator<T>::interpolate( f, ::nv::forward<Args>( args )... );
		}
	}

	template < typename T, typename ...Args >
	void interpolate( T& result, float f, interpolation i, Args&&... args )
	{
		typedef typename T::value_type value_type;
		switch ( i )
		{
		case interpolation::LINEAR        : interpolate_array( result, f, linear_interpolator<value_type>(), ::nv::forward<Args>( args )... ); break;
		case interpolation::NORMALIZED    : interpolate_array( result, f, normalized_interpolator<value_type>(), ::nv::forward<Args>( args )... ); break;
		case interpolation::SPHERICAL     : interpolate_array( result, f, spherical_interpolator<value_type>(), ::nv::forward<Args>( args )... ); break;
		case interpolation::QUADRATIC     : interpolate_array( result, f, quadratic_interpolator<value_type>( f ), ::nv::forward<Args>( args )... ); break;
		case interpolation::SQUADRATIC    : interpolate_array( result, f, squad_interpolator<value_type>(), ::nv::forward<Args>( args )... ); break;
		default: case interpolation::NONE : interpolate_array( result, f, no_interpolator<value_type>(), ::nv::forward<Args>( args )... ); break;
		}
	}

	template < typename T, typename Interpolator >
	auto interpolate_extract( float f, const Interpolator& i, uint32 n, const T& a1, const T& a2 )
		-> typename T::value_type
	{
		NV_UNUSED( i );
		return i.interpolate( f, a1[n], a2[n] );
	}

	template < typename T, typename Interpolator >
	auto interpolate_extract( float f, const Interpolator& i, uint32 n, const T& a0, const T& a1, const T& a2, const T& a3 )
		-> typename T::value_type
	{
		NV_UNUSED( i );
		return i.interpolate( f, a0[n], a1[n], a2[n], a3[n] );
	}


	template < typename T, typename Interpolator, typename ...Args >
	void interpolate_array( T& result, float f, const Interpolator& in, Args&&... args )
	{
		uint32 size = result.size();
		for ( uint32 n = 0; n < size; ++n )
		{
			result[n] = interpolate_extract( f, in, n, ::nv::forward<Args>( args )... );
		}
	}

	template < typename T, typename Interpolator, typename ...Args >
	void interpolate_array( T& a, float f, const Interpolator& in, const array_view< bool >& mask, Args&&... args )
	{
		uint32 size = a.size();
		for ( uint32 n = 0; n < size; ++n )
			if ( mask[ n ] )
			{
				a[n] = interpolate_extract( f, in, n, ::nv::forward<Args>( args )... ); 
			}
	}

	template < typename T, typename Interpolator, typename... Args >
	void interpolate_array( T& result, float f, const Interpolator& in, float blend_factor, interpolation bi, Args&&... args )
	{
		typedef typename T::value_type value_type;
		switch ( bi )
		{
		case interpolation::LINEAR        : interpolate_blend( result, f, in, blend_factor, linear_interpolator<value_type>(), ::nv::forward<Args>( args )... ); break;
		case interpolation::NORMALIZED    : interpolate_blend( result, f, in, blend_factor, normalized_interpolator<value_type>(), ::nv::forward<Args>( args )... ); break;
		case interpolation::SPHERICAL     : interpolate_blend( result, f, in, blend_factor, spherical_interpolator<value_type>(), ::nv::forward<Args>( args )... ); break;
		case interpolation::QUADRATIC     : interpolate_blend( result, f, in, blend_factor, normalized_interpolator<value_type>(), ::nv::forward<Args>( args )... ); break;
		case interpolation::SQUADRATIC    : interpolate_blend( result, f, in, blend_factor, spherical_interpolator<value_type>(), ::nv::forward<Args>( args )... ); break;
		default: case interpolation::NONE : interpolate_blend( result, f, in, blend_factor, no_interpolator<value_type>(), ::nv::forward<Args>( args )... ); break;
		}
	}

	template < typename T, typename Interpolator, typename BlendInterpolator, typename... Args >
	void interpolate_blend( T& a, float f, const Interpolator& in, float blend_factor, const BlendInterpolator& bin, Args&&... args )
	{
		NV_UNUSED( bin );
		typedef typename T::value_type value_type;
		uint32 size = a.size();
		for ( uint32 n = 0; n < size; ++n )
		{
			value_type temp = interpolate_extract( f, in, n, ::nv::forward<Args>( args )... );
			a[n] = bin.interpolate( blend_factor, a[n], temp );
		}
	}

	template < typename T, typename Interpolator, typename BlendInterpolator, typename... Args >
	void interpolate_blend( T& a, float f, const Interpolator& in, float blend_factor, const BlendInterpolator& bin, const array_view< bool >& mask, Args&&... args )
	{
		NV_UNUSED( bin );
		typedef typename T::value_type value_type;
		uint32 size = a.size();
		for ( uint32 n = 0; n < size; ++n )
			if ( mask[n] )
			{
				value_type temp = interpolate_extract( f, in, n, ::nv::forward<Args>( args )... );
				a[n] = bin.interpolate( blend_factor, a[n], temp );
			}
	}



}
#endif // NV_INTERFACE_INTERPOLATE_HH
