// Copyright (C) 2012-2015 ChaosForge Ltd
// http://chaosforge.org/
//
// This file is part of Nova libraries. 
// For conditions of distribution and use, see copying.txt file in root folder.

#include "nv/gfx/mesh_creator.hh"

struct nv_key_transform { nv::transform tform; };

void nv::mesh_nodes_creator::pre_transform_keys()
{
	if ( m_data->m_flat ) return;
	merge_keys();
	uint32 max_frames = 0;
	for ( size_t i = 0; i < m_data->get_count(); ++i )
	{
		sint16 parent_id = m_data->m_nodes[i].parent_id;
		key_data* keys   = m_data->m_nodes[i].data;
		key_data* pkeys  = ( parent_id != -1 ? m_data->m_nodes[parent_id].data : nullptr );
		size_t count     = ( keys ? keys->get_channel(0)->element_count() : 0 );
		size_t pcount    = ( pkeys ? pkeys->get_channel(0)->element_count() : 0 );
		max_frames = nv::max<uint32>( count, max_frames );
		if ( pkeys && pkeys->get_channel_count() > 0 && keys && keys->get_channel_count() > 0 )
		{
			data_channel_creator< nv_key_transform > channel_creator( const_cast< raw_data_channel* >( keys->get_channel( 0 ) ) );
			nv_key_transform* channel = channel_creator.data();
			const nv_key_transform* pchannel = pkeys->get_channel(0)->data_cast< nv_key_transform >();
			for ( unsigned n = 0; n < count; ++n )
			{
				channel[n].tform = pchannel[ nv::min( n, pcount-1 ) ].tform * channel[n].tform;
			}
		}
	}

	// DAE pre_transform hack
	if ( m_data->m_frame_rate == 1 )
	{
		m_data->m_frame_rate = 32;
		m_data->m_duration   = static_cast<float>( max_frames );
	}

	m_data->m_flat = true;
}

// TODO: DELETE
struct assimp_key_p  { float time; nv::vec3 position; };
struct assimp_key_r  { float time; nv::quat rotation; };


void nv::mesh_nodes_creator::merge_keys()
{
	for ( size_t i = 0; i < m_data->get_count(); ++i )
	{
		key_data* old_keys = m_data->m_nodes[i].data;
		if ( old_keys && old_keys->get_channel_count() > 0 )
		{
			size_t chan_count = old_keys->get_channel_count();
			if ( chan_count == 1 
				&& old_keys->get_channel(0)->descriptor().slot_count() == 1 
				&& old_keys->get_channel(0)->descriptor()[0].etype == TRANSFORM ) continue;

			size_t max_keys = 0;
			for ( size_t c = 0; c < chan_count; ++c )
			{
				max_keys = nv::max( max_keys, old_keys->get_channel(c)->element_count() );
			}

			data_channel_creator< nv_key_transform > kt_channel( max_keys );
			key_data* new_keys = new key_data;
			data_descriptor final_key = old_keys->get_final_key();

			for ( unsigned n = 0; n < max_keys; ++n )
			{
				float key[ 16 ];
				float* pkey = key;

				for ( uint16 c = 0; c < chan_count; ++c )
				{
					size_t idx = nv::min( old_keys->get_channel(c)->element_count() - 1, n );
					pkey += old_keys->get_raw( old_keys->get_channel(c), idx, pkey );
				}
				kt_channel.data()[n].tform = extract_transform_raw( final_key, key );
			}

			delete old_keys;
			new_keys->add_channel( kt_channel.release() );
			m_data->m_nodes[i].data = new_keys;
		}
	}
}

void nv::mesh_nodes_creator::transform( float scale, const mat3& r33 )
{
	mat3 ri33 = glm::inverse( r33 );
	mat4 pre_transform ( scale * r33 );
	mat4 post_transform( 1.f/scale * ri33 ); 

	for ( size_t i = 0; i < m_data->get_count(); ++i )
	{
		mesh_node_data& node = m_data->m_nodes[i];
		node.transform = pre_transform * node.transform * post_transform;
		if ( node.data )
		{
			key_data* kdata  = node.data;
			for ( size_t c = 0; c < kdata->get_channel_count(); ++c )
			{
				raw_data_channel_creator channel( const_cast< raw_data_channel* >( kdata->get_channel( c ) ) );
				size_t key_size = channel.element_size();
				for ( size_t n = 0; n < channel.size(); ++n )
				{
					transform_key_raw( kdata->get_channel( c )->descriptor(), channel.raw_data() + n * key_size, scale, r33, ri33 );
				}
			}
		}
	}
}

void nv::mesh_data_creator::transform( float scale, const mat3& r33 )
{
	vec3 vertex_offset     = vec3(); 
	mat3 vertex_transform  = scale * r33;
	mat3 normal_transform  = r33;

	for ( uint32 c = 0; c < m_data->get_channel_count(); ++c )
	{
		raw_data_channel_creator channel( m_data->m_channels[ c ] );
		const data_descriptor&  desc    = channel.descriptor();
		uint8* raw_data = channel.raw_data();
		uint32 vtx_size = desc.element_size();
		int p_offset = -1;
		int n_offset = -1;
		int t_offset = -1;
		for ( const auto& cslot : desc  )
			switch ( cslot.vslot )
			{
				case slot::POSITION : if ( cslot.etype == FLOAT_VECTOR_3 ) p_offset = int( cslot.offset ); break;
				case slot::NORMAL   : if ( cslot.etype == FLOAT_VECTOR_3 ) n_offset = int( cslot.offset ); break;
				case slot::TANGENT  : if ( cslot.etype == FLOAT_VECTOR_4 ) t_offset = int( cslot.offset ); break;
				default             : break;
			}

		if ( p_offset != -1 )
			for ( uint32 i = 0; i < channel.size(); i++)
			{
				vec3& p = *reinterpret_cast<vec3*>( raw_data + vtx_size*i + p_offset );
				p = vertex_transform * p + vertex_offset;
			}

		if ( n_offset != -1 )
			for ( uint32 i = 0; i < channel.size(); i++)
			{
				vec3& n = *reinterpret_cast<vec3*>( raw_data + vtx_size*i + n_offset );
				n = glm::normalize( normal_transform * n );
			}
		if ( t_offset != -1 )
			for ( uint32 i = 0; i < channel.size(); i++)
			{
				vec4& t = *reinterpret_cast<vec4*>(raw_data + vtx_size*i + t_offset );
				t = vec4( glm::normalize( normal_transform * vec3(t) ), t[3] );
			}
	}
}

struct vertex_g
{
	nv::vec4 tangent;
};

void nv::mesh_data_creator::flip_normals()
{
	int ch_n  = m_data->get_channel_index( slot::NORMAL );
	size_t n_offset = 0;
	if ( ch_n == -1 ) return;
	raw_data_channel_creator channel( m_data->m_channels[ unsigned( ch_n ) ] );
	for ( const auto& cslot : channel.descriptor() )
		if ( cslot.vslot == slot::NORMAL )
		{
			n_offset  = cslot.offset;
		}

	for ( uint32 i = 0; i < channel.size(); ++i )
	{
		vec3& normal = *reinterpret_cast<vec3*>( channel.raw_data() + channel.element_size() * i + n_offset );
		normal = -normal;
	}
}


void nv::mesh_data_creator::generate_tangents()
{
	int p_offset = -1;
	int n_offset = -1;
	int t_offset = -1;
	datatype i_type = NONE;
	uint32 n_channel_index = 0;

	const raw_data_channel* p_channel = nullptr;
	      raw_data_channel* n_channel = nullptr;
	const raw_data_channel* t_channel = nullptr;
	const raw_data_channel* i_channel = nullptr;

	for ( uint32 c = 0; c < m_data->get_channel_count(); ++c )
	{
		const raw_data_channel* channel = m_data->get_channel(c);

		for ( const auto& cslot : channel->descriptor() )
		switch ( cslot.vslot )
		{
			case slot::POSITION : 
				if ( cslot.etype == FLOAT_VECTOR_3 )
				{
					p_offset  = int( cslot.offset );
					p_channel = channel;
				}
				break;
			case slot::NORMAL   : 
				if ( cslot.etype == FLOAT_VECTOR_3 )
				{
					n_offset  = int( cslot.offset );
					n_channel = m_data->m_channels[ c ];
					n_channel_index = c;
				}
				break;
			case slot::TEXCOORD : 
				if ( cslot.etype == FLOAT_VECTOR_2 )
				{
					t_offset  = int( cslot.offset );
					t_channel = channel;
				}
				break;
			case slot::INDEX    : 
				{
					i_type    = cslot.etype;
					i_channel = channel;
				}
				break;
			case slot::TANGENT  : return;
			default             : break;
		}
	}
	if ( !p_channel || !n_channel || !t_channel ) return;

	if ( p_channel->element_count() != n_channel->element_count() 
		|| p_channel->element_count() % t_channel->element_count() != 0 
		|| ( i_type != UINT && i_type != USHORT && i_type != NONE ) )
	{
		return;
	}

	data_channel_creator< vertex_g > g_channel( p_channel->element_count() );
	vec4* tangents              = &( g_channel.data()[0].tangent );
	vec3* tangents2             = new vec3[ p_channel->element_count() ];
	uint32 tri_count = i_channel ? i_channel->element_count() / 3 : t_channel->element_count() / 3;
	uint32 vtx_count = p_channel->element_count();
	uint32 sets      = p_channel->element_count() / t_channel->element_count();

	for ( unsigned int i = 0; i < tri_count; ++i )
	{
		uint32 ti0 = 0;
		uint32 ti1 = 0;
		uint32 ti2 = 0;
		if ( i_type == UINT )
		{
			const uint32* idata = reinterpret_cast<const uint32*>( i_channel->raw_data() );
			ti0 = idata[ i * 3 ];
			ti1 = idata[ i * 3 + 1 ];
			ti2 = idata[ i * 3 + 2 ];
		}
		else if ( i_type == USHORT )
		{
			const uint16* idata = reinterpret_cast<const uint16*>( i_channel->raw_data() );
			ti0 = idata[ i * 3 ];
			ti1 = idata[ i * 3 + 1 ];
			ti2 = idata[ i * 3 + 2 ];
		}
		else // if ( i_type == NONE )
		{
			ti0 = i * 3;
			ti1 = i * 3 + 1;
			ti2 = i * 3 + 2;
		}

		const vec2& w1 = *reinterpret_cast<const vec2*>(t_channel->raw_data() + t_channel->element_size()*ti0 + t_offset );
		const vec2& w2 = *reinterpret_cast<const vec2*>(t_channel->raw_data() + t_channel->element_size()*ti1 + t_offset );
		const vec2& w3 = *reinterpret_cast<const vec2*>(t_channel->raw_data() + t_channel->element_size()*ti2 + t_offset );
		vec2 st1 = w3 - w1;
		vec2 st2 = w2 - w1;
		float stst = (st1.x * st2.y - st2.x * st1.y);
		float coef = ( stst != 0.0f ? 1.0f / stst : 0.0f );

		for ( uint32 set = 0; set < sets; ++set )
		{
			uint32 nti0 = t_channel->element_count() * set + ti0;
			uint32 nti1 = t_channel->element_count() * set + ti1;
			uint32 nti2 = t_channel->element_count() * set + ti2;
			const vec3& v1 = *reinterpret_cast<const vec3*>(p_channel->raw_data() + p_channel->element_size()*nti0 + p_offset );
			const vec3& v2 = *reinterpret_cast<const vec3*>(p_channel->raw_data() + p_channel->element_size()*nti1 + p_offset );
			const vec3& v3 = *reinterpret_cast<const vec3*>(p_channel->raw_data() + p_channel->element_size()*nti2 + p_offset );
			vec3 xyz1 = v3 - v1;
			vec3 xyz2 = v2 - v1;

			//vec3 normal = glm::cross( xyz1, xyz2 );
			//
			//vtcs[ ti0 ].normal += normal;
			//vtcs[ ti1 ].normal += normal;
			//vtcs[ ti2 ].normal += normal;
			vec3 tangent  = (( xyz1 * st2.y ) - ( xyz2 * st1.y )) * coef;
			vec3 tangent2 = (( xyz2 * st1.x ) - ( xyz1 * st2.x )) * coef;

			tangents[nti0] = vec4( vec3( tangents[nti0] ) + tangent, 0 );
			tangents[nti1] = vec4( vec3( tangents[nti1] ) + tangent, 0 );
			tangents[nti2] = vec4( vec3( tangents[nti2] ) + tangent, 0 );

			tangents2[nti0] += tangent2;
			tangents2[nti1] += tangent2;
			tangents2[nti2] += tangent2;
		}
	}

	for ( unsigned int i = 0; i < vtx_count; ++i )
	{
		const vec3 n = *reinterpret_cast<const vec3*>( n_channel->raw_data() + n_channel->element_size()*i + n_offset );
		const vec3 t = vec3(tangents[i]);
		if ( ! ( t.x == 0.0f && t.y == 0.0f && t.z == 0.0f ) )
		{
			tangents[i]    = vec4( glm::normalize(t - n * glm::dot( n, t )), 0.0f ); 
			tangents[i][3] = (glm::dot(glm::cross(n, t), tangents2[i]) < 0.0f) ? -1.0f : 1.0f;
		}
	}
	delete tangents2;

	m_data->m_channels[ n_channel_index ] = merge_channels( n_channel, g_channel.channel() );
	delete n_channel;
}

nv::raw_data_channel* nv::mesh_data_creator::merge_channels( raw_data_channel* a, raw_data_channel* b )
{
	NV_ASSERT( a->element_count() == b->element_count(), "merge_channel - bad channels!" );
	data_descriptor desc  = a->descriptor();
	desc.append( b->descriptor() );

	raw_data_channel_creator result( desc, a->element_count() );
	for ( uint32 i = 0; i < a->element_count(); ++i )
	{
		raw_copy_n( a->raw_data() + i * a->element_size(), a->element_size(), result.raw_data() + i*desc.element_size() );
		raw_copy_n( b->raw_data() + i * b->element_size(), b->element_size(), result.raw_data() + i*desc.element_size() + a->element_size() );
	}

	return result.release();
}

nv::raw_data_channel* nv::mesh_data_creator::append_channels( raw_data_channel* a, raw_data_channel* b, uint32 frame_count )
{
	if ( a->descriptor() != b->descriptor() ) return nullptr;
	if ( a->element_count() % frame_count != 0 ) return nullptr;
	if ( b->element_count() % frame_count != 0 ) return nullptr;
	size_t vtx_size = a->element_size();

	raw_data_channel_creator result( a->descriptor(), a->element_count() + b->element_count() );

	if ( frame_count == 1 )
	{
		size_t a_size = vtx_size * a->element_count();
		raw_copy_n( a->raw_data(), a_size, result.raw_data() );
		raw_copy_n( b->raw_data(), vtx_size * b->element_count(), result.raw_data() + a_size );
	}
	else
	{
		size_t frame_size_a = ( a->element_count() / frame_count ) * vtx_size;
		size_t frame_size_b = ( b->element_count() / frame_count ) * vtx_size;
		size_t pos_a = 0;
		size_t pos_b = 0;
		size_t pos   = 0;
		for ( size_t i = 0; i < frame_count; ++i )
		{
			raw_copy_n( a->raw_data() + pos_a, frame_size_a, result.raw_data() + pos );
			raw_copy_n( b->raw_data() + pos_b, frame_size_b, result.raw_data() + pos + frame_size_a );				pos_a += frame_size_a;
			pos_b += frame_size_b; 
			pos   += frame_size_a + frame_size_b;
		}
	}

	return result.release();
}



bool nv::mesh_data_creator::is_same_format( mesh_data* other )
{
	if ( m_data->get_channel_count() != other->get_channel_count() ) return false;
	for ( uint32 c = 0; c < m_data->get_channel_count(); ++c )
	{
		if ( m_data->get_channel(c)->descriptor() != other->get_channel(c)->descriptor() )
			return false;
	}
	return true;
}

void nv::mesh_data_creator::merge( mesh_data* other )
{
	if ( !is_same_format( other ) ) return;
	int ch_pi  = m_data->get_channel_index( slot::POSITION );
	int ch_ti  = m_data->get_channel_index( slot::TEXCOORD );
	int och_pi = other->get_channel_index( slot::POSITION );
	int och_ti = other->get_channel_index( slot::TEXCOORD );
	if ( ch_pi == -1 || ch_ti == -1 ) return;
	size_t size   = m_data->m_channels[ unsigned(ch_ti) ]->element_count();
	size_t osize  =  other->m_channels[ unsigned(och_ti) ]->element_count();
	size_t count  = m_data->m_channels[ unsigned(ch_pi) ]->element_count();
	size_t ocount =  other->m_channels[ unsigned(och_pi) ]->element_count();
	if ( count % size != 0 || ocount % osize != 0 ) return;
	if ( count / size != ocount / osize ) return;
	
	for ( uint32 c = 0; c < m_data->get_channel_count(); ++c )
	{
		raw_data_channel* old = m_data->m_channels[c];
		bool old_is_index = old->element_count() > 0 && old->descriptor()[0].vslot == slot::INDEX;
		size_t frame_count = ( old_is_index ? 1 : old->element_count() / size );
		m_data->m_channels[c] = append_channels( old, other->m_channels[c], frame_count );
		NV_ASSERT( m_data->m_channels[c], "Merge problem!" );
		if ( old_is_index )
		{
			switch ( old->descriptor()[0].etype )
			{
			case USHORT : 
				{
					NV_ASSERT( size + osize < uint16(-1), "Index out of range!" );
					raw_data_channel_creator ic( m_data->m_channels[c] );
					uint16* indexes = reinterpret_cast<uint16*>( ic.raw_data() );
					for ( uint16 i = uint16( old->element_count() ); i < ic.size(); ++i )
						indexes[i] += uint16( size );

				}
				break;
			case UINT   : 
				{
					raw_data_channel_creator ic( m_data->m_channels[c] );
					uint32* indexes = reinterpret_cast<uint32*>( ic.raw_data() );
					for ( uint32 i = old->element_count(); i < ic.size(); ++i )
						indexes[i] += size;
				}
				break;
			default : NV_ASSERT( false, "Unsupported index type!" ); break;
			}
			m_data->m_index_channel = m_data->m_channels[c];
		}
		delete old;
	}
}
