// Copyright (C) 2012-2013 ChaosForge / Kornel Kisielewicz
// http://chaosforge.org/
//
// This file is part of NV Libraries.
// For conditions of distribution and use, see copyright notice in nv.hh

#include "nv/formats/obj_loader.hh"
#include "nv/io/std_stream.hh"
#include <sstream>

using namespace nv;

struct obj_reader
{
	std::vector< vec3 > v;
	std::vector< vec3 > n;
	std::vector< vec2 > t;

	std::string line;
	std::string cmd;

	std::size_t size;

	obj_reader();
	bool read_stream( std::istream& stream );

	virtual std::size_t add_face( uint32* vi, uint32* ti, uint32* ni, size_t count ) = 0;
	virtual ~obj_reader(){}
};

obj_reader::obj_reader()
{
	// push in dummy 0-index objects for faster indexing
	v.push_back( vec3() );
	n.push_back( vec3() );
	t.push_back( vec2() );
	size = 0;
}

bool obj_reader::read_stream( std::istream& stream )
{
	f32 x, y, z;

	while ( std::getline( stream, line ) )
	{
		if ( line.length() < 3 || line[0] == '#' )
		{
			continue;
		}

		std::istringstream ss(line);
		ss >> cmd;

		if ( cmd == "v" )
		{
			ss >> x >> y >> z;
			v.push_back( vec3( x, y, z ) );
			continue;
		}

		if ( cmd == "vn" )
		{
			ss >> x >> y >> z;
			n.push_back( vec3( x, y, z ) );
			continue;
		}

		if ( cmd == "vt" )
		{
			ss >> x >> y;
			t.push_back( vec2( x, 1.0f - y ) );
			continue;
		}

		if ( cmd == "f" )
		{
			ss >> cmd;

			uint32 vis[8];
			uint32 tis[8];
			uint32 nis[8];
			bool   normals = false;
			uint32 count = 0;

			while ( !ss.fail() )
			{
				char ch;

				std::istringstream ss2( cmd );
				ss2 >> vis[count] >> ch;
				ss2 >> tis[count] >> ch;
				if ( ch == '/')
				{
					normals = true;
					ss2 >> nis[count];
				}

				ss >> cmd;
				count++;
			}

			size += add_face( vis, tis, normals ? nis : nullptr, count );
			continue;
		}

		if ( cmd == "g" || cmd == "s" )
		{
			// ignored
			continue;
		}

		// unknown command
	}

	return true;
}


struct mesh_obj_reader : public obj_reader
{
	mesh_obj_reader( mesh* m ) : m_mesh( m ), m_position( nullptr ), m_normal( nullptr ), m_tex_coord( nullptr ), m_tangent( nullptr ) {}
	virtual std::size_t add_face( uint32* v, uint32* t, uint32* n, size_t count );
	virtual void calculate_tangents();

	vertex_attribute< vec3 >* m_position;
	vertex_attribute< vec3 >* m_normal;
	vertex_attribute< vec2 >* m_tex_coord;
	vertex_attribute< vec4 >* m_tangent;
	mesh* m_mesh;
};

std::size_t mesh_obj_reader::add_face( uint32* vi, uint32* ti, uint32* ni, size_t count )
{
	if ( count < 3 )
	{
		// TODO : report error?
		return 0;
	}

	if ( m_position == nullptr )
	{
		m_position  = m_mesh->add_attribute< vec3 >( "position" );
	}
	if ( m_tex_coord == nullptr )
	{
		m_tex_coord = m_mesh->add_attribute< vec2 >( "texcoord" );
	}
	if ( m_normal == nullptr && ni != nullptr )
	{
		m_normal = m_mesh->add_attribute< vec3 >( "normal" );
	}

	// TODO : support if normals not present;

	std::vector< vec3 >& vp = m_position->get();
	std::vector< vec2 >& vt = m_tex_coord->get();
	std::vector< vec3 >& vn = m_normal->get();

	std::size_t result = 0;

	// Simple triangulation - obj's shouldn't have more than quads anyway
	for ( size_t i = 2; i < count; ++i )
	{
		result++;
		vp.push_back( v[ vi[ 0 ] ] );   vt.push_back( t[ ti[ 0 ] ] );   vn.push_back( n[ ni[ 0 ] ] );
		vp.push_back( v[ vi[ i-1 ] ] ); vt.push_back( t[ ti[ i-1 ] ] ); vn.push_back( n[ ni[ i-1 ] ] );
		vp.push_back( v[ vi[ i ] ] );   vt.push_back( t[ ti[ i ] ] );   vn.push_back( n[ ni[ i ] ] );
	}

	return result;
}

// based on http://www.terathon.com/code/tangent.html
void mesh_obj_reader::calculate_tangents()
{
	m_tangent = m_mesh->add_attribute< vec4 >( "tangent" );

	std::vector< vec3 >& vp = m_position->get();
	std::vector< vec2 >& vt = m_tex_coord->get();
	std::vector< vec3 >& vn = m_normal->get();
	std::vector< vec4 >& tg = m_tangent->get();

	std::size_t count  = vp.size();
	std::size_t tcount = count / 3;

	std::vector< vec3 > tan1( count );
	std::vector< vec3 > tan2( count );
	tg.resize( count );

	for (std::size_t a = 0; a < tcount; ++a )
	{
		uint32 i1 = a * 3;
		uint32 i2 = a * 3 + 1;
		uint32 i3 = a * 3 + 2;

		// TODO: simplify

		const vec3& v1 = vp[i1];
		const vec3& v2 = vp[i2];
		const vec3& v3 = vp[i3];

		const vec2& w1 = vt[i1];
		const vec2& w2 = vt[i2];
		const vec2& w3 = vt[i3];

		vec3 xyz1 = v2 - v1;
		vec3 xyz2 = v3 - v1;
		vec2 st1  = w2 - w1;
		vec2 st2  = w3 - w1;

		float s1 = w2.x - w1.x;
		float t1 = w2.y - w1.y;
		float s2 = w3.x - w1.x;
		float t2 = w3.y - w1.y;

		float r = 1.0f / (s1 * t2 - s2 * t1);

		vec3 sdir = ( t2 * xyz1 - t1 * xyz2 ) * r;
		vec3 tdir = ( s1 * xyz2 - s2 * xyz1 ) * r;

		// the += below obviously doesn't make sense in this case, but I'll
		// leave it here for when I move to indices
		tan1[i1] += sdir;
		tan1[i2] += sdir;
		tan1[i3] += sdir;

		tan2[i1] += tdir;
		tan2[i2] += tdir;
		tan2[i3] += tdir;
	}

	for (std::size_t a = 0; a < count; ++a )
	{
		const vec3& n = vn[a];
		const vec3& t = tan1[a];

		tg[a] = vec4( glm::normalize(t - n * glm::dot( n, t )), 
		    (glm::dot(glm::cross(n, t), tan2[a]) < 0.0f) ? -1.0f : 1.0f );
	}

}

nv::obj_loader::obj_loader( bool tangents )
	: m_mesh( nullptr ), m_tangents( tangents )
{

}

nv::obj_loader::~obj_loader()
{
	delete m_mesh;
}

bool nv::obj_loader::load( stream& source )
{
	if ( m_mesh != nullptr )
	{
		delete m_mesh;
	}
	m_mesh = new mesh();
	mesh_obj_reader reader( m_mesh );
	std_stream sstream( &source );
	reader.read_stream( sstream );
	m_size = reader.size;
	if ( m_tangents )
	{
		reader.calculate_tangents();
	}
	return true;
}
