#include "entity.h"

// Entities.
#include "entities/antifaShip.h"
#include "entities/soldato.h"
#include "entities/caporale.h"
#include "entities/sergente.h"
#include "entities/maresciallo.h"
#include "entities/generale.h"
#include "entities/mussolini.h"

// This fucker is used for creating entities.
const EntityTypeInfo entityTypeInfo[ENTITY_TYPE_COUNT] = {
	(EntityTypeInfo){initAntifaShip, closeAntifaShip, updateAntifaShip, drawAntifaShip},
	(EntityTypeInfo){initSoldato, closeSoldato, updateSoldato, drawSoldato},
	(EntityTypeInfo){initCaporale, closeCaporale, updateCaporale, drawCaporale},
	(EntityTypeInfo){initSergente, closeSergente, updateSergente, drawSergente},
	(EntityTypeInfo){initMaresciallo, closeMaresciallo, updateMaresciallo, drawMaresciallo},
	(EntityTypeInfo){initGenerale, closeGenerale, updateGenerale, drawGenerale},
	(EntityTypeInfo){initMussolini, closeMussolini, updateMussolini, drawMussolini}
};

EntityVelocity entityVelocityIdentity() {
	return (EntityVelocity){
		.velocity = Vector3Zero(),
		.angularVelocity = AxisAngleIdentity(),
		.stick = Vector3Zero(),
		.speed = 0
	};
}

float accelerateValue(float value, float lastValue, float up, float down) {
	if (value - lastValue >= up)
		return lastValue + up;
	if (lastValue - value >= down)
		return lastValue - down;

	return value;
}

Vector3 accelerateVector3(Vector3 value, Vector3 lastValue, Vector3 up, Vector3 down) {
	return (Vector3){
		accelerateValue(value.x, lastValue.x, up.x, down.x),
		accelerateValue(value.y, lastValue.y, up.y, down.y),
		accelerateValue(value.z, lastValue.z, up.z, down.z)
	};
}

Entity createEntity(EntityType type, Game * game) {
	EntityTypeInfo info = entityTypeInfo[type];

	// Set defaults.
	Entity entity = (Entity){
		.type = type,
		.model = NULL,
		.radius = 0.0,
		.collisionModel = (EntityCollisionModel){0, NULL},
		.transformedCollisionModel = (EntityCollisionModel){0, NULL},
		.collisionModelTransformed = false,
		.position = Vector3Zero(),
		.rotation = QuaternionIdentity(),
		.lastPosition = Vector3Zero(),
		.lastRotation = QuaternionIdentity(),
		.velocity = entityVelocityIdentity(),
		.lastVelocity = entityVelocityIdentity(),
		.useAcceleration = false,
		.updateCb = info.updateCb,
		.drawCb = info.drawCb,
		.health = ENTITY_MAX_HEALTH,
		.collision.hit = false,
		.data = NULL
	};

	// Init.
	info.initCb(&entity, game);

	return entity;
}

void closeEntity(Entity * entity) {
	entityTypeInfo[entity->type].closeCb(entity);
}

void setEntityRadius(Entity * entity) {
	int i, j;
	Mesh mesh;
	Vector3 v;
	float farthest = 0.0;

	// Loop through meshes.
	for (i = 0; i < entity->model->meshCount; ++i) {
		mesh = entity->model->meshes[i];

		// Loop though vertices.
		for (j = 0; j < mesh.vertexCount; ++j) {
			v = (Vector3){
				mesh.vertices[j * 3],
				mesh.vertices[j * 3 + 1],
				mesh.vertices[j * 3 + 2]
			};

			farthest = fmaxf(farthest, Vector3Length(v));
		}
	}

	entity->radius = farthest;
}

// Little triangle helper for checkEntityMeshCollision.
void getTriangleFromMeshAndTransform(int num, Mesh mesh, Triangle3D triangle, Quaternion rotation, Vector3 position) {
	int i;
	int triangleLocation = num * 9;
	int vertexLocation;

	for (i = 0; i < 3; ++i) {
		vertexLocation = (i * 3) + triangleLocation;

		// Get vertex.
		triangle[i] = (Vector3){
			mesh.vertices[vertexLocation],
			mesh.vertices[vertexLocation + 1],
			mesh.vertices[vertexLocation + 2],
		};

		// Transform vertex.
		triangle[i] = Vector3RotateByQuaternion(triangle[i], rotation);
		triangle[i] = Vector3Add(triangle[i], position);
	}
}

// Little normals helper for checkEntityMeshCollision. num is not triangle number for fast fast reasons.
Vector3 getNormalsFromMeshAndTransform(int num, Mesh mesh, Quaternion rotation) {
	// Get normals.
	Vector3 normals = (Vector3){
		mesh.normals[num],
		mesh.normals[num + 1],
		mesh.normals[num + 2]
	};

	// Transform.
	normals = Vector3RotateByQuaternion(normals, rotation);

	return normals;
}

EntityCollisionMesh createCollisionMesh(Mesh mesh) {
	int i, j;
	EntityCollisionMesh collisionMesh;
	collisionMesh.triangleCount = mesh.triangleCount;

	// Allocate.
	collisionMesh.triangles = (Triangle3D*)KF_CALLOC(collisionMesh.triangleCount, sizeof(Triangle3D));
	collisionMesh.normals = (Vector3*)KF_CALLOC(collisionMesh.triangleCount, sizeof(Vector3));

	if (collisionMesh.triangles == NULL || collisionMesh.normals == NULL) {
		ALLOCATION_ERROR;
		return (EntityCollisionMesh){0, NULL, NULL};
	}

	int triangleLocation;
	int vertexLocation;
	int normalLocation;

	// Copy triangles and normals over.
	for (i = 0; i < collisionMesh.triangleCount; ++i) {
		triangleLocation = i * 9;
		normalLocation = i * 3;

		// Get triangle.
		for (j = 0; j < 3; ++j) {
			vertexLocation = triangleLocation + (j * 3);
   
			collisionMesh.triangles[i][j] = (Vector3){
				mesh.vertices[vertexLocation],
				mesh.vertices[vertexLocation + 1],
				mesh.vertices[vertexLocation + 2]
			};
		}

		// Get normal.
		collisionMesh.normals[i] = (Vector3){
			mesh.normals[normalLocation],
			mesh.normals[normalLocation + 1],
			mesh.normals[normalLocation + 2],
		};
	}

	return collisionMesh;
}

void freeCollisionMesh(EntityCollisionMesh mesh) {
	if (mesh.triangles == NULL)
		return;

	KF_FREE(mesh.triangles);
	KF_FREE(mesh.normals);
}

EntityCollisionModel entityCreateCollisionModel(Model model) {
	int i;
	EntityCollisionModel collisionModel;      
	collisionModel.meshCount = model.meshCount;

	// Allocate.
	collisionModel.meshes = (EntityCollisionMesh*)KF_CALLOC(collisionModel.meshCount, sizeof(EntityCollisionMesh));

	if (collisionModel.meshes == NULL) {
		ALLOCATION_ERROR;
		return (EntityCollisionModel){0, NULL};
	}

	// Create meshes.
	for (i = 0; i < collisionModel.meshCount; ++i)
		collisionModel.meshes[i] = createCollisionMesh(model.meshes[i]);

	return collisionModel;
}

void entityFreeCollisionModel(EntityCollisionModel model) {
	int i;

	if (model.meshes == NULL)
		return;

	for (i = 0; i < model.meshCount; ++i)
		freeCollisionMesh(model.meshes[i]);

	KF_FREE(model.meshes);
}

void transformCollisionMesh(Entity * entity, int num) {
	int i, j;
	EntityCollisionMesh mesh1 = entity->collisionModel.meshes[num];
	EntityCollisionMesh * mesh2 = &entity->transformedCollisionModel.meshes[num];
	Vector3 vertex;

	Matrix m = QuaternionToMatrix(QuaternionInvert(entity->rotation));

	for (i = 0; i < mesh1.triangleCount; ++i) {

		// Transform triangle.
		for (j = 0; j < 3; ++j) {
			vertex = mesh1.triangles[i][j];

			// Rotate.
			mesh2->triangles[i][j] = (Vector3){
				m.m0 * vertex.x + m.m1 * vertex.y + m.m2 * vertex.z,
				m.m4 * vertex.x + m.m5 * vertex.y + m.m6 * vertex.z,
				m.m8 * vertex.x + m.m9 * vertex.y + m.m10 * vertex.z
			};

			// Move to position.
			mesh2->triangles[i][j] = Vector3Add(entity->position, mesh2->triangles[i][j]);
		}

		// Transform normals.
		vertex = mesh1.normals[i];

		// Rotate.
		mesh2->normals[i] = (Vector3){
			m.m0 * vertex.x + m.m1 * vertex.y + m.m2 * vertex.z,
			m.m4 * vertex.x + m.m5 * vertex.y + m.m6 * vertex.z,
			m.m8 * vertex.x + m.m9 * vertex.y + m.m10 * vertex.z
		};

	}
}

void entityTransformCollisionModel(Entity * entity) {
	int i;

	for (i = 0; i < entity->collisionModel.meshCount; ++i)
		transformCollisionMesh(entity, i);
}

void entityCheckTransformedCollisionModel(Entity * entity) {
	bool moved = !Vector3Equals(entity->lastPosition, entity->position);
	moved |= !QuaternionEquals(entity->lastRotation, entity->rotation);

	if (moved)
		entity->collisionModelTransformed = false;
}

// Big mesh helper for checkEntityCollision.
bool checkEntityMeshCollision(Entity entity1, Entity entity2, int entity1MeshNum, int entity2MeshNum) {
	int triangle1Num;
	int triangle2Num;
	bool collided;

	EntityCollisionMesh mesh1 = entity1.transformedCollisionModel.meshes[entity1MeshNum];
	EntityCollisionMesh mesh2 = entity2.transformedCollisionModel.meshes[entity2MeshNum];

	// Test every triangle for collision.
	for (triangle1Num = 0; triangle1Num < mesh1.triangleCount; ++triangle1Num) {
		for (triangle2Num = 0; triangle2Num < mesh2.triangleCount; ++triangle2Num) {

			// Check for collision.
			collided = checkTriangleCollision3D(
				mesh1.triangles[triangle1Num],
				mesh2.triangles[triangle2Num],
				mesh1.normals[triangle1Num],
				mesh2.normals[triangle2Num]
			);

			if (collided)
				return true;
		}
	}

	return false;
}

bool checkEntityCollision(Entity * entity1, Entity * entity2) {
	int i, j;

	// Failed quick check.
	if (Vector3Distance(entity1->position, entity2->position) > entity1->radius + entity2->radius)
		return false;

	// Transform collision model.
	if (!entity1->collisionModelTransformed) {
		entityTransformCollisionModel(entity1);
		entity1->collisionModelTransformed = true;
	}

	if (!entity2->collisionModelTransformed) {
		entityTransformCollisionModel(entity2);
		entity2->collisionModelTransformed = true;
	}

	// Loop through every mesh and check.
	for (i = 0; i < entity1->collisionModel.meshCount; ++i)
		for (j = 0; j < entity2->collisionModel.meshCount; ++j)
			if (checkEntityMeshCollision(*entity1, *entity2, i, j))
				return true;

	return false;
}

// Basic wireframe drawing.
void entityDraw(Entity * entity) {
	entity->model->transform = QuaternionToMatrix(entity->rotation);

	DrawModelWires(
		*entity->model,
		entity->position,
		1.0,
		GREEN
	);
}

void entityUpdatePosition(Entity * entity) {
	float t = GetFrameTime();

	Vector3 velocity = (Vector3){
		entity->velocity.velocity.x * t,
		entity->velocity.velocity.y * t,
		entity->velocity.velocity.z * t
	};

	entity->position = Vector3Add(entity->position, velocity);
}

void entityUpdateRotation(Entity * entity) {
	float t = GetFrameTime();

    Quaternion angularRotation = QuaternionFromAxisAngle(
		entity->velocity.angularVelocity.axis, 
		entity->velocity.angularVelocity.angle * t
	);

    entity->rotation = QuaternionMultiply(entity->rotation, angularRotation);
	entity->rotation = QuaternionNormalize(entity->rotation);
}

void entityJoystickControl(Entity * entity, Vector3 stick, float speed) {
	float s = speed;
	Vector3 st = stick;
	float t = GetFrameTime();

	// Handle acceleration.
	if (entity->useAcceleration) {
		s = accelerateValue(
			speed,
			entity->lastVelocity.speed,
			entity->acceleration.speedUp * t,
			entity->acceleration.speedDown * t
		);

		st = accelerateVector3(
			stick,
			entity->lastVelocity.stick,
			Vector3Scale(entity->acceleration.rotation, t),
			Vector3Scale(entity->acceleration.rotation, t)
		);
	}

	entity->velocity.stick = st;
	entity->velocity.speed = s;
	entity->lastVelocity = entity->velocity;

	// Set angular velocity.
	Vector3 angularVelocity = Vector3Scale(st, PI);
	entity->velocity.angularVelocity.angle = Vector3Length(angularVelocity);
	entity->velocity.angularVelocity.axis = st;

	entityUpdateRotation(entity);

	// Set position.
	Matrix m = QuaternionToMatrix(QuaternionInvert(entity->rotation));

	entity->velocity.velocity = (Vector3){
		m.m2 * s,
		m.m6 * s,
		m.m10 * s
	};

	entityUpdatePosition(entity);
}

void entityUpdateLastValues(Entity * entity) {
	entity->lastPosition = entity->position;
	entity->lastRotation = entity->rotation;
}

void entityFlyToPoint(Entity * entity, Vector3 point, EntityFlyToPointInfo * info) {
	float t = GetFrameTime();

	// Get distance and direction.
	Vector3 dis = Vector3Subtract(entity->position, point);
	Vector3 direction = Vector3Normalize(dis);

	// Get look at and rotation.
	Matrix matrix = MatrixLookAt(Vector3Zero(), direction, (Vector3){0.0, 1.0, 0.0});
	Quaternion rotation = QuaternionInvert(QuaternionFromMatrix(matrix));

	// Rotate this fucker.
	if (info->rotationSpeed == 0.0)
		entity->rotation = rotation;
	else
		entity->rotation = QuaternionSlerp(entity->rotation, rotation, t * info->rotationSpeed);

	// Velocity control.
	float speed = 0.0;

	float distance = Vector3Length(dis);

	switch (info->controlType) {
		case ENTITY_FLY_TO_POINT_PID:
			speed = runPID(0.0, -distance, &info->controller.speedPID);
			break;
		case ENTITY_FLY_TO_POINT_BANG_BANG:
			speed = info->controller.bangbang.speed;

			if (distance <= info->controller.bangbang.stopAt)
				speed = 0.0;

			break;
		default: // Something is fucked up.
			break;
	}

	Matrix m = QuaternionToMatrix(QuaternionInvert(entity->rotation));

	// Accelerate.
	if (entity->useAcceleration)
		speed = accelerateValue(
			speed,
			entity->lastVelocity.speed,
			entity->acceleration.speedUp * t,
			entity->acceleration.speedDown * t
		);

	// Velocity.
	entity->velocity.velocity = (Vector3){
		m.m2 * speed,
		m.m6 * speed,
		m.m10 * speed
	};

	entityUpdatePosition(entity);

	entity->velocity.speed = speed;
	entity->lastVelocity = entity->velocity;
}