#include "entity.h"

// Entities.
#include "entities/antifaShip.h"
#include "entities/soldato.h"
#include "entities/caporale.h"
#include "entities/sergente.h"
#include "entities/maresciallo.h"
#include "entities/generale.h"
#include "entities/mussolini.h"

// This fucker is used for creating entities.
const EntityTypeInfo entityTypeInfo[ENTITY_TYPE_COUNT] = {
	(EntityTypeInfo){initAntifaShip, closeAntifaShip, updateAntifaShip, drawAntifaShip},
	(EntityTypeInfo){initSoldato, closeSoldato, updateSoldato, drawSoldato},
	(EntityTypeInfo){initCaporale, closeCaporale, updateCaporale, drawCaporale},
	(EntityTypeInfo){initSergente, closeSergente, updateSergente, drawSergente},
	(EntityTypeInfo){initMaresciallo, closeMaresciallo, updateMaresciallo, drawMaresciallo},
	(EntityTypeInfo){initGenerale, closeGenerale, updateGenerale, drawGenerale},
	(EntityTypeInfo){initMussolini, closeMussolini, updateMussolini, drawMussolini}
};

EntityVelocity entityVelocityIdentity() {
	return (EntityVelocity){
		.velocity = Vector3Zero(),
		.angularVelocity = AxisAngleIdentity(),
		.stick = Vector3Zero(),
		.speed = 0
	};
}

float accelerateValue(float value, float lastValue, float up, float down) {
	if (value - lastValue >= up)
		return lastValue + up;
	if (lastValue - value >= down)
		return lastValue - down;

	return value;
}

Vector3 accelerateVector3(Vector3 value, Vector3 lastValue, Vector3 up, Vector3 down) {
	return (Vector3){
		accelerateValue(value.x, lastValue.x, up.x, down.x),
		accelerateValue(value.y, lastValue.y, up.y, down.y),
		accelerateValue(value.z, lastValue.z, up.z, down.z)
	};
}

Entity createEntity(EntityType type, Game * game) {
	EntityTypeInfo info = entityTypeInfo[type];

	// Set defaults.
	Entity entity = (Entity){
		.type = type,
		.model = NULL,
		.position = Vector3Zero(),
		.rotation = QuaternionIdentity(),
		.velocity = entityVelocityIdentity(),
		.lastVelocity = entityVelocityIdentity(),
		.useAcceleration = false,
		.updateCb = info.updateCb,
		.drawCb = info.drawCb,
		.health = 1.0,
		.data = NULL
	};

	// Init.
	info.initCb(&entity, game);

	return entity;
}

void closeEntity(Entity * entity) {
	entityTypeInfo[entity->type].closeCb(entity);
}

// Basic wireframe drawing.
void entityDraw(Entity * entity) {
	entity->model->transform = QuaternionToMatrix(entity->rotation);

	DrawModelWires(
		*entity->model,
		entity->position,
		1,
		GREEN
	);
}

void entityUpdatePosition(Entity * entity) {
	float t = GetFrameTime();

	Vector3 velocity = (Vector3){
		entity->velocity.velocity.x * t,
		entity->velocity.velocity.y * t,
		entity->velocity.velocity.z * t
	};

	entity->position = Vector3Add(entity->position, velocity);
}

void entityUpdateRotation(Entity * entity) {
	float t = GetFrameTime();

    Quaternion angularRotation = QuaternionFromAxisAngle(
		entity->velocity.angularVelocity.axis, 
		entity->velocity.angularVelocity.angle * t
	);

    entity->rotation = QuaternionMultiply(entity->rotation, angularRotation);
}

void entityJoystickControl(Entity * entity, Vector3 stick, float speed) {
	float s = speed;
	Vector3 st = stick;
	float t = GetFrameTime();

	// Handle acceleration.
	if (entity->useAcceleration) {
		s = accelerateValue(
			speed,
			entity->lastVelocity.speed,
			entity->acceleration.speedUp * t,
			entity->acceleration.speedDown * t
		);

		st = accelerateVector3(
			stick,
			entity->lastVelocity.stick,
			Vector3Scale(entity->acceleration.rotation, t),
			Vector3Scale(entity->acceleration.rotation, t)
		);
	}

	entity->velocity.stick = st;
	entity->velocity.speed = s;
	entity->lastVelocity = entity->velocity;

	// Set angular velocity.
	Vector3 angularVelocity = Vector3Scale(st, PI);
	entity->velocity.angularVelocity.angle = Vector3Length(angularVelocity);
	entity->velocity.angularVelocity.axis = st;

	entityUpdateRotation(entity);

	// Set position.
	Matrix m = QuaternionToMatrix(QuaternionInvert(entity->rotation));

	entity->velocity.velocity = (Vector3){
		m.m2 * s,
		m.m6 * s,
		m.m10 * s
	};

	entityUpdatePosition(entity);
}

void entityFlyToPoint(Entity * entity, Vector3 point, EntityFlyToPointInfo * info) {
	float t = GetFrameTime();

	// Get distance and direction.
	Vector3 dis = Vector3Subtract(entity->position, point);
	Vector3 direction = Vector3Normalize(dis);

	// Get look at and rotation.
	Matrix matrix = MatrixLookAt(Vector3Zero(), direction, (Vector3){0.0, 1.0, 0.0});
	Quaternion rotation = QuaternionInvert(QuaternionFromMatrix(matrix));

	// Rotate this fucker.
	if (info->rotationSpeed == 0.0)
		entity->rotation = rotation;
	else
		entity->rotation = QuaternionSlerp(entity->rotation, rotation, t * info->rotationSpeed);

	// Velocity control.
	float speed = 0.0;

	float distance = Vector3Length(dis);

	switch (info->controlType) {
		case ENTITY_FLY_TO_POINT_PID:
			speed = runPID(0.0, -distance, &info->controller.speedPID);
			break;
		case ENTITY_FLY_TO_POINT_BANG_BANG:
			speed = info->controller.bangbang.speed;

			if (distance <= info->controller.bangbang.stopAt)
				speed = 0.0;

			break;
		default: // Something is fucked up.
			break;
	}

	Matrix m = QuaternionToMatrix(QuaternionInvert(entity->rotation));

	// Accelerate.
	if (entity->useAcceleration)
		speed = accelerateValue(
			speed,
			entity->lastVelocity.speed,
			entity->acceleration.speedUp * t,
			entity->acceleration.speedDown * t
		);

	// Velocity.
	entity->velocity.velocity = (Vector3){
		m.m2 * speed,
		m.m6 * speed,
		m.m10 * speed
	};

	entityUpdatePosition(entity);

	entity->velocity.speed = speed;
	entity->lastVelocity = entity->velocity;
}