// Copyright (C) 2012-2015 ChaosForge Ltd
// http://chaosforge.org/
//
// This file is part of Nova libraries. 
// For conditions of distribution and use, see copying.txt file in root folder.

#include "nv/formats/md5_loader.hh"

#include "nv/core/logging.hh"
#include "nv/stl/vector.hh"
#include "nv/io/std_stream.hh"
#include "nv/interface/data_channel_access.hh"

#include <stdio.h>  // sscanf
#include <stdlib.h> // atof

using namespace nv;

static void next_line( std::istream& stream )
{
	stream.ignore( 1024*1024, '\n' );
}

static inline void discard( std::istream& stream, const std::string& token )
{
	std::string discarded;
	stream >> discarded;
	assert( discarded == token );
}

static void remove_quotes( std::string& str )
{
	nv::size_t n;
	while ( ( n = str.find('\"') ) != std::string::npos ) str.erase(n,1);
}

static void unit_quat_w( nv::quat& quat )
{
	float t = 1.0f - ( quat.x * quat.x ) - ( quat.y * quat.y ) - ( quat.z * quat.z );
	quat.w = ( t < 0.0f ? 0.0f : -sqrtf(t) );
}

bool md5_loader::load( stream& source )
{
	reset();
	std_stream sstream( &source );
	std::string command;
	mesh_node_data* nodes = nullptr;
	size_t num_joints = 0;

	// MESH data
	dynamic_array< md5_weight > weights;
	dynamic_array< md5_weight_info > weight_info;
	size_t num_meshes = 0;

	// MESH data
	dynamic_array< md5_joint_info > joint_infos;
	vector< transform >             base_frames;
	size_t num_animated_components = 0;
	size_t frame_rate = 0;
	size_t num_frames = 0;

	sstream >> command;
	while ( !sstream.eof() )
	{
		if ( command == "MD5Version" )
		{
			sstream >> m_md5_version;
			assert( m_md5_version == 10 );
		}
		else if ( command == "commandline" )
		{
			next_line( sstream ); 
		}
		else if ( command == "numJoints" )
		{
			sstream >> num_joints;
			next_line( sstream ); 
		}
		else if ( command == "numMeshes" )
		{
			assert( m_type == UNKNOWN );
			m_type = MESH;
			sstream >> num_meshes;
			m_meshes.resize( num_meshes );
			num_meshes = 0;
		}
		else if ( command == "numFrames" )
		{
			assert( m_type == UNKNOWN || m_type == ANIMATION );
			m_type = ANIMATION;
			sstream >> num_frames;
			next_line( sstream ); 
		}
		else if ( command == "frameRate" )
		{
			assert( m_type == UNKNOWN || m_type == ANIMATION );
			m_type = ANIMATION;
			sstream >> frame_rate;
			next_line( sstream ); 
		}
		else if ( command == "numAnimatedComponents" )
		{
			assert( m_type == UNKNOWN || m_type == ANIMATION );
			m_type = ANIMATION;
			sstream >> num_animated_components;
			next_line( sstream ); 
		}
		else if ( command == "joints" )
		{
			assert( m_type == MESH );
			assert( m_nodes == nullptr );
			nodes = new mesh_node_data[ num_joints ];
			m_nodes = new mesh_nodes_data( "md5_bones", num_joints, nodes );
			discard( sstream, "{" );
			for ( size_t i = 0; i < m_nodes->get_count(); ++i )
			{
				sstream >> nodes[i].name >> nodes[i].parent_id;
				vec3 pos;
				quat orient;
				discard( sstream, "(" );
				sstream >> pos.x >> pos.y >> pos.z;
				discard( sstream, ")" );
				discard( sstream, "(" );
				sstream >> orient.x >> orient.y >> orient.z;
				unit_quat_w( orient );
				remove_quotes( nodes[i].name );
				nodes[i].target_id       = -1;
				nodes[i].parent_id       = -1;
				nodes[i].transform       = transform( pos, orient ).inverse().extract();
				nodes[i].data            = nullptr;
				next_line( sstream );
			}
			discard( sstream, "}" );
		}
		else if ( command == "mesh" )
		{
			assert( m_type == MESH );
			data_channel_set_creator mesh;

			uint32 num_verts   = 0;
			uint32 num_tris    = 0;
			uint32 num_weights = 0;

			discard( sstream, "{" );
			sstream >> command;
			while ( command != "}" ) 
			{
				if ( command == "shader" )
				{
					std::string shader;
					sstream >> shader;
					remove_quotes( shader );
					next_line( sstream );
				}
				else if ( command == "numverts")
				{
					sstream >> num_verts; 

					md5_vtx_t* tdata = nullptr;
					{
						data_channel_creator<md5_vtx_pnt>   ch_pnt( num_verts );
						data_channel_creator<md5_vtx_t>     ch_t( num_verts );
						data_channel_creator<md5_vtx_pntiw> ch_pntiw( num_verts );
						tdata = ch_t.data();
						mesh.add_channel( ch_pnt.release() );
						mesh.add_channel( ch_t.release() );
						// TODO: hack to prevent rendering
						//ch_pntiw->m_count = 0;
						mesh.add_channel( ch_pntiw.release() );
					}
					weight_info.resize( num_verts );

					next_line( sstream );
					std::string line;
					for ( uint32 i = 0; i < num_verts; ++i )
					{
						size_t weight_count;
						size_t start_weight;
						vec2 texcoord;

						std::getline( sstream, line );
						sscanf( line.c_str(), "%*s %*u ( %f %f ) %u %u", &(texcoord.x), &(texcoord.y), &(start_weight), &(weight_count) );
						weight_info[i].start_weight = start_weight;
						weight_info[i].weight_count = weight_count;
						tdata[i].texcoord = texcoord;
					}  
				}
				else if ( command == "numtris" )
				{
					sstream >> num_tris;

					data_channel_creator< index_u32 > ch_i( num_tris * 3 );
					uint32* vtx_i                = reinterpret_cast< uint32* >( ch_i.raw_data() );
					uint32 idx = 0;

					next_line( sstream );
					std::string line;
					for ( uint32 i = 0; i < num_tris; ++i )
					{
						unsigned ti0;
						unsigned ti1;
						unsigned ti2;

						std::getline( sstream, line );
						sscanf( line.c_str(), "%*s %*u %u %u %u )", &(ti0), &(ti1), &(ti2));

						vtx_i[idx++] = ti0;
						vtx_i[idx++] = ti1;
						vtx_i[idx++] = ti2;
					}              

					mesh.add_channel( ch_i.release() );
				}
				else if ( command == "numweights" )
				{
					sstream >> num_weights;
					weights.resize( num_weights );
					next_line( sstream );
					std::string line;
					for ( uint32 i = 0; i < num_weights; ++i )
					{
						md5_weight weight;

						std::getline( sstream, line );
						sscanf( line.c_str(), "%*s %*u %u %f ( %f %f %f )", &(weight.joint_id), &(weight.bias), &(weight.pos.x), &(weight.pos.y), &(weight.pos.z));
 						weights[i] = weight;
					}
				}
				else
				{
					next_line( sstream );
				}

				sstream >> command;
			}

			data_channel_set* mdata = mesh.release();
			prepare_mesh( nodes, weight_info.size(), mdata, weights.data(), weight_info.data() );

			m_meshes[ num_meshes ] = mdata;
			num_meshes++;
		} // mesh
		else if ( command == "hierarchy" )
		{
			assert( m_type == ANIMATION );
			assert( nodes == nullptr );
			nodes = new mesh_node_data[ num_joints ];
			m_nodes = new mesh_nodes_data( "md5_animation", num_joints, nodes, static_cast< nv::uint16 >( frame_rate ), static_cast< float >( num_frames ), true );
			joint_infos.resize( num_joints );

			discard( sstream, "{" );
			for ( size_t i = 0; i < m_nodes->get_count(); ++i )
			{
				std::string    name;
				sstream >> nodes[i].name >> nodes[i].parent_id >> joint_infos[i].flags >> joint_infos[i].start_index;
				remove_quotes( name );
				nodes[i].transform = mat4();
				nodes[i].target_id = -1;
				nodes[i].data      = new key_data;
				data_channel_creator< md5_key_t > fc( num_frames );
				nodes[i].data->add_key_channel( fc.release() );
				next_line( sstream ); 
			}
			discard( sstream, "}" );
		}
		else if ( command == "bounds" )
		{
			assert( m_type == ANIMATION );
			discard( sstream, "{" );
			next_line( sstream ); 
			for ( size_t i = 0; i < num_frames; ++i ) 
			{
				//  				vec3 min;
				//  				vec3 max;
				//  				discard( sstream, "(" );
				//  				sstream >> min.x >> min.y >> min.z;
				//  				discard( sstream, ")" );
				//  				discard( sstream, "(" );
				//  				sstream >> max.x >> max.y >> max.z;
				// 				m_bounds.push_back( bound );
				next_line( sstream ); 
			}

			discard( sstream, "}" );
			next_line( sstream ); 
		}
		else if ( command == "baseframe" )
		{
			assert( m_type == ANIMATION );
			discard( sstream, "{" );
			next_line( sstream ); 

			for ( size_t i = 0; i < m_nodes->get_count(); ++i )
			{
				transform base_frame;
				vec3 pos;
				quat orient;
				discard( sstream, "(" );
				sstream >> pos.x >> pos.y >> pos.z;
				discard( sstream, ")" );
				discard( sstream, "(" );
				sstream >> orient.x >> orient.y >> orient.z;
				next_line( sstream ); 

				base_frames.emplace_back( pos, orient );
			}
			discard( sstream, "}" );
			next_line( sstream ); 
		}
		else if ( command == "frame" )
		{
			vector<float> frame;
			uint32 frame_id;
			sstream >> frame_id;
			discard( sstream, "{" );
			next_line( sstream ); 

			frame.reserve( num_animated_components );
			char buf[50];
			for ( size_t i = 0; i < num_animated_components; ++i )
			{
				sstream >> buf;
				frame.push_back( static_cast< float >( atof(buf) ) );
			}

			build_frame_skeleton( nodes, frame_id, joint_infos, base_frames, frame );

			discard( sstream, "}" );
			next_line( sstream ); 
		}

		sstream >> command;
	}

	return true;
}

bool md5_loader::prepare_mesh( mesh_node_data* nodes, uint32 vtx_count, data_channel_set* mdata, md5_weight* weights, md5_weight_info* weight_info )
{
	assert( m_type == MESH );
	data_channel_creator< md5_vtx_pnt >   pnt  ( const_cast< raw_data_channel* >( mdata->get_channel< md5_vtx_pnt >() ) );
	data_channel_creator< md5_vtx_pntiw > pntiw( const_cast< raw_data_channel* >( mdata->get_channel< md5_vtx_pntiw >() ) );
	md5_vtx_pntiw* vtx_data = pntiw.data();
	md5_vtx_pnt* vtcs = pnt.data();

	for ( uint32 i = 0; i < vtx_count; ++i )
	{
		size_t start_weight = weight_info[i].start_weight;
		size_t weight_count = weight_info[i].weight_count;
		md5_vtx_pntiw& vdata = vtx_data[i];
		md5_vtx_pnt& vtc = vtcs[i];

		vtc.position = vec3(0);
		vtc.normal   = vec3(0);
		vtc.tangent  = vec3(0);

		stable_sort( weights + start_weight, weights + start_weight + weight_count, [] ( const md5_weight& a, const md5_weight& b ) -> bool { return a.bias > b.bias; } );
		//std::sort( weights + start_weight, weights + start_weight + weight_count, [](const md5_weight& a, const md5_weight& b) -> bool { return a.bias > b.bias; } );

		if ( weight_count > 4 )
		{
			float sum = 0.0f;
			for ( size_t j = 0; j < 4; ++j )
			{
				sum += weights[start_weight + j].bias;
			}
			float ratio = 1.0f / sum;
			for ( size_t j = 0; j < 4; ++j )
			{
				weights[start_weight + j].bias = ratio * weights[start_weight + j].bias;
			}
			weight_count = 4;
		}

		for ( int j = 0; j < 4; ++j )
		{
			if ( j < int(weight_count) )
			{
				vdata.boneindex[j]  = int( weights[int(start_weight) + j].joint_id );
				vdata.boneweight[j] = weights[int(start_weight) + j].bias;
			}
			else
			{
				vdata.boneindex[j]  = 0;
				vdata.boneweight[j] = 0.0f;
			}
		}

		for ( size_t j = 0; j < 4; ++j )
		{
			if ( j < weight_count )
			{
				md5_weight& weight           = weights[start_weight + j];
				const mesh_node_data&  joint = nodes[weight.joint_id];
				const transform tr = transform( joint.transform ).inverse();
				vec3 rot_pos = tr.get_orientation() * weight.pos;

				vtc.position += ( tr.get_position() + rot_pos ) * weight.bias;
			}
		}
	}

	const uint32*    idata = reinterpret_cast< uint32* >( const_cast< uint8* >( mdata->get_channel( slot::INDEX )->raw_data() ) );
	const md5_vtx_t* tdata = mdata->get_channel_data<md5_vtx_t>();

	// Prepare normals
	uint32 tri_count = mdata->get_channel_size( slot::INDEX ) / 3;
	for ( unsigned int i = 0; i < tri_count; ++i )
	{
		uint32 ti0 = idata[ i * 3 ];
		uint32 ti1 = idata[ i * 3 + 1 ];
		uint32 ti2 = idata[ i * 3 + 2 ];
 
		vec3 v1 = vtcs[ ti0 ].position;
		vec3 v2 = vtcs[ ti1 ].position;
		vec3 v3 = vtcs[ ti2 ].position;
		vec3 xyz1 = v3 - v1;
		vec3 xyz2 = v2 - v1;

		vec3 normal = glm::cross( xyz1, xyz2 );

		vtcs[ ti0 ].normal += normal;
		vtcs[ ti1 ].normal += normal;
		vtcs[ ti2 ].normal += normal;

		const vec2& w1 = tdata[ ti0 ].texcoord;
		const vec2& w2 = tdata[ ti1 ].texcoord;
		const vec2& w3 = tdata[ ti2 ].texcoord;

		vec2 st1 = w3 - w1;
		vec2 st2 = w2 - w1;

		float coef = 1.0f / (st1.x * st2.y - st2.x * st1.y);

		vec3 tangent = (( xyz1 * st2.y ) - ( xyz2 * st1.y )) * coef;

		vtcs[ ti0 ].tangent += tangent;
		vtcs[ ti1 ].tangent += tangent;
		vtcs[ ti2 ].tangent += tangent;
	}

	for ( size_t i = 0; i < vtx_count; ++i )
	{
		md5_vtx_pntiw& vdata = vtx_data[i];

		vec3 normal  = glm::normalize( vtcs[i].normal );
		vec3 tangent = glm::normalize( vtcs[i].tangent );
		vtcs[i].normal   = normal;
		vtcs[i].tangent  = tangent;

		vdata.position = vtcs[i].position;
		vdata.normal   = vec3(0);
 		vdata.tangent  = vec3(0);
 
 		for ( int j = 0; j < 4; ++j )
 		{
			const mesh_node_data&  joint = nodes[vdata.boneindex[j]];
			const transform tr = transform( joint.transform ).inverse();
 			vdata.normal  += ( normal  * tr.get_orientation() ) * vdata.boneweight[j];
 			vdata.tangent += ( tangent * tr.get_orientation() ) * vdata.boneweight[j];
 		}
	}

	return true;
}

void md5_loader::build_frame_skeleton( mesh_node_data* nodes, uint32 index, const array_view<md5_joint_info>& joint_infos, const array_view<transform>& base_frames, const array_view<float>& frame_data )
{
	assert( m_type == ANIMATION );
	for ( unsigned int i = 0; i < joint_infos.size(); ++i )
	{
		unsigned int j = 0;

		const md5_joint_info& jinfo = joint_infos[i];
		mesh_node_data& joint = nodes[i];
		int parent_id         = joint.parent_id;

		vec3 pos    = base_frames[i].get_position();
		quat orient = base_frames[i].get_orientation();
		if ( jinfo.flags & 1 )  pos.x    = frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 2 )  pos.y    = frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 4 )  pos.z    = frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 8 )  orient.x = frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 16 ) orient.y = frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 32 )	orient.z = frame_data[ jinfo.start_index + j++ ];
		unit_quat_w( orient );

		if ( parent_id >= 0 ) // Has a parent joint
		{
			const mesh_node_data& pjoint = nodes[parent_id];
			const transform* ptv = reinterpret_cast< const transform* >( pjoint.data->get_channel(0)->raw_data() );
			transform ptr;
			if ( pjoint.data->get_channel(0)->size() > index ) ptr = ptv[ index ];
			vec3 rot_pos = ptr.get_orientation() * pos;

			pos    = ptr.get_position() + rot_pos;
			orient = ptr.get_orientation() * orient;

			orient = glm::normalize( orient );
		}

		reinterpret_cast< transform* >( const_cast< uint8* >( joint.data->get_channel(0)->raw_data() ) )[index] = transform( pos, orient );
	}
}

data_channel_set* nv::md5_loader::release_mesh_data( size_t index )
{
	data_channel_set* result = m_meshes[ index ];
	m_meshes[ index ] = nullptr;
	return result;
}

mesh_nodes_data* nv::md5_loader::release_mesh_nodes_data( size_t )
{
	mesh_nodes_data* nodes = m_nodes;
	m_nodes = nullptr; 
	return nodes;
}

mesh_data_pack* nv::md5_loader::release_mesh_data_pack()
{
	uint32 size = m_meshes.size();
	data_channel_set* meshes = data_channel_set_creator::create_array( size, 4 );
	for ( uint32 i = 0; i < size; ++i )
	{
		data_channel_set_creator( m_meshes[i] ).move_to( meshes[i] );
		delete m_meshes[i];
		m_meshes[i] = nullptr;
	}
	return new mesh_data_pack( size, meshes, release_mesh_nodes_data() );
}


nv::md5_loader::~md5_loader()
{
	reset();
}

void nv::md5_loader::reset()
{
	if ( m_nodes ) delete m_nodes;
	for ( auto m : m_meshes ) { if (m) delete m; }
	m_meshes.resize(0);
	m_nodes = nullptr;
}

