// Copyright (C) 2014-2015 ChaosForge Ltd
// http://chaosforge.org/
//
// This file is part of Nova libraries. 
// For conditions of distribution and use, see copying.txt file in root folder.

#ifndef NV_GFX_ANIMATION_HH
#define NV_GFX_ANIMATION_HH

#include <nv/common.hh>
#include <nv/stl/vector.hh>
#include <nv/interface/stream.hh>
#include <nv/stl/math.hh>
#include <nv/interface/data_channel.hh>
#include <nv/interface/interpolation_raw.hh>
#include <nv/interface/interpolation_template.hh>
#include <nv/core/transform.hh>

namespace nv
{

	class key_data : public data_channel_set
	{
	public:
		key_data() {}

		void add_key_channel( raw_data_channel* channel ) 
		{
			NV_ASSERT( channel, "nullptr passed to add_channel!" );
			add_channel( channel );
			for ( const auto& cslot : channel->descriptor() )
				if ( cslot.vslot != slot::TIME )
					m_final_key.push_slot( cslot.etype, cslot.vslot );
		}

		mat4 get_raw_matrix( uint32 index ) const 
		{
			float key[ 16 ];
			float* pkey = key;
			for ( uint16 i = 0; i < size(); ++i )
			{
				pkey += get_raw( m_channels[i], index, pkey );
			}
			return extract_matrix_raw( m_final_key, key );
		}

		transform get_raw_transform( uint32 index ) const 
		{
			float key[ 16 ];
			float* pkey = key;
			for ( uint16 i = 0; i < size(); ++i )
			{
				pkey += get_raw( m_channels[i], index, pkey );
			}
			return extract_transform_raw( m_final_key, key );
		}

		mat4 get_matrix( float time ) const
		{
			float key[ 16 ];
			float* pkey = key;
			for ( uint16 i = 0; i < size(); ++i )
			{
				pkey += interpolate_raw( m_channels[i], time, pkey );
			}
			return extract_matrix_raw( m_final_key, key );
		}

		transform get_transform( float time ) const
		{
			float key[ 16 ];
			float* pkey = key;
			for ( uint16 i = 0; i < size(); ++i )
			{
				pkey += interpolate_raw( m_channels[i], time, pkey );
			}
			return extract_transform_raw( m_final_key, key );
		}

		const data_descriptor& get_final_key() const { return m_final_key; }

		static uint32 get_raw( const raw_data_channel* channel, uint32 index, float* result )
		{
			if ( channel->size() == 0 ) return 0;
			uint32 keyfsize = channel->element_size() / 4;
			const float* fdata = reinterpret_cast<const float*>( channel->raw_data() ) + keyfsize * index;
			uint32 mod = 0;
			if ( channel->descriptor()[0].vslot == slot::TIME ) mod = 1;
			raw_copy_n( fdata + mod, keyfsize - mod, result );
			return keyfsize - mod;
		}

		static uint32 interpolate_raw( const raw_data_channel* channel, float time, float* result )
		{
			if ( channel->size() == 0 ) return 0;
			uint32 keyfsize = channel->element_size() / 4;
			uint32 keyfresult = keyfsize;
			const float* fdata = reinterpret_cast<const float*>( channel->raw_data() );

			uint32 slot = 0;
			int index0 = -1;
			int index1 = -1;
			float factor = 1.0f;
			if ( channel->descriptor()[0].vslot == slot::TIME )
			{
				NV_ASSERT( channel->descriptor()[0].offset == 0, "time offset not zero!" );
				slot++;
				keyfresult--;
				if ( channel->size() == 1 )
				{
					raw_copy_n( fdata + 1, keyfresult, result );
					return keyfresult;
				}
				for ( unsigned i = 1; i < channel->size(); i++ )
				{
					if ( time < fdata[i * keyfsize] )
					{
						index0 = static_cast<int>( i ) - 1;
						break;
					}
				}
				NV_ASSERT( index0 >= 0, "animation time fail!" );
				index1 = index0 + 1;
				float time0 = fdata[index0 * static_cast<int>( keyfsize )];
				float time1 = fdata[index1 * static_cast<int>( keyfsize )];
				float delta = time1 - time0;
				factor = glm::clamp( ( time - time0 ) / delta, 0.0f, 1.0f );
			}
			else
			{
				if ( channel->size() == 1 )
				{
					raw_copy_n( fdata, keyfresult, result );
					return keyfresult;
				}
				index0 = glm::clamp<int>( int( time ), 0, int( channel->size() ) - 2 );
				index1 = index0 + 1;
				factor = glm::clamp<float>( time - index0, 0.0f, 1.0f );
			}
			uint32 ret = 0;
			for ( ; slot < channel->descriptor().size(); ++slot )
			{
				ret += nv::interpolate_raw(
					channel->descriptor()[slot], factor,
					fdata + index0 * static_cast<int>( keyfsize ) + channel->descriptor()[slot].offset / 4,
					fdata + index1 * static_cast<int>( keyfsize ) + channel->descriptor()[slot].offset / 4,
					result + ret );
			}
			return ret;
		}

	private:
		data_descriptor m_final_key;
	};

	template < typename KEY, bool TIMED >
	class key_channel_interpolator;


 	template < typename KEY >
	class key_channel_interpolator< KEY, false >
	{
	public:
		key_channel_interpolator() : m_data( nullptr ) {}
		key_channel_interpolator( const raw_data_channel* data ) : m_data( nullptr ) { set_data( data ); }
		key_channel_interpolator( const raw_data_channel* data, bool ) : m_data( data ) {}
		void set_data( const raw_data_channel* data )
		{
			m_data = data;
			data_descriptor desc;
			desc.initialize<KEY>();
			NV_ASSERT( data->descriptor() == desc, "Bad channel passed!" );
		}
		void get_interpolated( KEY& result, float frame ) const
		{
			NV_ASSERT( m_data, "Data is null!" );
			const KEY* keys = m_data->data_cast<KEY>( );
			uint32 count = m_data->size();
			if ( count == 0 ) return;
			if ( count == 1 )
			{
				result = keys[0];
				return;
			}
			size_t index = glm::clamp<size_t>( size_t( frame ), 0, count - 2 );
			float factor = glm::clamp<float> ( frame - index, 0.0f, 1.0f );
			interpolate_key( result, keys[index], keys[index+1], factor );
		}

	private:
		const raw_data_channel* m_data;
	};
 
 	template < typename KEY >
 	class key_channel_interpolator< KEY, true >
	{
	public:
		key_channel_interpolator() : m_data( nullptr ) {}
		key_channel_interpolator( const raw_data_channel* data ) : m_data( nullptr ) { set_data( data ); }
		void set_data( const raw_data_channel* data )
		{
			m_data = data;
			data_descriptor desc;
			desc.initialize<KEY>();
			NV_ASSERT( data->descriptor() == desc, "Bad channel passed!" );
		}
		void get_interpolated( KEY& result, float time ) const
		{
			// TODO: this probably could be optimized
			NV_ASSERT( m_data, "Data is null!" );
			const KEY* keys = m_data->data_cast<KEY>( );
			uint32 count = m_data->size();
			if ( count == 0 ) return;
			if ( count == 1 ) 
			{
				result = keys[0];
				return;
			}
			int index = -1;
			for ( int i = 0 ; i < int( count ) - 1 ; i++ )
			{
				if ( time < keys[i + 1].time ) { index = i; break; }
			}
			NV_ASSERT( index >= 0, "animation time fail!");
			float delta  = keys[index + 1].time - keys[index].time;
			float factor = glm::clamp( (time - keys[index].time) / delta, 0.0f, 1.0f );
			interpolate_key( result, keys[index], keys[index+1], factor );
		}

	private:
		const raw_data_channel* m_data;
	};
 
// 	template < typename KEY1, typename KEY2 = void, typename KEY3 = void >
// 	class key_data_interpolator
// 	{
// 
// 	};

	class key_animation_data
	{
	public:
		virtual mat4 get_matrix( float time ) const = 0;
		virtual transform get_transform( float time ) const = 0;
		virtual bool empty() const = 0;
		virtual size_t size() const = 0;
		virtual uint32 raw_size() const = 0;
		virtual ~key_animation_data() {}
	};


	class key_vectors_prs : public key_animation_data
	{
		struct key_p { float time; vec3 position; };
		struct key_r { float time; quat rotation; };
		struct key_s { float time; vec3 scale; };
	public:
		explicit key_vectors_prs( const raw_data_channel* p, const raw_data_channel* r, const raw_data_channel* s )
		{
			m_pchannel = p;
			m_rchannel = r;
			m_schannel = s;
			m_pinter.set_data( m_pchannel );
			m_rinter.set_data( m_rchannel );
			m_sinter.set_data( m_schannel );
		}
		size_t size() const { return 0; } // TODO: remove?
		bool empty() const { 
			return m_pchannel->size() == 0
				&& m_rchannel->size() == 0
				&& m_schannel->size() == 0; }
		virtual mat4 get_matrix( float time ) const
		{
			key_p p;
			key_r r;
			key_s s;

			m_pinter.get_interpolated( p, time );
			m_rinter.get_interpolated( r, time );
			m_sinter.get_interpolated( s, time );

			return extract_matrix( p ) * extract_matrix( r ) * extract_matrix( s );
		}
		virtual transform get_transform( float time ) const
		{
			key_p p;
			key_r r;

			m_pinter.get_interpolated( p, time );
			m_rinter.get_interpolated( r, time );

			return transform( p.position, r.rotation );
		}
		virtual uint32 raw_size() const 
		{
			return 3 * sizeof( size_t ) 
				+ m_pchannel->size() * sizeof( key_p )
				+ m_rchannel->size() * sizeof( key_r )
				+ m_schannel->size() * sizeof( key_s );
		}
		~key_vectors_prs()
		{
		}
	protected:
		const raw_data_channel* m_pchannel;
		const raw_data_channel* m_rchannel;
		const raw_data_channel* m_schannel;
		key_channel_interpolator< key_p, true > m_pinter;
		key_channel_interpolator< key_r, true > m_rinter;
		key_channel_interpolator< key_s, true > m_sinter;
	};

	class transform_vector : public key_animation_data
	{
		struct key
		{
			transform tform;
		};
	public:
		explicit transform_vector( const raw_data_channel* channel )
		{
			data_descriptor kd;
			kd.initialize<key>();
			NV_ASSERT( kd == channel->descriptor(), "bad channel!" );
			m_channel = channel;
			m_interpolator.set_data( m_channel );
		}

		~transform_vector()
		{
			delete m_channel;
		}
		bool empty() const { return m_channel->size() == 0; }
		size_t size() const { return m_channel->size(); }
		const transform& get( size_t index ) const { return m_channel->data_cast< key >()[ index ].tform; }
		const transform* data() const { return &(m_channel->data_cast< key >()[0].tform); }

		virtual uint32 raw_size() const 
		{
			return sizeof( size_t ) + m_channel->size() * sizeof( key );
		}

		virtual mat4 get_matrix( float time ) const
		{
			return get_transform( time ).extract();
		}
		virtual transform get_transform( float time ) const
		{
			key result;
			m_interpolator.get_interpolated( result, time );
			return extract_transform< key >( result );
		}
	protected:
		key_channel_interpolator< key, false > m_interpolator;
		const raw_data_channel* m_channel;
	};


}

#endif // NV_GFX_ANIMATION_HH
