// Copyright (C) 2011 Kornel Kisielewicz
// This file is part of NV Libraries.
// For conditions of distribution and use, see copyright notice in nv.hh

#include "nv/gfx/keyframed_mesh.hh"

#include <glm/gtc/matrix_access.hpp>
#include <glm/gtx/matrix_interpolation.hpp>
#include "nv/interface/context.hh"
#include "nv/interface/device.hh"


#include "nv/logging.hh"

using namespace nv;

nv::keyframed_mesh::keyframed_mesh( device* a_device, mesh_data* a_data, tag_map* a_tag_map )
	: animated_mesh()
	, m_mesh_data( a_data )
	, m_tag_map( a_tag_map )
	, m_start_frame( false )
	, m_stop_frame( false )
	, m_last_frame( 0 )
	, m_next_frame( 0 )
	, m_time( 0 )
	, m_fps( 0 )
	, m_interpolation( 0.0f )
	, m_looping( false )
	, m_active( false )
{
	m_va = a_device->create_vertex_array();

	m_index_count  = m_mesh_data->get_index_channel()->count;
	m_vertex_count = m_mesh_data->get_channel<vertex_t>()->count;
	m_frame_count  = m_mesh_data->get_channel<vertex_pn>()->count / m_vertex_count;
}

size_t keyframed_mesh::get_max_frames() const
{
	return m_frame_count;
}

transform keyframed_mesh::get_tag( const std::string& tag ) const
{
	NV_ASSERT( m_tag_map, "TAGMAP FAIL" );
	const transform_vector* transforms = m_tag_map->get_tag( tag );
	NV_ASSERT( transforms, "TAG FAIL" );
	return interpolate( transforms->get( m_last_frame ), transforms->get( m_next_frame ), m_interpolation  );
}

void keyframed_mesh::setup_animation( uint32 start, uint32 count, uint32 fps, bool loop )
{
	m_start_frame   = start;
	m_stop_frame    = start+count-1;
	m_looping       = loop;
	m_fps           = fps;
	m_active        = count > 1;
	m_time          = 0;
	m_last_frame    = start;
	m_next_frame    = (count > 1 ? start + 1 : start );
	m_interpolation = 0.0f;
}

void nv::keyframed_mesh::set_frame( uint32 frame )
{
	m_last_frame    = frame;
	m_next_frame    = frame;
	m_active        = false;
	m_interpolation = 0.0f;
}

void keyframed_mesh::update( uint32 ms )
{
	if ( m_active )
	{
		m_time += ms;
		uint32 f_diff = (m_stop_frame - m_start_frame);
		float f_time  = 1000 / (float)m_fps;
		float f_max   = ( m_looping ? ( f_diff + 1 ) : f_diff ) * f_time;
		float f_pos   = m_time / f_time;

		m_last_frame    = (uint32)glm::floor( f_pos ) + m_start_frame;
		m_next_frame    = m_last_frame + 1;
		if ( m_next_frame > m_stop_frame )
		{
			m_next_frame = m_start_frame;
		}

		if ( m_time >= f_max )
		{
			if ( m_looping )
			{
				uint32 left = m_time - static_cast< uint32 >( f_max );
				m_time = 0;
				update( left );
			}
			else
			{
				m_active     = false;
				m_last_frame = m_stop_frame;
				m_next_frame = m_stop_frame;
			}
		}
		m_interpolation = f_pos - glm::floor( f_pos );
	}
}

void nv::keyframed_mesh::update( program* a_program ) const
{
	a_program->set_opt_uniform( "nv_interpolate", m_interpolation );
}

nv::keyframed_mesh::~keyframed_mesh()
{
	delete m_va;
}

void nv::keyframed_mesh::run_animation( animation_entry* a_anim )
{
	if ( a_anim )
	{
		keyframed_animation_entry * anim = down_cast<keyframed_animation_entry>(a_anim);
		m_active = true;
		setup_animation( anim->m_start, anim->m_frames, anim->m_fps, anim->m_looping );
	}
	else
	{
		m_active = false;
	}
}

nv::keyframed_mesh_gpu::keyframed_mesh_gpu( device* a_device, mesh_data* a_data, tag_map* a_tag_map, program* a_program )
	: keyframed_mesh( a_device, a_data, a_tag_map )
	, m_loc_next_position( 0 )
	, m_loc_next_normal( 0 )
	, m_gpu_last_frame( 0xFFFFFFFF )
	, m_gpu_next_frame( 0xFFFFFFFF )
{
	m_loc_next_position = a_program->get_attribute( "nv_next_position" )->get_location();
	m_loc_next_normal   = a_program->get_attribute( "nv_next_normal" )->get_location();
	m_va = a_device->create_vertex_array( a_data, STATIC_DRAW );
	vertex_buffer* vb = m_va->find_buffer( slot::POSITION );
	m_va->add_vertex_buffer( m_loc_next_position, vb, FLOAT, 3, 0,              sizeof( vertex_pn ), false );
	m_va->add_vertex_buffer( m_loc_next_normal,   vb, FLOAT, 3, sizeof( vec3 ), sizeof( vertex_pn ), false );
}


void nv::keyframed_mesh_gpu::update( uint32 ms )
{
	keyframed_mesh::update( ms );

	if ( m_gpu_last_frame != m_last_frame )
	{
		m_va->update_vertex_buffer( slot::POSITION, m_last_frame * m_vertex_count * sizeof( vertex_pn ) );
		m_va->update_vertex_buffer( slot::NORMAL,   m_last_frame * m_vertex_count * sizeof( vertex_pn ) + sizeof( vec3 ) );
		m_gpu_last_frame = m_last_frame;
	}
	if ( m_gpu_next_frame != m_next_frame )
	{
		m_va->update_vertex_buffer( m_loc_next_position, m_next_frame * m_vertex_count * sizeof( vertex_pn ) );
		m_va->update_vertex_buffer( m_loc_next_normal,   m_next_frame * m_vertex_count * sizeof( vertex_pn ) + sizeof( vec3 ) );
		m_gpu_next_frame = m_next_frame;
	}
}

nv::keyframed_mesh_cpu::keyframed_mesh_cpu( device* a_device, mesh_data* a_data, tag_map* a_tag_map )
	: keyframed_mesh( a_device, a_data, a_tag_map )
{
	m_vb = a_device->create_vertex_buffer( nv::STATIC_DRAW, m_vertex_count * sizeof( vertex_pn ), (void*)m_mesh_data->get_channel<vertex_pn>()->data );
	m_va->add_vertex_buffers( m_vb, m_mesh_data->get_channel<vertex_pn>() );

	nv::vertex_buffer* vb = a_device->create_vertex_buffer( nv::STATIC_DRAW, m_vertex_count * sizeof( nv::vec2 ), (void*)m_mesh_data->get_channel<vertex_t>()->data );
	m_va->add_vertex_buffers( vb, m_mesh_data->get_channel<vertex_t>() );

	nv::index_buffer* ib = a_device->create_index_buffer( nv::STATIC_DRAW, m_mesh_data->get_index_channel()->size(), (void*)m_mesh_data->get_index_channel()->data );
	m_va->set_index_buffer( ib, m_mesh_data->get_index_channel()->desc.slots[0].etype, true );

	m_vertex.resize( m_vertex_count );
}

void nv::keyframed_mesh_cpu::update( uint32 ms )
{
	keyframed_mesh::update( ms );

	const vertex_pn* data = m_mesh_data->get_channel_data<vertex_pn>();
	const vertex_pn* prev = data + m_vertex_count * m_last_frame;
	const vertex_pn* next = data + m_vertex_count * m_next_frame;

	for ( size_t i = 0; i < m_vertex_count; ++i )
	{
		m_vertex[i].position = glm::mix( prev[i].position, next[i].position, m_interpolation );
		m_vertex[i].normal   = glm::mix( prev[i].normal,   next[i].normal,   m_interpolation );
	}

	m_vb->bind();
	m_vb->update( m_vertex.data(), 0, m_vertex_count * sizeof( vertex_pn ) );
	m_vb->unbind();
}
