// Copyright (C) 2014 ChaosForge / Kornel Kisielewicz
// http://chaosforge.org/
//
// This file is part of NV Libraries.
// For conditions of distribution and use, see copyright notice in nv.hh

#ifndef NV_ANIMATION_HH
#define NV_ANIMATION_HH

#include <nv/common.hh>
#include <vector>
#include <nv/interface/stream.hh>
#include <nv/math.hh>
#include <nv/interface/animation_key.hh>
#include <nv/interface/interpolation_raw.hh>
#include <nv/interface/interpolation_template.hh>
#include <nv/transform.hh>
#include <glm/gtc/matrix_transform.hpp>

namespace nv
{

	struct key_raw_channel
	{
		key_descriptor desc;
		uint8*         data;
		uint32         count;

		key_raw_channel() : data( nullptr ), count( 0 ) {}
		~key_raw_channel() 
		{
			if ( data != nullptr ) delete[] data;
		}

		uint32 size() const { return count * desc.size; }

		template < typename KEY >
		static key_raw_channel* create( uint32 count = 0 )
		{
			key_raw_channel* result = new key_raw_channel();
			result->desc.initialize<KEY>();
			result->count = count;
			result->data  = (count > 0 ? ( new uint8[ result->size() ] ) : nullptr );
			return result;
		}

		static key_raw_channel* create( const key_descriptor& keydesc, uint32 count = 0 )
		{
			key_raw_channel* result = new key_raw_channel();
			result->desc  = keydesc;
			result->count = count;
			result->data  = (count > 0 ? ( new uint8[ result->size() ] ) : nullptr );
			return result;
		}

		uint32 get_raw( uint32 index, float* result ) const 
		{
			if ( count == 0 ) return 0;
			uint32 keyfsize   = desc.size / 4;
			const float* fdata = ((const float*)data) + keyfsize * index;
			uint32 mod        = 0;
			if ( desc.slots[0].vslot == animation_slot::TIME ) mod = 1;
			std::copy_n( fdata + mod, keyfsize - mod, result );
			return keyfsize - mod;
		}

		uint32 interpolate_raw( float time, float* result ) const 
		{
			if ( count == 0 ) return 0;
			uint32 keyfsize   = desc.size / 4;
			uint32 keyfresult = keyfsize;
			const float* fdata = (const float*)data;

			uint32 slot = 0;
			int index    = -1;
			float factor = 1.0f;
			if ( desc.slots[0].vslot == animation_slot::TIME )
			{
				slot++;
				keyfresult--;
				if ( count == 1 ) 
				{
					std::copy_n( fdata + 1, keyfresult, result );
					return keyfresult;
				}
				uint32 toffset = desc.slots[0].offset / 4;
				for ( int i = 0 ; i < (int)count - 1 ; i++ )
				{
					if ( time < fdata[ i * keyfsize + keyfsize + toffset ] ) { index = i; break; }
				}
				NV_ASSERT( index >= 0, "animation time fail!");
				float time0  = fdata[ index * keyfsize + toffset ];
				float time1  = fdata[ index * keyfsize + keyfsize + toffset ];
				float delta  = time1 - time0;
				factor = glm::clamp( (time - time0) / delta, 0.0f, 1.0f );
			}
			else
			{
				if ( count == 1 ) 
				{
					std::copy_n( fdata, keyfresult, result );
					return keyfresult;
				}
				index  = glm::clamp<int>( int( time ), 0, count - 2 );
				factor = glm::clamp<float> ( time - index, 0.0f, 1.0f );
			}
			uint32 ret = 0;
			for ( ; slot < desc.count; ++slot )
			{
				ret += nv::interpolate_raw( 
					desc.slots[slot], factor, 
					fdata + index * keyfsize + desc.slots[slot].offset / 4,
					fdata + index * keyfsize + keyfsize + desc.slots[slot].offset / 4,
					result + ret );
			}
			return ret;
		}
	};

	class key_data
	{
	public:
		key_data() {}

		void add_channel( key_raw_channel* channel ) 
		{
			NV_ASSERT( channel, "nullptr passed to add_channel!" );
			m_channels.push_back( channel );
			for ( uint16 i = 0; i < channel->desc.count; ++i )
			{
				const key_descriptor_slot& ksi = channel->desc.slots[i];
				if ( ksi.vslot != animation_slot::TIME )
				{
					uint32 index = final_key.count;
					final_key.slots[ index ].offset = final_key.size;
					final_key.slots[ index ].etype  = ksi.etype;
					final_key.slots[ index ].vslot  = ksi.vslot;
					final_key.size += get_datatype_info( ksi.etype ).size;
					final_key.count++;
				}
			}
		}

		mat4 get_raw_matrix( uint32 index ) const 
		{
			float key[ 16 ];
			float* pkey = key;
			for ( uint16 i = 0; i < m_channels.size(); ++i )
			{
				pkey += m_channels[i]->get_raw( index, pkey );
			}
			return extract_matrix_raw( final_key, key );
		};

		transform get_raw_transform( uint32 index ) const 
		{
			float key[ 16 ];
			float* pkey = key;
			for ( uint16 i = 0; i < m_channels.size(); ++i )
			{
				pkey += m_channels[i]->get_raw( index, pkey );
			}
			return extract_transform_raw( final_key, key );
		};

		mat4 get_matrix( float time ) const
		{
			float key[ 16 ];
			float* pkey = key;
			for ( uint16 i = 0; i < m_channels.size(); ++i )
			{
				pkey += m_channels[i]->interpolate_raw( time, pkey );
			}
			return extract_matrix_raw( final_key, key );
		};

		transform get_transform( float time ) const
		{
			float key[ 16 ];
			float* pkey = key;
			for ( uint16 i = 0; i < m_channels.size(); ++i )
			{
				pkey += m_channels[i]->interpolate_raw( time, pkey );
			}
			return extract_transform_raw( final_key, key );
		};

		size_t get_channel_count() const { return m_channels.size(); }
		const key_raw_channel* get_channel( size_t index ) const { return m_channels[ index ]; }

		virtual ~key_data()
		{
			for ( auto channel : m_channels ) delete channel;
		}
	private:
		key_descriptor final_key;
		std::vector< key_raw_channel* > m_channels;
	};

	template < typename KEY, bool TIMED >
	class key_channel_interpolator;


 	template < typename KEY >
	class key_channel_interpolator< KEY, false >
	{
	public:
		key_channel_interpolator() : m_data( nullptr ) {}
		key_channel_interpolator( const key_raw_channel* data ) : m_data( nullptr ) { set_data( data ); }
		key_channel_interpolator( const key_raw_channel* data, bool ) : m_data( data ) {}
		void set_data( const key_raw_channel* data )
		{
			m_data = data;
			key_descriptor desc;
			desc.initialize<KEY>();
			NV_ASSERT( data->desc == desc, "Bad channel passed!" );
		}
		void get_interpolated( KEY& result, float frame ) const
		{
			NV_ASSERT( m_data, "Data is null!" );
			if ( m_data->count == 0 ) return;
			if ( m_data->count == 1 ) 
			{
				result = ((KEY*)m_data->data)[0];
				return;
			}
			size_t index = glm::clamp<size_t>( size_t( frame ), 0, m_data->count - 2 );
			float factor = glm::clamp<float> ( frame - index, 0.0f, 1.0f );
			KEY* keys = ((KEY*)m_data->data);
			interpolate_key( result, keys[index], keys[index+1], factor );
		}

	private:
		const key_raw_channel* m_data;
	};
 
 	template < typename KEY >
 	class key_channel_interpolator< KEY, true >
	{
	public:
		key_channel_interpolator() : m_data( nullptr ) {}
		key_channel_interpolator( const key_raw_channel* data ) : m_data( nullptr ) { set_data( data ); }
		void set_data( const key_raw_channel* data )
		{
			m_data = data;
			key_descriptor desc;
			desc.initialize<KEY>();
			NV_ASSERT( data->desc == desc, "Bad channel passed!" );
		}
		void get_interpolated( KEY& result, float time ) const
		{
			// TODO: this probably could be optimized
			const KEY* keys = (const KEY*)(m_data->data);
			NV_ASSERT( m_data, "Data is null!" );
			if ( m_data->count == 0 ) return;
			if ( m_data->count == 1 ) 
			{
				result = keys[0];
				return;
			}
			int index = -1;
			for ( int i = 0 ; i < (int)m_data->count - 1 ; i++ )
			{
				if ( time < keys[i + 1].time ) { index = i; break; }
			}
			NV_ASSERT( index >= 0, "animation time fail!");
			float delta  = keys[index + 1].time - keys[index].time;
			float factor = glm::clamp( (time - keys[index].time) / delta, 0.0f, 1.0f );
			interpolate_key( result, keys[index], keys[index+1], factor );
		}

	private:
		const key_raw_channel* m_data;
	};
 
// 	template < typename KEY1, typename KEY2 = void, typename KEY3 = void >
// 	class key_data_interpolator
// 	{
// 
// 	};

	class key_animation_data
	{
	public:
		virtual mat4 get_matrix( float time ) const = 0;
		virtual transform get_transform( float time ) const = 0;
		virtual bool empty() const = 0;
		virtual size_t size() const = 0;
		virtual uint32 raw_size() const = 0;
		virtual ~key_animation_data() {}
	};


	class key_vectors_prs : public key_animation_data
	{
		struct key_p { float time; vec3 position; };
		struct key_r { float time; quat rotation; };
		struct key_s { float time; vec3 scale; };
	public:
		explicit key_vectors_prs( const key_raw_channel* p, const key_raw_channel* r, const key_raw_channel* s )
		{
			m_pchannel = p;
			m_rchannel = r;
			m_schannel = s;
			m_pinter.set_data( m_pchannel );
			m_rinter.set_data( m_rchannel );
			m_sinter.set_data( m_schannel );
		}
		size_t size() const { return 0; } // TODO: remove?
		bool empty() const { return m_pchannel->count == 0 && m_rchannel->count == 0 && m_schannel->count == 0; }
		virtual mat4 get_matrix( float time ) const
		{
			key_p p;
			key_r r;
			key_s s;

			m_pinter.get_interpolated( p, time );
			m_rinter.get_interpolated( r, time );
			m_sinter.get_interpolated( s, time );

			return extract_matrix( p ) * extract_matrix( r ) * extract_matrix( s );
		}
		virtual transform get_transform( float time ) const
		{
			key_p p;
			key_r r;

			m_pinter.get_interpolated( p, time );
			m_rinter.get_interpolated( r, time );

			return transform( p.position, r.rotation );
		}
		virtual uint32 raw_size() const 
		{
			return 3 * sizeof( size_t ) 
				+ m_pchannel->count * sizeof( key_p )
				+ m_rchannel->count * sizeof( key_r )
				+ m_schannel->count * sizeof( key_s );
		}
		~key_vectors_prs()
		{
		}
	protected:
		const key_raw_channel* m_pchannel;
		const key_raw_channel* m_rchannel;
		const key_raw_channel* m_schannel;
		key_channel_interpolator< key_p, true > m_pinter;
		key_channel_interpolator< key_r, true > m_rinter;
		key_channel_interpolator< key_s, true > m_sinter;
	};

	class transform_vector : public key_animation_data
	{
		struct key
		{
			transform tform;
		};
	public:
		explicit transform_vector( const key_raw_channel* channel ) 
		{
			key_descriptor kd;
			kd.initialize<key>();
			NV_ASSERT( kd == channel->desc, "bad channel!" );
			m_channel = channel;
			m_interpolator.set_data( m_channel );
		}

		~transform_vector()
		{
			delete m_channel;
		}
		bool empty() const { return m_channel->count == 0; }
		size_t size() const { return m_channel->count; }
		const transform& get( size_t index ) const { return ((key*)(m_channel->data))[ index ].tform; }
		const transform* data() const { return (const transform*)m_channel->data; }

		virtual uint32 raw_size() const 
		{
			return sizeof( size_t ) + m_channel->count * sizeof( key );
		}

		virtual mat4 get_matrix( float time ) const
		{
			return get_transform( time ).extract();
		}
		virtual transform get_transform( float time ) const
		{
			key result;
			m_interpolator.get_interpolated( result, time );
			return extract_transform< key >( result );
		}
	protected:
		key_channel_interpolator< key, false > m_interpolator;
		const key_raw_channel* m_channel;
	};


}

#endif // NV_ANIMATION_HH
