#include <stdio.h>
#include <time.h>
#include <math.h>
#include "mpi.h"

#define WIDTH 100           // Width of the matrix
#define HEIGHT 100          // Height of the matrix
#define NUM_OF_CITIES 40    // Number of cities that will be created
#define MASTER 0            // ID of the master process

MPI_Status status;

struct POINT
{    
    int x;                  // X-Axis
    int y;                  // Y-Axis
};

struct CITY
{    
    int id;                 // ID of the City    
    struct POINT location;  // Location in the Matrix    
    int visited;            // 1 = Visited 0 = Not Visited
};

struct DISTANCE_TABLE
{
    double distances[NUM_OF_CITIES][NUM_OF_CITIES];
};

double CalculateDistance(struct POINT* loc1, struct POINT* loc2)
{
    double result = 0.0;

    result = (loc1->x - loc2->x) * (loc1->x - loc2->x) + (loc1->y - loc2->y) * (loc1->y - loc2->y);

    result = sqrt(result);

    return result;
}

void PopulateDistanceTable(struct DISTANCE_TABLE* dist, struct CITY* cities)
{
    int i, j;

    double result;

    for (i = 0; i < NUM_OF_CITIES; i++)
    {
        for (j = 0; j < NUM_OF_CITIES; j++)
        {            
            (dist->distances)[i][j] = 0.0;
        }
    }

    for (i = 0; i < NUM_OF_CITIES; i++)
    {
        for (j = 0; j < NUM_OF_CITIES; j++)
        {
            if (dist->distances[j][i] == 0.0)
            {
                result = CalculateDistance(&(cities[i].location), &(cities[j].location));
            }
            else
            {
                result = dist->distances[j][i];
            }

            (dist->distances)[i][j] = result;
        }
    }
}

double SolveRoute(struct CITY* cities, struct DISTANCE_TABLE* distanceTable, int start)
{
    double distanceTravelled = 0.0;
    double lowestDistance = 100000.00;
    double distance = 0.0;

    int citiesVisited = 1, i, curr = start, nextCity;

    for (i = 0; i < NUM_OF_CITIES; i++)
    {
        cities[i].visited = 0;
    }

    cities[start].visited = 1;

    while (citiesVisited != NUM_OF_CITIES)
    {
        for (i = 0; i < NUM_OF_CITIES; i++)
        {
            if (i == curr || cities[i].visited == 1)
            {
                continue;
            }
            distance = distanceTable->distances[curr][i];

            if (distance < lowestDistance)
            {
                nextCity = i;
                lowestDistance = distance;
            }
        }

        curr = nextCity;
        cities[curr].visited = 1;
        citiesVisited++;
        distanceTravelled += lowestDistance;
        lowestDistance = 100000.00;
    }

    distanceTravelled += CalculateDistance(&cities[curr].location, &cities[start].location);

    return distanceTravelled;
}

void main(int argc, char **argv)
{
    int i, j, numtasks, numworkers, taskid, dest, source, cityPerProcess, x, y, rangeStart, rangeEnd, processComplete = 0; 
    int start, end;

    int matrix[WIDTH][HEIGHT] = {0};

    struct CITY cities[NUM_OF_CITIES];  
    struct CITY cities2[NUM_OF_CITIES];

    struct DISTANCE_TABLE distanceTable;

    double totalDistance = 0.0;
    double result = 100000.00, received;

    MPI_Init(&argc, &argv);
	MPI_Comm_rank(MPI_COMM_WORLD, &taskid);
	MPI_Comm_size(MPI_COMM_WORLD, &numtasks);

	srand(time(NULL));

    start = time(NULL);

    numworkers = numtasks - 1;	

    if (taskid == MASTER)
    {
        printf("Master is creating the cities...\n");

        for (i = 0; i < NUM_OF_CITIES; i++)
        {
            while (1)
            {
                x = 1 + rand() % WIDTH;
                y = 1 + rand() % HEIGHT;

                if (matrix[x][y] != 1)
                {
                    matrix[x][y] = 1;

                    cities[i].id = i;
                    cities[i].location.x = x;
                    cities[i].location.y = y;
                    cities[i].visited = 0;
                    break;
                }
            }
        }

        printf("Master's created the cities. It's now sending the city data to other nodes...\n");
        
        for (dest = 1; dest <= numworkers; dest++)
        {
            MPI_Send(cities, sizeof(struct CITY) * NUM_OF_CITIES, MPI_BYTE, dest, 0, MPI_COMM_WORLD);
        }

        printf("Master's sent the city data to nodes\n");

        printf("Master is calculating distances between the cities...\n");

        PopulateDistanceTable(&distanceTable, cities);

        printf("Master is sending the distance table to nodes...\n");

        for (dest = 1; dest <= numworkers; dest++)
        {
            MPI_Send(&distanceTable, sizeof(struct DISTANCE_TABLE), MPI_BYTE, dest, 0, MPI_COMM_WORLD);
        }

        j = 0;

        while (1)
        {
            MPI_Recv(&received, 1, MPI_DOUBLE, MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &status); 
            //printf("The value received is %f\n\n", received); 
            j++;

            if (received == -1.0)
            {
                processComplete++;
            }
            if (processComplete == numworkers)
            {
                break;
            }
            if (received < result && received != -1.0)
            {
                result = received;
            }
        }

        printf("The result is %f\n\n", result);  
        //printf("The counter j is %d\n\n", j);
    } 	
    else
    {
        MPI_Recv(cities2, sizeof(struct CITY) * NUM_OF_CITIES, MPI_BYTE, MASTER, 0, MPI_COMM_WORLD, &status);

        printf("NODE %d has received city data...\n", taskid);

        MPI_Recv(&distanceTable, sizeof(struct DISTANCE_TABLE), MPI_BYTE, MASTER, 0, MPI_COMM_WORLD, &status);

        printf("NODE %d has received distance table...\n", taskid);

        cityPerProcess = NUM_OF_CITIES / numworkers;

        rangeStart = (taskid - 1) * cityPerProcess;        
        rangeEnd = taskid * cityPerProcess;

        for (i = rangeStart; i < rangeEnd; i++)
        {            
            result = SolveRoute(cities2, &distanceTable, i);

            MPI_Send(&result, 1, MPI_DOUBLE, MASTER, 0, MPI_COMM_WORLD);
        }   

        result = -1;
        MPI_Send(&result, 1, MPI_DOUBLE, MASTER, 0, MPI_COMM_WORLD);
    }

    end = time(NULL);

    printf("Program finished in %d seconds...\n\n", end - start);

    MPI_Finalize();
}
