// Copyright (C) 2012-2015 ChaosForge Ltd
// http://chaosforge.org/
//
// This file is part of Nova libraries. 
// For conditions of distribution and use, see copying.txt file in root folder.

#include "nv/formats/obj_loader.hh"
#include "nv/io/std_stream.hh"
#include "nv/interface/data_channel_access.hh"

#include <sstream>

using namespace nv;

struct obj_vertex_vt
{
	vec3 position;
	vec2 texcoord;

	obj_vertex_vt( vec3 a_position, vec2 a_texcoord, vec3 ) 
		: position( a_position ), texcoord( a_texcoord ) {}
};

struct obj_vertex_vtn
{
	vec3 position;
	vec2 texcoord;
	vec3 normal;

	obj_vertex_vtn( vec3 a_position, vec2 a_texcoord, vec3 a_normal ) 
		: position( a_position ), texcoord( a_texcoord ), normal( a_normal ) {}
};


struct obj_vertex_vtnt
{
	vec3 position;
	vec2 texcoord;
	vec3 normal;
	vec4 tangent;

	obj_vertex_vtnt( vec3 a_position, vec2 a_texcoord, vec3 a_normal ) 
		: position( a_position ), texcoord( a_texcoord ), normal( a_normal ) {}
};

struct obj_reader
{
	vector< vec3 > v;
	vector< vec3 > n;
	vector< vec2 > t;

	std::string line;
	std::string cmd;
	std::string name;
	std::string next_name;

	nv::size_t size;
	bool   eof;

	obj_reader();
	bool read_stream( std::istream& stream );
	virtual nv::size_t add_face( uint32* vi, uint32* ti, uint32* ni, nv::size_t count ) = 0;
	virtual nv::size_t raw_size() const = 0;
	virtual void reset() = 0;
	virtual const uint8* raw_pointer() const = 0;
	virtual void calculate_tangents() {}

	virtual ~obj_reader(){}
};

obj_reader::obj_reader()
{
	// push in dummy 0-index objects for faster indexing
	v.push_back( vec3() );
	n.push_back( vec3() );
	t.push_back( vec2() );
	size = 0;
	eof = false;
}

bool obj_reader::read_stream( std::istream& stream )
{
	name = next_name;
	bool added_faces = false;
	f32 x, y, z;
	if ( eof ) return false;

	while ( std::getline( stream, line ) )
	{
		if ( line.length() < 3 || line[0] == '#' )
		{
			continue;
		}

		std::istringstream ss(line);
		ss >> cmd;

		if ( cmd == "v" )
		{
			ss >> x >> y >> z;
			v.push_back( vec3( x, y, z ) );
			continue;
		}

		if ( cmd == "vn" )
		{
			ss >> x >> y >> z;
			n.push_back( vec3( x, y, z ) );
			continue;
		}

		if ( cmd == "vt" )
		{
			ss >> x >> y;
			t.push_back( vec2( x, 1.0f - y ) );
			continue;
		}

		if ( cmd == "f" )
		{
			added_faces = true;
			ss >> cmd;

			uint32 vis[8];
			uint32 tis[8];
			uint32 nis[8];
			bool   normals = false;
			uint32 count = 0;

			while ( !ss.fail() )
			{
				char ch;

				std::istringstream ss2( cmd );
				ss2 >> vis[count] >> ch;
				ss2 >> tis[count] >> ch;
				if ( ch == '/')
				{
					normals = true;
					ss2 >> nis[count];
				}

				ss >> cmd;
				count++;
			}

			size += add_face( vis, tis, normals ? nis : nullptr, count );
			continue;
		}

		if ( cmd == "g" )
		{
			ss >> next_name;
			if (added_faces) 
				return true;
			name = next_name;
			continue;
		}

		if ( cmd == "s" )
		{
			continue;
		}

		// unknown command
	}

	eof = true;
	return true;
}

template < typename VTX >
struct mesh_data_reader : public obj_reader
{
	mesh_data_reader( bool normals ) : m_normals( normals ) {}
	virtual nv::size_t add_face( uint32* vi, uint32* ti, uint32* ni, nv::size_t count )
	{
		if ( count < 3 ) return 0; // TODO : report error?

		// TODO : support if normals not present;
		vec3 nullvec;
		nv::size_t result = 0;
		// Simple triangulation - obj's shouldn't have more than quads anyway

		if ( m_normals )
		{
			for ( nv::size_t i = 2; i < count; ++i )
			{
				result++;
				m_data.emplace_back( v[ vi[ 0 ]   ], t[ ti[ 0   ] ], n[ ni[ 0   ] ] );
				m_data.emplace_back( v[ vi[ i-1 ] ], t[ ti[ i-1 ] ], n[ ni[ i-1 ] ] );
				m_data.emplace_back( v[ vi[ i ]   ], t[ ti[ i   ] ], n[ ni[ i   ] ] );
			}
		}
		else
		{
			for ( nv::size_t i = 2; i < count; ++i )
			{
				result++;
				m_data.emplace_back( v[ vi[ 0 ]   ], t[ ti[ 0   ] ], nullvec );
				m_data.emplace_back( v[ vi[ i-1 ] ], t[ ti[ i-1 ] ], nullvec );
				m_data.emplace_back( v[ vi[ i ]   ], t[ ti[ i   ] ], nullvec );
			}
		}
		return result;
	}
	bool m_normals;
	vector< VTX > m_data;
	virtual void reset() { m_data.clear(); }
	virtual nv::size_t raw_size() const { return m_data.size() * sizeof( VTX ); }
	virtual const uint8* raw_pointer() const { return reinterpret_cast< const uint8* >( m_data.data() ); }
};


struct mesh_data_reader_vt : public mesh_data_reader< obj_vertex_vt >
{
	mesh_data_reader_vt() : mesh_data_reader( false ) {}
};

struct mesh_data_reader_vtn : public mesh_data_reader< obj_vertex_vtn >
{
	mesh_data_reader_vtn() : mesh_data_reader( true ) {}
};

struct mesh_data_reader_vtnt : public mesh_data_reader< obj_vertex_vtnt >
{
	mesh_data_reader_vtnt() : mesh_data_reader( true ) {}

	// based on http://www.terathon.com/code/tangent.html
	void calculate_tangents()
	{
		nv::size_t count = m_data.size();
		nv::size_t tcount = count / 3;

		vector< vec3 > tan1( count );
		vector< vec3 > tan2( count );

		for ( nv::size_t a = 0; a < tcount; ++a )
		{
			nv::size_t i1 = a * 3;
			nv::size_t i2 = a * 3 + 1;
			nv::size_t i3 = a * 3 + 2;
			obj_vertex_vtnt& vtx1 = m_data[ i1 ];
			obj_vertex_vtnt& vtx2 = m_data[ i2 ];
			obj_vertex_vtnt& vtx3 = m_data[ i3 ];

			// TODO: simplify
			vec3 xyz1 = vtx2.position - vtx1.position;
			vec3 xyz2 = vtx3.position - vtx1.position;
			//vec2 st1  = w2 - w1;
			//vec2 st2  = w3 - w1;

			float s1 = vtx2.texcoord.x - vtx1.texcoord.x;
			float t1 = vtx2.texcoord.y - vtx1.texcoord.y;
			float s2 = vtx3.texcoord.x - vtx1.texcoord.x;
			float t2 = vtx3.texcoord.y - vtx1.texcoord.y;

			float stst = s1 * t2 - s2 * t1;
			float r = 0.0f;
			if (stst > 0.0f || stst < 0.0f) r = 1.0f / stst;

			vec3 sdir = ( t2 * xyz1 - t1 * xyz2 ) * r;
			vec3 tdir = ( s1 * xyz2 - s2 * xyz1 ) * r;

			// the += below obviously doesn't make sense in this case, but I'll
			// leave it here for when I move to indices
			tan1[i1] += sdir;
			tan1[i2] += sdir;
			tan1[i3] += sdir;

			// tan2 not needed anymore??
			tan2[i1] += tdir;
			tan2[i2] += tdir;
			tan2[i3] += tdir;
		}

		for ( nv::size_t a = 0; a < count; ++a )
		{
			const vec3& nv = m_data[a].normal;
			const vec3& tv = tan1[a];
			if ( ! (tv.x == 0.0f && tv.y == 0.0f && tv.z == 0.0f) )
			{
				m_data[a].tangent    = vec4( glm::normalize(tv - nv * glm::dot( nv, tv )), 0.0f ); 
				m_data[a].tangent[3] = (glm::dot(glm::cross(nv, tv), tan2[a]) < 0.0f) ? -1.0f : 1.0f;
			}
		}

	}


};

nv::obj_loader::obj_loader( string_table* strings, bool normals /*= true*/, bool tangents /*= false */ )
	: mesh_loader( strings ), m_normals( normals ), m_tangents( tangents )
{
	if ( normals )
	{
		if ( tangents )
			m_descriptor.initialize<obj_vertex_vtnt>();
		else
			m_descriptor.initialize<obj_vertex_vtn>();
	}
	else
		m_descriptor.initialize<obj_vertex_vt>();
}

bool nv::obj_loader::load( stream& source )
{
	obj_reader* reader = nullptr;
	if ( m_normals )
	{
		if ( m_tangents )
			reader = new mesh_data_reader_vtnt();
		else
			reader = new mesh_data_reader_vtn();
	}
	else
		reader = new mesh_data_reader_vt();
	std_stream sstream( &source );

	while ( reader->read_stream( sstream ) )
	{
		if ( m_tangents )
		{
			reader->calculate_tangents();
		}
	
		data_channel_set* result = data_channel_set_creator::create_set( 1 );
		data_channel_set_creator raccess( result );
		raccess.set_name( make_name( reader->name ) );
		uint8* rdata = raccess.add_channel( m_descriptor, reader->size * 3 ).raw_data();

		if ( reader->raw_size() > 0 )
		{
			raw_copy_n( reader->raw_pointer(), reader->raw_size(), rdata );
		}

		m_meshes.push_back( result );

		reader->reset();
	}
	delete reader;
	return true;

}

data_channel_set* nv::obj_loader::release_mesh_data( size_t index )
{
	data_channel_set* result = m_meshes[ index ];
	m_meshes[ index ] = nullptr;
	return result;
}

nv::obj_loader::~obj_loader()
{
	for ( auto mesh : m_meshes ) if ( mesh ) delete mesh;
}

mesh_data_pack* nv::obj_loader::release_mesh_data_pack()
{
	uint32 size = m_meshes.size();
	data_channel_set* meshes = data_channel_set_creator::create_set_array( size, 1 );
	for ( uint32 i = 0; i < size; ++i )
	{
		meshes[i] = move( *m_meshes[i] );
		delete m_meshes[i];
	}
	m_meshes.clear();
	return new mesh_data_pack( size, meshes );
}
