// Copyright (C) 2012-2013 ChaosForge / Kornel Kisielewicz
// http://chaosforge.org/
//
// This file is part of NV Libraries.
// For conditions of distribution and use, see copyright notice in nv.hh

#include "nv/formats/md5_loader.hh"

#include <glm/gtc/constants.hpp>
#include "nv/logging.hh"
#include "nv/io/std_stream.hh"
#include "nv/profiler.hh"
#include <cstring>

using namespace nv;

// based on http://tfc.duke.free.fr/coding/md5-specs-en.html

static void next_line( std::istream& stream )
{
	stream.ignore( std::numeric_limits<std::streamsize>::max(), '\n' );
}

static inline void discard( std::istream& stream, const std::string& token )
{
//	stream.ignore( std::numeric_limits<std::streamsize>::max(), ' ' );
	std::string discarded;
	stream >> discarded;
	assert( discarded == token );
}


static void remove_quotes( std::string& str )
{
	size_t n;
	while ( ( n = str.find('\"') ) != std::string::npos ) str.erase(n,1);
}

static void unit_quat_w( glm::quat& quat )
{
	float t = 1.0f - ( quat.x * quat.x ) - ( quat.y * quat.y ) - ( quat.z * quat.z );
	quat.w = ( t < 0.0f ? 0.0f : -sqrtf(t) );
}

bool md5_loader::load( stream& source )
{
	std_stream sstream( &source );
	std::string command;

	sstream >> command;
	while ( !sstream.eof() )
	{
		if ( command == "MD5Version" )
		{
			sstream >> m_md5_version;
			assert( m_md5_version == 10 );
		}
		else if ( command == "commandline" )
		{
			next_line( sstream ); 
		}
		else if ( command == "numJoints" )
		{
			sstream >> m_num_joints;
			m_joints.reserve( m_num_joints );
		}
		else if ( command == "numMeshes" )
		{
			sstream >> m_num_meshes;
			m_meshes.reserve( m_num_meshes );
		}
		else if ( command == "joints" )
		{
			discard( sstream, "{" );
			md5_joint joint;
			for ( size_t i = 0; i < m_num_joints; ++i )
			{
				int parent_id;
				sstream >> joint.name >> parent_id;
				discard( sstream, "(" );
				sstream >> joint.pos.x >> joint.pos.y >> joint.pos.z;
				discard( sstream, ")" );
				discard( sstream, "(" );
				sstream >> joint.orient.x >> joint.orient.y >> joint.orient.z;
				remove_quotes( joint.name );
				unit_quat_w( joint.orient );
				m_joints.push_back( joint );
				next_line( sstream );
			}
			discard( sstream, "}" );
		}
		else if ( command == "mesh" )
		{
			md5_mesh_data* mesh = new md5_mesh_data();

			int num_verts, num_tris, num_weights;

			discard( sstream, "{" );
			sstream >> command;
			while ( command != "}" ) 
			{
				if ( command == "shader" )
				{
					sstream >> mesh->m_shader;
					remove_quotes( mesh->m_shader );
					// texturePath.replace_extension( ".tga" );
					next_line( sstream );
				}
				else if ( command == "numverts")
				{
					sstream >> num_verts; 

					{
						mesh_raw_channel* ch_pnt = mesh_raw_channel::create<md5_vtx_pnt>( num_verts );
						mesh_raw_channel* ch_t   = mesh_raw_channel::create<md5_vtx_t>( num_verts );
						mesh->m_pntdata          = (md5_vtx_pnt*)ch_pnt->data;
						mesh->m_tdata            = (md5_vtx_t*)ch_t->data;
						mesh->add_channel( ch_pnt );
						mesh->add_channel( ch_t );
					}
					mesh->m_vtx_data.resize( num_verts );

					next_line( sstream );
					std::string line;
					for ( int i = 0; i < num_verts; ++i )
					{
						md5_vtx_data& vdata = mesh->m_vtx_data[i];
						size_t weight_count;
						size_t start_weight;
						vec2 texcoord;

						std::getline( sstream, line );
						sscanf( line.c_str(), "%*s %*u ( %f %f ) %u %u", &(texcoord.x), &(texcoord.y), &(start_weight), &(weight_count) );
						vdata.start_weight = start_weight;
						vdata.weight_count = weight_count;
						mesh->m_tdata[i].texcoord = texcoord;
					}  
				}
				else if ( command == "numtris" )
				{
					sstream >> num_tris;

					mesh_raw_index_channel* ch_i = mesh_raw_index_channel::create<uint32>( num_tris * 3 );
					uint32* vtx_i                = (uint32*)ch_i->data;
					mesh->m_idata                = vtx_i;
					uint32 idx = 0;
					mesh->set_index_channel( ch_i );

					next_line( sstream );
					std::string line;
					for ( int i = 0; i < num_tris; ++i )
					{
						size_t ti0;
						size_t ti1;
						size_t ti2;

						std::getline( sstream, line );
						sscanf( line.c_str(), "%*s %*u %u %u %u )", &(ti0), &(ti1), &(ti2));

						vtx_i[idx++] = (uint32)ti0;
						vtx_i[idx++] = (uint32)ti1;
						vtx_i[idx++] = (uint32)ti2;
					}              
				}
				else if ( command == "numweights" )
				{
					sstream >> num_weights;
					mesh->m_weights.reserve( num_weights );
					next_line( sstream );
					std::string line;
					for ( int i = 0; i < num_weights; ++i )
					{
						md5_weight weight;

						std::getline( sstream, line );
						sscanf( line.c_str(), "%*s %*u %u %f ( %f %f %f )", &(weight.joint_id), &(weight.bias), &(weight.pos.x), &(weight.pos.y), &(weight.pos.z));
 						mesh->m_weights.push_back(weight);
					}
				}
				else
				{
					next_line( sstream );
				}

				sstream >> command;
			}

			prepare_mesh( mesh );

			m_meshes.push_back(mesh);
		}
		sstream >> command;
	}

	assert( m_joints.size() == m_num_joints );
	assert( m_meshes.size() == m_num_meshes );
	return true;
}

bool md5_loader::prepare_mesh( md5_mesh_data* mdata )
{
	uint32 vtx_count = mdata->m_vtx_data.size();
	md5_vtx_pnt* vtcs = mdata->m_pntdata;

	for ( uint32 i = 0; i < vtx_count; ++i )
	{
		md5_vtx_data& vdata = mdata->m_vtx_data[i];
		md5_vtx_pnt& vtc = vtcs[i];

		vtc.position = glm::vec3(0);
		vtc.normal   = glm::vec3(0);
		vtc.tangent  = glm::vec3(0);

		std::sort( mdata->m_weights.begin() + vdata.start_weight, mdata->m_weights.begin() + vdata.start_weight + vdata.weight_count, [](const md5_weight& a, const md5_weight& b) -> bool { return a.bias > b.bias; } );

		if ( vdata.weight_count > 4 )
		{
			float sum = 0.0f;
			for ( size_t j = 0; j < 4; ++j )
			{
				sum += mdata->m_weights[vdata.start_weight + j].bias;
			}
			float ratio = 1.0f / sum;
			for ( size_t j = 0; j < 4; ++j )
			{
				mdata->m_weights[vdata.start_weight + j].bias = 
					ratio * mdata->m_weights[vdata.start_weight + j].bias;
			}
			vdata.weight_count = 4;
		}

		for ( size_t j = 0; j < vdata.weight_count; ++j )
		{
			md5_weight& weight = mdata->m_weights[vdata.start_weight + j];
			md5_joint&  joint  = m_joints[weight.joint_id];

			glm::vec3 rot_pos = joint.orient * weight.pos;

			vtc.position += ( joint.pos + rot_pos ) * weight.bias;
		}
	}

	// Prepare normals
	uint32 tri_count = mdata->get_count() / 3;
	for ( unsigned int i = 0; i < tri_count; ++i )
	{
		uint32 ti0 = mdata->m_idata[ i * 3 ];
		uint32 ti1 = mdata->m_idata[ i * 3 + 1 ];
		uint32 ti2 = mdata->m_idata[ i * 3 + 2 ];
 
		glm::vec3 v1 = vtcs[ ti0 ].position;
		glm::vec3 v2 = vtcs[ ti1 ].position;
		glm::vec3 v3 = vtcs[ ti2 ].position;
		glm::vec3 xyz1 = v3 - v1;
		glm::vec3 xyz2 = v2 - v1;

		glm::vec3 normal = glm::cross( xyz1, xyz2 );

		vtcs[ ti0 ].normal += normal;
		vtcs[ ti1 ].normal += normal;
		vtcs[ ti2 ].normal += normal;

		const vec2& w1 = mdata->m_tdata[ ti0 ].texcoord;
		const vec2& w2 = mdata->m_tdata[ ti1 ].texcoord;
		const vec2& w3 = mdata->m_tdata[ ti2 ].texcoord;

		vec2 st1 = w3 - w1;
		vec2 st2 = w2 - w1;

		float coef = 1.0f / (st1.x * st2.y - st2.x * st1.y);

		vec3 tangent = (( xyz1 * st2.y ) - ( xyz2 * st1.y )) * coef;

		vtcs[ ti0 ].tangent += tangent;
		vtcs[ ti1 ].tangent += tangent;
		vtcs[ ti2 ].tangent += tangent;
	}

	for ( size_t i = 0; i < vtx_count; ++i )
	{
		md5_vtx_data& vdata = mdata->m_vtx_data[i];

		glm::vec3 normal  = glm::normalize( vtcs[i].normal );
		glm::vec3 tangent = glm::normalize( vtcs[i].tangent );
		vtcs[i].normal   = normal;
		vtcs[i].tangent  = tangent;

 		vdata.normal  = glm::vec3(0);
 		vdata.tangent = glm::vec3(0);
 
 		for ( size_t j = 0; j < vdata.weight_count; ++j )
 		{
 			const md5_weight& weight = mdata->m_weights[vdata.start_weight + j];
 			const md5_joint&  joint  = m_joints[weight.joint_id];
 			vdata.normal  += ( normal  * joint.orient ) * weight.bias;
 			vdata.tangent += ( tangent * joint.orient ) * weight.bias;
 		}
	}

	return true;
}


md5_animation::md5_animation()
	: m_md5_version( 0 )
	, m_num_frames( 0 )
	, m_num_joints( 0 )
	, m_frame_rate( 0 )
	, m_num_animated_components( 0 )
	, m_anim_duration( 0 )
	, m_frame_duration( 0 )
{

}

md5_animation::~md5_animation()
{

}

bool md5_animation::load_animation( stream& source )
{
	std::vector<md5_joint_info> joint_infos;
	std::vector<transform>      base_frames;
	m_num_frames = 0;

	std_stream sstream( &source );
	std::string command;

	sstream >> command;
	while ( !sstream.eof() )
	{
		if ( command == "MD5Version" )
		{
			sstream >> m_md5_version;
			assert( m_md5_version == 10 );
		}
		else if ( command == "commandline" )
		{
			next_line( sstream ); 
		}
		else if ( command == "numFrames" )
		{
			sstream >> m_num_frames;
			next_line( sstream ); 
		}
		else if ( command == "numJoints" )
		{
			sstream >> m_num_joints;
			m_joints.reserve( m_num_joints );
			next_line( sstream ); 
		}
		else if ( command == "frameRate" )
		{
			sstream >> m_frame_rate;
			next_line( sstream ); 
		}
		else if ( command == "numAnimatedComponents" )
		{
			sstream >> m_num_animated_components;
			next_line( sstream ); 
		}
		else if ( command == "hierarchy" )
		{
			discard( sstream, "{" );
			for ( size_t i = 0; i < m_num_joints; ++i )
			{
				md5_joint_info joint;
				sstream >> joint.name >> joint.parent_id >> joint.flags >> joint.start_index;
				remove_quotes( joint.name );
				joint_infos.push_back( joint );
				m_joints.push_back( md5_joint( joint.parent_id, m_num_frames ) );
				next_line( sstream ); 
			}
			discard( sstream, "}" );
		}
		else if ( command == "bounds" )
		{
			discard( sstream, "{" );
			next_line( sstream ); 
			for ( size_t i = 0; i < m_num_frames; ++i ) 
			{
//  				vec3 min;
//  				vec3 max;
//  				discard( sstream, "(" );
//  				sstream >> min.x >> min.y >> min.z;
//  				discard( sstream, ")" );
//  				discard( sstream, "(" );
//  				sstream >> max.x >> max.y >> max.z;
// 				m_bounds.push_back( bound );
				next_line( sstream ); 
			}

			discard( sstream, "}" );
			next_line( sstream ); 
		}
		else if ( command == "baseframe" )
		{
			discard( sstream, "{" );
			next_line( sstream ); 

			for ( size_t i = 0; i < m_num_joints; ++i )
			{
				transform base_frame;
				vec3 pos;
				quat orient;
				discard( sstream, "(" );
				sstream >> pos.x >> pos.y >> pos.z;
				discard( sstream, ")" );
				discard( sstream, "(" );
				sstream >> orient.x >> orient.y >> orient.z;
				next_line( sstream ); 

				base_frames.emplace_back( pos, orient );
			}
			discard( sstream, "}" );
			next_line( sstream ); 
		}
		else if ( command == "frame" )
		{
			std::vector<float> frame;
			int frame_id;
			sstream >> frame_id;
			discard( sstream, "{" );
			next_line( sstream ); 

			frame.reserve( m_num_animated_components );
			char buf[50];
			for ( size_t i = 0; i < m_num_animated_components; ++i )
			{
				sstream >> buf;
				frame.push_back((float)atof(buf));
			}

			build_frame_skeleton( joint_infos, base_frames, frame );

			discard( sstream, "}" );
			next_line( sstream ); 
		}

		sstream >> command;
	} 


	m_frame_duration = 1.0f / (float)m_frame_rate;
	m_anim_duration = ( m_frame_duration * (float)m_num_frames );

	return true;
}


void nv::md5_animation::update_skeleton( std::vector<transform>& skeleton, float anim_time ) const
{
	NV_ASSERT( skeleton.size() == m_num_joints, "Incompatible skeleton passed!" );
	anim_time = glm::clamp( anim_time, 0.0f, m_anim_duration );
	float frame_num = anim_time * (float)m_frame_rate;
	size_t frame0 = (size_t)floorf( frame_num );
	size_t frame1 = (size_t)ceilf( frame_num );
	frame0 = frame0 % m_num_frames;
	frame1 = frame1 % m_num_frames;

	float interpolation = fmodf( anim_time, m_frame_duration ) / m_frame_duration;

	for ( size_t i = 0; i < m_num_joints; ++i )
	{
		const transform_vector& keys = m_joints[i].keys;
		skeleton[i] = interpolate( keys.get(frame0), keys.get(frame1), interpolation );
	}
}

void md5_animation::build_frame_skeleton( const std::vector<md5_joint_info>& joint_infos, const std::vector<transform>& base_frames, const std::vector<float>& frame_data )
{
	size_t index = m_joints[0].keys.size();
	for ( unsigned int i = 0; i < joint_infos.size(); ++i )
	{
		unsigned int j = 0;

		const md5_joint_info& jinfo = joint_infos[i];


		int parent_id = jinfo.parent_id;

		vec3 pos    = base_frames[i].get_position();
		quat orient = base_frames[i].get_orientation();
		if ( jinfo.flags & 1 )  pos.x    = frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 2 )  pos.y    = frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 4 )  pos.z    = frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 8 )  orient.x = frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 16 ) orient.y = frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 32 )	orient.z = frame_data[ jinfo.start_index + j++ ];
		unit_quat_w( orient );

		if ( parent_id >= 0 ) // Has a parent joint
		{
			const transform_vector& ptv = m_joints[ size_t( parent_id ) ].keys;
			transform ptr;
			if ( ptv.size() > index ) ptr = ptv.get( index );
			glm::vec3 rot_pos = ptr.get_orientation() * pos;

			pos    = ptr.get_position() + rot_pos;
			orient = ptr.get_orientation() * orient;

			orient = glm::normalize( orient );
		}

		m_joints[i].keys.insert( transform( pos, orient ) );
	}
}

mesh_data* nv::md5_loader::release_mesh_data( size_t index )
{
	mesh_data* result = m_meshes[ index ];
	m_meshes[ index ] = nullptr;
	return result;
}

md5_mesh_instance* nv::md5_mesh_data::spawn() const
{
	return new md5_mesh_instance( this );
}

nv::md5_loader::~md5_loader()
{
	for ( auto m : m_meshes ) { if (m) delete m; }
}

nv::md5_mesh_instance::md5_mesh_instance( const md5_mesh_data* a_data ) 
	: m_data( a_data ), m_size( 0 ), m_indices( 0 ), m_pntdata( nullptr )
{
	m_size = m_data->m_vtx_data.size();
	m_indices = m_data->get_count();
	m_pntdata = new md5_vtx_pnt[ m_size ];
	std::copy_n( m_data->m_pntdata, m_size, m_pntdata );
}

void nv::md5_mesh_instance::apply( const std::vector< transform >& skeleton )
{
	NV_PROFILE("md5::apply");
	char* fill_ptr = (char*)&(m_pntdata[0]);
	std::fill( fill_ptr, fill_ptr + m_size * ( sizeof( md5_vtx_pnt ) ), 0 );
	for ( unsigned int i = 0; i < m_size; ++i )
	{
		const md5_vtx_data& vert = m_data->m_vtx_data[i];
		md5_vtx_pnt& result = m_pntdata[i];

		for ( size_t j = 0; j < vert.weight_count; ++j )
		{
			const md5_weight& weight = m_data->m_weights[vert.start_weight + j];
			const transform& joint = skeleton[weight.joint_id];

			glm::vec3 rot_pos = joint.get_orientation() * weight.pos;
			result.position += ( joint.get_position() + rot_pos ) * weight.bias;

			result.normal  += ( joint.get_orientation() * vert.normal  ) * weight.bias;
			result.tangent += ( joint.get_orientation() * vert.tangent ) * weight.bias;
		}
	}
}
