// Copyright (C) 2012-2013 ChaosForge / Kornel Kisielewicz
// http://chaosforge.org/
//
// This file is part of NV Libraries.
// For conditions of distribution and use, see copyright notice in nv.hh

#include "nv/formats/md5_loader.hh"

#include <glm/gtc/constants.hpp>
#include "nv/logging.hh"
#include "nv/io/std_stream.hh"
#include <cstring>

using namespace nv;

// based on http://tfc.duke.free.fr/coding/md5-specs-en.html

static void next_line( std::istream& stream )
{
	stream.ignore( std::numeric_limits<std::streamsize>::max(), '\n' );
}

static inline void discard( std::istream& stream, const std::string& token )
{
//	stream.ignore( std::numeric_limits<std::streamsize>::max(), ' ' );
	std::string discarded;
	stream >> discarded;
	assert( discarded == token );
}


static void remove_quotes( std::string& str )
{
	size_t n;
	while ( ( n = str.find('\"') ) != std::string::npos ) str.erase(n,1);
}

static void unit_quat_w( glm::quat& quat )
{
	float t = 1.0f - ( quat.x * quat.x ) - ( quat.y * quat.y ) - ( quat.z * quat.z );
	quat.w = ( t < 0.0f ? 0.0f : -sqrtf(t) );
}

bool md5_loader::load( stream& source )
{
	std_stream sstream( &source );
	std::string command;

	sstream >> command;
	while ( !sstream.eof() )
	{
		if ( command == "MD5Version" )
		{
			sstream >> m_md5_version;
			assert( m_md5_version == 10 );
		}
		else if ( command == "commandline" )
		{
			next_line( sstream ); 
		}
		else if ( command == "numJoints" )
		{
			sstream >> m_num_joints;
			m_joints.reserve( m_num_joints );
		}
		else if ( command == "numMeshes" )
		{
			sstream >> m_num_meshes;
			m_meshes.reserve( m_num_meshes );
		}
		else if ( command == "joints" )
		{
			discard( sstream, "{" );
			md5_joint joint;
			for ( size_t i = 0; i < m_num_joints; ++i )
			{
				sstream >> joint.name >> joint.parent_id;
				discard( sstream, "(" );
				sstream >> joint.pos.x >> joint.pos.y >> joint.pos.z;
				discard( sstream, ")" );
				discard( sstream, "(" );
				sstream >> joint.orient.x >> joint.orient.y >> joint.orient.z;
				remove_quotes( joint.name );
				unit_quat_w( joint.orient );
				m_joints.push_back( joint );
				next_line( sstream );
			}
			discard( sstream, "}" );
		}
		else if ( command == "mesh" )
		{
			md5_mesh* mesh = new md5_mesh;
			int num_verts, num_tris, num_weights;

			discard( sstream, "{" );
			sstream >> command;
			while ( command != "}" ) 
			{
				if ( command == "shader" )
				{
					sstream >> mesh->shader;
					remove_quotes( mesh->shader );
					// texturePath.replace_extension( ".tga" );
					next_line( sstream );
				}
				else if ( command == "numverts")
				{
					sstream >> num_verts; 
					next_line( sstream );
					std::string line;
					for ( int i = 0; i < num_verts; ++i )
					{
						md5_vertex vert;

						std::getline( sstream, line );
						sscanf( line.c_str(), "%*s %*u ( %f %f ) %u %u", &(vert.texcoord.x), &(vert.texcoord.y), &(vert.start_weight), &(vert.weight_count) );

// 						std::string ignore;
// 						discard( sstream, "vert" );
// 						sstream >> ignore;
// 						discard( sstream, "(" );
// 						sstream >> vert.texcoord.x >> vert.texcoord.y;
// 						discard( sstream, ")" );
// 						sstream >> vert.start_weight >> vert.weight_count;
// 						next_line( sstream );

						mesh->verts.push_back(vert);
						mesh->texcoord_buffer.push_back( vert.texcoord );
					}  
				}
				else if ( command == "numtris" )
				{
					sstream >> num_tris;
					next_line( sstream );
					std::string line;
					for ( int i = 0; i < num_tris; ++i )
					{
						md5_triangle tri;

						std::getline( sstream, line );
						sscanf( line.c_str(), "%*s %*u %u %u %u )", &(tri.indices[0]), &(tri.indices[1]), &(tri.indices[2]));

// 						std::string ignore;
// 						discard( sstream, "tri" );
// 						sstream >> ignore >> tri.indices[0] >> tri.indices[1] >> tri.indices[2];
// 						next_line( sstream );

						mesh->tris.push_back( tri );
						mesh->index_buffer.push_back( (uint32)tri.indices[0] );
						mesh->index_buffer.push_back( (uint32)tri.indices[1] );
						mesh->index_buffer.push_back( (uint32)tri.indices[2] );
					}              
				}
				else if ( command == "numweights" )
				{
					sstream >> num_weights;
					mesh->weights.reserve( num_weights );
					next_line( sstream );
					std::string line;
					for ( int i = 0; i < num_weights; ++i )
					{
						md5_weight weight;

						std::getline( sstream, line );
						sscanf( line.c_str(), "%*s %*u %u %f ( %f %f %f )", &(weight.joint_id), &(weight.bias), &(weight.pos.x), &(weight.pos.y), &(weight.pos.z));

//  						std::string ignore;
//  						discard( sstream, "weight" );
//  						sstream >> ignore >> weight.joint_id >> weight.bias;
//  						discard( sstream, "(" );
//  						sstream >> weight.pos.x >> weight.pos.y >> weight.pos.z;
//  						discard( sstream, ")" );
//  						next_line( sstream );
 						mesh->weights.push_back(weight);
					}
				}
				else
				{
					next_line( sstream );
				}

				sstream >> command;
			}

			prepare_mesh( mesh );
			prepare_normals( mesh );

			m_meshes.push_back(mesh);
		}
		sstream >> command;
	}

	assert( m_joints.size() == m_num_joints );
	assert( m_meshes.size() == m_num_meshes );
	return true;
}

bool md5_loader::prepare_mesh( md5_mesh* mesh )
{
	mesh->position_buffer.clear();
	mesh->texcoord_buffer.clear();

	for ( uint32 i = 0; i < mesh->verts.size(); ++i )
	{
		md5_vertex& vert = mesh->verts[i];

		vert.position = glm::vec3(0);
		vert.normal   = glm::vec3(0);
		vert.tangent  = glm::vec3(0);

		for ( size_t j = 0; j < vert.weight_count; ++j )
		{
			md5_weight& weight = mesh->weights[vert.start_weight + j];
			md5_joint&  joint  = m_joints[weight.joint_id];

			glm::vec3 rot_pos = joint.orient * weight.pos;

			vert.position += ( joint.pos + rot_pos ) * weight.bias;
		}

		mesh->position_buffer.push_back(vert.position);
		mesh->texcoord_buffer.push_back(vert.texcoord);
	}

	return true;
}

bool md5_loader::prepare_normals( md5_mesh* mesh )
{
	mesh->normal_buffer.clear();

	for ( unsigned int i = 0; i < mesh->tris.size(); ++i )
	{
		const md5_triangle& tri = mesh->tris[i];
		glm::vec3 v1 = mesh->verts[ tri.indices[0] ].position;
		glm::vec3 v2 = mesh->verts[ tri.indices[1] ].position;
		glm::vec3 v3 = mesh->verts[ tri.indices[2] ].position;
		glm::vec3 xyz1 = v3 - v1;
		glm::vec3 xyz2 = v2 - v1;

		glm::vec3 normal = glm::cross( xyz1, xyz2 );

		mesh->verts[ tri.indices[0] ].normal += normal;
		mesh->verts[ tri.indices[1] ].normal += normal;
		mesh->verts[ tri.indices[2] ].normal += normal;

		const vec2& w1 = mesh->verts[ tri.indices[0] ].texcoord;
		const vec2& w2 = mesh->verts[ tri.indices[1] ].texcoord;
		const vec2& w3 = mesh->verts[ tri.indices[2] ].texcoord;

		vec2 st1 = w3 - w1;
		vec2 st2 = w2 - w1;

		float coef = 1.0f / (st1.x * st2.y - st2.x * st1.y);

		vec3 tangent = (( xyz1 * st2.y ) - ( xyz2 * st1.y )) * coef;

		mesh->verts[ tri.indices[0] ].tangent += tangent;
		mesh->verts[ tri.indices[1] ].tangent += tangent;
		mesh->verts[ tri.indices[2] ].tangent += tangent;
	}

	for ( size_t i = 0; i < mesh->verts.size(); ++i )
	{
		md5_vertex& vert = mesh->verts[i];

		glm::vec3 normal  = glm::normalize( vert.normal );
		glm::vec3 tangent = glm::normalize( vert.tangent );
		mesh->normal_buffer.push_back( normal );
		mesh->tangent_buffer.push_back( tangent );

		vert.normal  = glm::vec3(0);
		vert.tangent = glm::vec3(0);

		for ( size_t j = 0; j < vert.weight_count; ++j )
		{
			const md5_weight& weight = mesh->weights[vert.start_weight + j];
			const md5_joint&  joint  = m_joints[weight.joint_id];
			vert.normal  += ( normal  * joint.orient ) * weight.bias;
			vert.tangent += ( tangent * joint.orient ) * weight.bias;
		}
	}

	return true;
}

mesh_data_old* nv::md5_loader::release_submesh_data( uint32 mesh_id )
{
	mesh_data_creator m;
	m.get_positions().assign( m_meshes[mesh_id]->position_buffer.begin(), m_meshes[mesh_id]->position_buffer.begin() );
	m.get_normals()  .assign( m_meshes[mesh_id]->normal_buffer.begin(),   m_meshes[mesh_id]->normal_buffer.begin() );
	m.get_tangents() .assign( m_meshes[mesh_id]->tangent_buffer.begin(),  m_meshes[mesh_id]->tangent_buffer.begin() );
	m.get_texcoords().assign( m_meshes[mesh_id]->texcoord_buffer.begin(), m_meshes[mesh_id]->texcoord_buffer.begin() );
	m.get_indices()  .assign( m_meshes[mesh_id]->index_buffer.begin(),    m_meshes[mesh_id]->index_buffer.begin() );

	return m.release();
}

/*
mesh* md5_loader::release_mesh()
{
	mesh* m = new mesh();
	auto position = m->add_attribute< vec3 >( "nv_position" );
	auto normal   = m->add_attribute< vec3 >( "nv_normal" );
	auto texcoord = m->add_attribute< vec2 >( "nv_texcoord" );
	auto tangent  = m->add_attribute< vec3 >( "nv_tangent" );
	auto indices  = m->add_indices< uint32 >();

	position->get().assign( m_meshes[0].position_buffer.begin(), m_meshes[0].position_buffer.end() );
	normal  ->get().assign( m_meshes[0].normal_buffer.begin(),   m_meshes[0].normal_buffer.end() );
	texcoord->get().assign( m_meshes[0].texcoord_buffer.begin(), m_meshes[0].texcoord_buffer.end() );
	tangent ->get().assign( m_meshes[0].tangent_buffer.begin(),  m_meshes[0].tangent_buffer.end() );
	indices ->get().assign( m_meshes[0].index_buffer.begin(),    m_meshes[0].index_buffer.end() );

	m_size = m_meshes[0].index_buffer.size();
	return m;
}
*/

md5_animation::md5_animation()
	: m_md5_version( 0 )
	, m_num_frames( 0 )
	, m_num_joints( 0 )
	, m_frame_rate( 0 )
	, m_num_animated_components( 0 )
	, m_anim_duration( 0 )
	, m_frame_duration( 0 )
	, m_anim_time( 0 )
{

}

md5_animation::~md5_animation()
{

}

bool md5_animation::load_animation( stream& source )
{
	m_joint_infos.clear();
	m_bounds.clear();
	m_base_frames.clear();
	m_frames.clear();
	m_animated_skeleton.joints.clear();
	m_num_frames = 0;

	std_stream sstream( &source );
	std::string command;

	sstream >> command;
	while ( !sstream.eof() )
	{
		if ( command == "MD5Version" )
		{
			sstream >> m_md5_version;
			assert( m_md5_version == 10 );
		}
		else if ( command == "commandline" )
		{
			next_line( sstream ); 
		}
		else if ( command == "numFrames" )
		{
			sstream >> m_num_frames;
			next_line( sstream ); 
		}
		else if ( command == "numJoints" )
		{
			sstream >> m_num_joints;
			next_line( sstream ); 
		}
		else if ( command == "frameRate" )
		{
			sstream >> m_frame_rate;
			next_line( sstream ); 
		}
		else if ( command == "numAnimatedComponents" )
		{
			sstream >> m_num_animated_components;
			next_line( sstream ); 
		}
		else if ( command == "hierarchy" )
		{
			discard( sstream, "{" );
			for ( size_t i = 0; i < m_num_joints; ++i )
			{
				md5_joint_info joint;
				sstream >> joint.name >> joint.parent_id >> joint.flags >> joint.start_index;
				remove_quotes( joint.name );
				m_joint_infos.push_back( joint );
				next_line( sstream ); 
			}
			discard( sstream, "}" );
		}
		else if ( command == "bounds" )
		{
			discard( sstream, "{" );
			next_line( sstream ); 
			for ( size_t i = 0; i < m_num_frames; ++i ) 
			{
				md5_bound bound;
				discard( sstream, "(" );
				sstream >> bound.min.x >> bound.min.y >> bound.min.z;
				discard( sstream, ")" );
				discard( sstream, "(" );
				sstream >> bound.max.x >> bound.max.y >> bound.max.z;

				m_bounds.push_back( bound );

				next_line( sstream ); 
			}

			discard( sstream, "}" );
			next_line( sstream ); 
		}
		else if ( command == "baseframe" )
		{
			discard( sstream, "{" );
			next_line( sstream ); 

			for ( size_t i = 0; i < m_num_joints; ++i )
			{
				md5_base_frame base_frame;
				discard( sstream, "(" );
				sstream >> base_frame.pos.x >> base_frame.pos.y >> base_frame.pos.z;
				discard( sstream, ")" );
				discard( sstream, "(" );
				sstream >> base_frame.orient.x >> base_frame.orient.y >> base_frame.orient.z;
				next_line( sstream ); 

				m_base_frames.push_back( base_frame );
			}
			discard( sstream, "}" );
			next_line( sstream ); 
		}
		else if ( command == "frame" )
		{
			md5_frame_data frame;
			sstream >> frame.frame_id;
			discard( sstream, "{" );
			next_line( sstream ); 

			frame.frame_data.reserve( m_num_animated_components );
			char buf[50];
			for ( size_t i = 0; i < m_num_animated_components; ++i )
			{
				sstream >> buf;
				frame.frame_data.push_back((float)atof(buf));
			}

			m_frames.push_back(frame);

			build_frame_skeleton( m_skeletons, m_joint_infos, m_base_frames, frame );

			discard( sstream, "}" );
			next_line( sstream ); 
		}

		sstream >> command;
	} 

	m_animated_skeleton.joints.assign( m_num_joints, md5_skeleton_joint() );

	m_frame_duration = 1.0f / (float)m_frame_rate;
	m_anim_duration = ( m_frame_duration * (float)m_num_frames );
	m_anim_time = 0.0f;

	assert( m_joint_infos.size() == m_num_joints );
	assert( m_bounds.size()      == m_num_frames );
	assert( m_base_frames.size() == m_num_joints );
	assert( m_frames.size()      == m_num_frames );
	assert( m_skeletons.size()   == m_num_frames );

	return true;
}

void md5_animation::update( float delta_time )
{
	if ( m_num_frames < 1 ) return;

	m_anim_time += delta_time;

	while ( m_anim_time > m_anim_duration ) m_anim_time -= m_anim_duration;
	while ( m_anim_time < 0.0f ) m_anim_time += m_anim_duration;

	float frame_num = m_anim_time * (float)m_frame_rate;
	size_t frame0 = (size_t)floorf( frame_num );
	size_t frame1 = (size_t)ceilf( frame_num );
	frame0 = frame0 % m_num_frames;
	frame1 = frame1 % m_num_frames;

	float interpolate = fmodf( m_anim_time, m_frame_duration ) / m_frame_duration;

	interpolate_skeletons( m_animated_skeleton, m_skeletons[frame0], m_skeletons[frame1], interpolate );
}

void md5_animation::build_frame_skeleton( md5_frame_skeleton_list& skeletons, const md5_joint_info_list& joint_infos, const md5_base_frame_list& base_frames, const md5_frame_data& frame_data )
{
	md5_frame_skeleton skeleton;

	for ( unsigned int i = 0; i < joint_infos.size(); ++i )
	{
		unsigned int j = 0;

		const md5_joint_info& jinfo = joint_infos[i];
		md5_skeleton_joint animated_joint = base_frames[i];

		animated_joint.parent = jinfo.parent_id;

		if ( jinfo.flags & 1 )  animated_joint.pos.x    = frame_data.frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 2 )  animated_joint.pos.y    = frame_data.frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 4 )  animated_joint.pos.z    = frame_data.frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 8 )  animated_joint.orient.x = frame_data.frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 16 ) animated_joint.orient.y = frame_data.frame_data[ jinfo.start_index + j++ ];
		if ( jinfo.flags & 32 )	animated_joint.orient.z = frame_data.frame_data[ jinfo.start_index + j++ ];

		unit_quat_w( animated_joint.orient );

		if ( animated_joint.parent >= 0 ) // Has a parent joint
		{
			md5_skeleton_joint& pjoint = skeleton.joints[static_cast< size_t >( animated_joint.parent ) ];
			glm::vec3 rot_pos = pjoint.orient * animated_joint.pos;

			animated_joint.pos    = pjoint.pos + rot_pos;
			animated_joint.orient = pjoint.orient * animated_joint.orient;

			animated_joint.orient = glm::normalize( animated_joint.orient );
		}

		skeleton.joints.push_back( animated_joint );
	}

	skeletons.push_back( skeleton );
}

void md5_animation::interpolate_skeletons( md5_frame_skeleton& final_skeleton, const md5_frame_skeleton& skeleton0, const md5_frame_skeleton& skeleton1, float interpolate )
{
	for ( size_t i = 0; i < m_num_joints; ++i )
	{
		md5_skeleton_joint& final_joint = final_skeleton.joints[i];
		const md5_skeleton_joint& joint0 = skeleton0.joints[i]; 
		const md5_skeleton_joint& joint1 = skeleton1.joints[i];

		final_joint.parent = joint0.parent;

		final_joint.orient = glm::slerp( joint0.orient, joint1.orient, interpolate );
		final_joint.pos    = glm::mix( joint0.pos, joint1.pos, interpolate );
	}
}

bool md5_loader::check_animation( const md5_animation& animation ) const
{
	if ( m_num_joints != animation.get_num_joints() )
	{
		return false;
	}

	for ( uint32 i = 0; i < m_joints.size(); ++i )
	{
		const md5_joint& mjoint = m_joints[i];
		const md5_animation::md5_joint_info& ajoint = animation.get_joint_info( i );

		if ( mjoint.name != ajoint.name || mjoint.parent_id != ajoint.parent_id )
		{
			return false;
		}
	}

	return true;
}

bool md5_loader::prepare_animated_mesh( md5_mesh* mesh, const md5_animation::md5_frame_skeleton& skel )
{
	for ( unsigned int i = 0; i < mesh->verts.size(); ++i )
	{
		const md5_vertex& vert = mesh->verts[i];
		glm::vec3& pos     = mesh->position_buffer[i];
		glm::vec3& normal  = mesh->normal_buffer[i];
		glm::vec3& tangent = mesh->tangent_buffer[i];

		pos     = glm::vec3(0);
		normal  = glm::vec3(0);
		tangent = glm::vec3(0);

		for ( size_t j = 0; j < vert.weight_count; ++j )
		{
			const md5_weight& weight = mesh->weights[vert.start_weight + j];
			const md5_animation::md5_skeleton_joint& joint = skel.joints[weight.joint_id];

			glm::vec3 rot_pos = joint.orient * weight.pos;
			pos += ( joint.pos + rot_pos ) * weight.bias;

			normal  += ( joint.orient * vert.normal  ) * weight.bias;
			tangent += ( joint.orient * vert.tangent ) * weight.bias;
		}
	}
	return true;
}

void md5_loader::apply( const md5_animation& animation )
{
	const md5_animation::md5_frame_skeleton& skeleton = animation.get_skeleton();

	for ( unsigned int i = 0; i < m_meshes.size(); ++i )
	{
		prepare_animated_mesh( m_meshes[i], skeleton );
	}
}

size_t nv::md5_loader::get_size()
{
	return m_size;
}
