// Copyright (C) 2011-2015 ChaosForge Ltd
// http://chaosforge.org/
//
// This file is part of Nova libraries. 
// For conditions of distribution and use, see copying.txt file in root folder.

#include "nv/gfx/skeletal_mesh.hh"

#include "nv/interface/context.hh"
#include "nv/interface/device.hh"
#include "nv/stl/unordered_map.hh"

nv::skeletal_mesh_cpu::skeletal_mesh_cpu( context* a_context, const mesh_data* a_mesh_data, const mesh_nodes_data* bones )
	: skeletal_mesh( a_context )
{
	const raw_data_channel* pnt_chan   = a_mesh_data->get_channel<md5_vtx_pnt>();
	const raw_data_channel* pntiw_chan = a_mesh_data->get_channel<md5_vtx_pntiw>();

	m_pntdata.assign( pnt_chan->data_cast< md5_vtx_pnt >(), pnt_chan->size() );
	m_bone_offset.resize( bones->get_count() );
	m_transform.resize( bones->get_count() );

	for ( uint32 i = 0; i < bones->get_count(); ++i )
	{
		m_bone_offset[i] = transform( bones->get_node(i)->transform );
	}

	m_vtx_data  = a_mesh_data->get_channel_data<md5_vtx_pntiw>();
	m_indices   = a_mesh_data->get_channel_size( slot::INDEX );
	m_va        = a_context->create_vertex_array();

	//array_view< raw_data_channel* > channels = a_mesh_data->get_raw_channels();
	for ( auto channel : *a_mesh_data )
	{
		//const raw_data_channel* channel = channels[ch];
		if ( channel->size() > 0 && channel != pntiw_chan )
		{
			const data_descriptor& desc = channel->descriptor();
			if ( desc[0].vslot == slot::INDEX )
			{
				buffer b = a_context->get_device()->create_buffer( INDEX_BUFFER, STREAM_DRAW, channel->raw_size(), channel->raw_data() );
				a_context->set_index_buffer( m_va, b, desc[0].etype, true );
			}
			else
			{
				buffer b = a_context->get_device()->create_buffer( VERTEX_BUFFER, STREAM_DRAW, channel->raw_size(), channel->raw_data() );
				a_context->add_vertex_buffers( m_va, b, desc );
			}
		}
	}

	m_pbuffer   = a_context->find_buffer( m_va, slot::POSITION );
}

void nv::skeletal_mesh_cpu::update_animation( animation_entry* a_anim, uint32 a_anim_time )
{
	if ( a_anim )
	{
		skeletal_animation_entry_cpu * anim = static_cast<skeletal_animation_entry_cpu*>( a_anim );
		anim->update_skeleton( m_transform.data(), static_cast<float>( a_anim_time ) );
		{
			size_t skeleton_size = m_bone_offset.size();
			size_t vertex_count  = m_pntdata.size();
			m_pos_offset.resize( skeleton_size );
			for ( unsigned int i = 0; i < skeleton_size; ++i )
			{
				m_pos_offset[i] = m_transform[i] * m_bone_offset[i];
			}

			fill( m_pntdata.raw_data(), m_pntdata.raw_data() + m_pntdata.raw_size(), 0 );
			for ( unsigned int i = 0; i < vertex_count; ++i )
			{
				const md5_vtx_pntiw& vert = m_vtx_data[i];

				for ( int j = 0; j < 4; ++j )
				{
					unsigned index = unsigned( vert.boneindex[j] );
					float weight   = vert.boneweight[j];
					const quat& orient      = m_transform[index].get_orientation();
					const transform& offset = m_pos_offset[index];
					m_pntdata[i].position += offset.transformed( vert.position )        * weight;
					m_pntdata[i].normal   += ( orient * vert.normal  ) * weight;
					m_pntdata[i].tangent  += ( orient * vert.tangent ) * weight;
				}
			}
		}

		m_context->update( m_pbuffer, m_pntdata.data(), 0, m_pntdata.raw_size() );
	}
}


void nv::skeletal_animation_entry_cpu::update_skeleton( transform* skeleton, float time ) const
{
	float frame_duration = 1000.f / static_cast<float>( m_node_data->get_frame_rate() );
	float anim_duration = frame_duration * m_node_data->get_duration();
	float new_time = fmodf( time, anim_duration ) * 0.001f;

	float frame_num = new_time * m_node_data->get_frame_rate();
	for ( size_t i = 0; i < m_node_data->get_count(); ++i )
	{
		skeleton[i] = m_node_data->get_node(i)->data->get_transform( frame_num );
	}
}

void nv::skeletal_animation_entry_gpu::initialize()
{
	m_prepared  = false;
	m_children  = nullptr;
	m_offsets   = nullptr;
	uint32 node_count = m_node_data->get_count();
	m_bone_ids  = new sint16[ node_count ];

	if ( !m_node_data->is_flat() )
	{
		m_children = new vector< uint32 >[ node_count ];
		for ( uint32 n = 0; n < node_count; ++n )
		{
			const mesh_node_data* node = m_node_data->get_node(n);
			if ( node->parent_id != -1 )
			{
				m_children[ node->parent_id ].push_back( n );
			}
		}
	}
}

void nv::skeletal_animation_entry_gpu::update_skeleton( mat4* data, uint32 time ) const
{
	float tick_time = ( time * 0.001f ) * m_frame_rate;
	float anim_time = m_start;
	if ( m_duration > 0.0f ) anim_time += fmodf( tick_time, m_duration );

	if ( !m_node_data->is_flat() )
	{
		animate_rec( data, anim_time, 0, mat4() );
		return;
	}

	for ( uint32 n = 0; n < m_node_data->get_count(); ++n )
		if ( m_bone_ids[n] >= 0 )
		{
			const mesh_node_data* node = m_node_data->get_node(n);
			nv::mat4 node_mat( node->transform );

			if ( node->data )
			{
				node_mat = node->data->get_matrix( anim_time );
			}

			sint16 bone_id = m_bone_ids[n];
			data[ bone_id ] = node_mat * m_offsets[ bone_id ];
		}
}

void nv::skeletal_animation_entry_gpu::prepare( const mesh_nodes_data* bones )
{
	if ( m_prepared ) return;
	unordered_map< std::string, nv::uint16 > bone_names;
	m_offsets = new mat4[ bones->get_count() ];
	for ( nv::uint16 bi = 0; bi < bones->get_count(); ++bi )
	{
		const mesh_node_data* bone = bones->get_node(bi);
		bone_names[ bone->name ] = bi;
		m_offsets[bi] = bone->transform;
	}

	for ( uint32 n = 0; n < m_node_data->get_count(); ++n )
	{
		const mesh_node_data* node = m_node_data->get_node(n);
		sint16 bone_id = -1;

		auto bi = bone_names.find( node->name );
		if ( bi != bone_names.end() )
		{
			bone_id = sint16( bi->second );
		}
		m_bone_ids[n] = bone_id;
	}
	m_prepared = true;
}

void nv::skeletal_animation_entry_gpu::animate_rec( mat4* data, float time, uint32 node_id, const mat4& parent_mat ) const
{
	// TODO: fix transforms, which are now embedded,
	//       see note in assimp_loader.cc:load_node
	const mesh_node_data* node = m_node_data->get_node( node_id );
	mat4 node_mat( node->transform );

	if ( node->data )
	{
		node_mat = node->data->get_matrix( time );
	}

	mat4 global_mat = parent_mat * node_mat;

	sint16 bone_id = m_bone_ids[ node_id ];
	if ( bone_id >= 0 )
	{
		data[ bone_id ] = global_mat * m_offsets[ bone_id ];
	}

	for ( auto child : m_children[ node_id ] )
	{
		animate_rec( data, time, child, global_mat );
	}
}

nv::skeletal_animation_entry_gpu::~skeletal_animation_entry_gpu()
{
	delete[] m_offsets;
	delete[] m_children;
	delete[] m_bone_ids;
}

nv::skeletal_mesh_gpu::skeletal_mesh_gpu( context* a_context, const mesh_data* a_mesh, const mesh_nodes_data* a_bone_data )
	: skeletal_mesh( a_context ), m_bone_data( a_bone_data ), m_transform( nullptr )
{
	m_va          = a_context->create_vertex_array( a_mesh, nv::STATIC_DRAW );
	m_index_count = a_mesh->get_channel_size( slot::INDEX );
	if ( m_bone_data )
	{
		m_transform = new mat4[ m_bone_data->get_count() ];
	}
}

void nv::skeletal_mesh_gpu::update_animation( animation_entry* a_anim, uint32 a_anim_time )
{
	if ( m_bone_data && a_anim )
	{
		skeletal_animation_entry_gpu * anim = static_cast<skeletal_animation_entry_gpu*>( a_anim );
		anim->prepare( m_bone_data );
		anim->update_skeleton( m_transform, a_anim_time );
	}
}

void nv::skeletal_mesh_gpu::update( program a_program )
{
	if ( m_bone_data )
		m_context->get_device()->set_opt_uniform_array( a_program, "nv_m_bones", m_transform, m_bone_data->get_count() );
}

nv::transform nv::skeletal_mesh_gpu::get_node_transform( uint32 node_id ) const
{
	if ( node_id == 0 ) return transform();
	return transform( m_transform[ node_id ] );
}

nv::mat4 nv::skeletal_mesh_gpu::get_node_matrix( uint32 node_id ) const
{
	return m_transform[ node_id ];
}
