#include <stdlib.h>
#include <stdio.h>
#include "mpi.h"

#define LENGTH 30000000

int logd(int n) {
    int i;
    if (n==0) return -1;
    i = 0;
    while (n!=1) {
	n /= 2;
	i++;
    }
    return i;
}


int main(int argc, char** argv) {
    float *a, *b, *a_loc, *b_loc;
    float r, buf;
    long i, i_loc, proc, n_loc;
    double start, end;
    int my_rank, p;
    MPI_Status status;

    MPI_Init(&argc,(char***) &argv);

    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
    MPI_Comm_size(MPI_COMM_WORLD, &p);

    n_loc = LENGTH / p;

    if (my_rank == 0) {
	a = (float*) malloc(LENGTH * sizeof(float));
	b = (float*) malloc(LENGTH * sizeof(float));

	a_loc = a;
	b_loc = b;

	for (i=0; i<LENGTH; i++) {
	    a[i] = rand();
	    b[i] = rand();
	}
	for (proc=1; proc<p; proc++) {
	    MPI_Send(&(a[n_loc*proc]), n_loc, MPI_FLOAT, proc, 1, MPI_COMM_WORLD);
	    MPI_Send(&(b[n_loc*proc]), n_loc, MPI_FLOAT, proc, 2, MPI_COMM_WORLD);
	}
    } else {
	a_loc = (float*) malloc(n_loc * sizeof(float));
	b_loc = (float*) malloc(n_loc * sizeof(float));

	MPI_Recv(a_loc, n_loc, MPI_FLOAT, 0, 1, MPI_COMM_WORLD, &status);
	MPI_Recv(b_loc, n_loc, MPI_FLOAT, 0, 2, MPI_COMM_WORLD, &status);
    }

    start = MPI_Wtime();

    r = 0;
    for (i=0; i<n_loc; i++) {
	r += a_loc[i] * b_loc[i];
    }

    end = MPI_Wtime();

    printf("Process %d: elapsed time %f\n", my_rank, end-start);

    proc = p;
    for (i=0; i<logd(p); i++) {
	if (my_rank >= (proc / 2)) {
	    MPI_Send(&r, 1, MPI_FLOAT, my_rank - (proc / 2), 3, MPI_COMM_WORLD);
	    printf("step %d: %d sends its result to %d\n", i, my_rank, my_rank - (proc / 2));
	    break;
	} else {
	    MPI_Recv(&buf, 1, MPI_FLOAT, my_rank + (proc / 2), 3, MPI_COMM_WORLD, &status);
	    printf("step %d: %d receives from %d\n", i, my_rank, my_rank + (proc / 2));
	    r += buf;
	}
	proc /= 2;
    }

    if (my_rank == 0) {
	printf("Final result on process 0: %e\n",r);
    }

    fflush(stdout);
    fflush(stderr);
    MPI_Barrier(MPI_COMM_WORLD);
    MPI_Finalize();

    return 0;
}

