#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <math.h>
#include <mpi.h>


double random_number(double *rmin, double *rmax){
  return (*rmin) + (float)rand()/(float)RAND_MAX*((*rmax)-(*rmin));
}


/* This function prints the matrix A to the screen */
void print_matrix(double **A, int n) {
  int i, j;
  for(i=0; i<n; i++) {
    for(j=0; j<n; j++){
      printf("% 6.2e  ",A[i][j]);
    }
    printf("\n");
  }
  printf("\n");
}


/* This function prints the vector c to the screen */
void print_vector(double *c, int n) {
  int i;
  for(i=0; i<n; i++){
    printf("% 6.2e  ", c[i]);
  }
  printf("\n");
}


/* This function calculates the actual solution x_i */
double raw(int i, double **A, double *d, double *x, int n, double om){
  int j;
  double s = d[i];
  for(j=0; j<n; j++){
    s += ( (i!=j) ? (-A[i][j]) : ((1-om)/om * A[i][j]) ) * x[j];
  }
  return om / A[i][i] * s;
}


/* This function calculates the actual solution x_i in parallel */
double raw_par(int i, double **A, double *d, double *x, int n, double om,
               int myrank, int nproz){
  int j;
  double s = 0;
  double tmp;
  for(j=0; j<n; j++) {
    if( myrank==(j%nproz) ) {
      s += ( (i!=j) ? (-A[i][j]) : ((1-om)/om * A[i][j]) ) * x[j];
    }
  }
  MPI_Reduce(&s,&tmp,1,MPI_DOUBLE,MPI_SUM,i%nproz,MPI_COMM_WORLD);
  if( myrank==(i%nproz) ){
    return om / A[i][i] * (tmp + d[i]);
  }
  else{
    return 0.;
  }
}


int main(int argc, char **argv){
  double **A, *d, *x, *e, s;
  double *Af;
  double om = 1.7;     /* omega: over-relaxation */
  double eps;          /* stop criterion */
  double rmin,rmax;    /* for generation of randum numbers */
  double tmp,rhs,lhs;  /* rhs,lhs: right resp. left hand side of inequality */
  double my_rhs,my_lhs; /* local pendants */
  int i, j, k, n;
  int k_stop;          /* maximal iteration number */
  int myrank, nproz;

  FILE *fp;

  n = 8;   /* dimension of the problem */
  k_stop = 100;
  eps = 1.0e-10;

  MPI_Init(&argc, &argv);
  MPI_Comm_size(MPI_COMM_WORLD, &nproz);
  MPI_Comm_rank(MPI_COMM_WORLD, &myrank);

  rmin = 0;
  rmax = 1;

  A  = (double**)malloc(n*sizeof(double*));
  Af = (double*)malloc(n*n*sizeof(double));

  d = (double*)malloc(n*sizeof(double));   
  x = (double*)malloc(n*sizeof(double)); 
  e = (double*)malloc(n*sizeof(double));

  for(i=0; i<n; i++){
    A[i] = &Af[i*n];
  }

  /* initialisation */
  if(myrank == 0) {
    srand( (unsigned)time(0));
    for(i=0; i<n; i++) {
      for(j=0; j<n; j++){
        A[i][j] = 0.;
        while( fabs(A[i][j]) < 1.0e-10 ){
          A[i][j] = random_number(&rmin,&rmax);
        }
      }
      A[i][i] += n+2;   /* ensure diagonal dominant matrix */ 
      d[i] = random_number(&rmin,&rmax);
    }
  }
  MPI_Bcast(Af,n*n,MPI_DOUBLE,0,MPI_COMM_WORLD);
  MPI_Bcast(d,n,MPI_DOUBLE,0,MPI_COMM_WORLD);

  for(i=0; i<n; i++) {
    x[i] = 0.0;      /* start vector for each process */
  }

  k=0;
  my_lhs = 1.;  /* initialisation because of while-loop */
  my_rhs = 1.;

  if( myrank==0 ){  /* writing is done only by process 0 */
    fp = fopen("residual.dat","w");
  }
  while( k<=k_stop && lhs>=eps*rhs ){
    my_rhs = 0.;
    my_lhs = 0.;
    for(i=0; i<n; i++) {    
      tmp = raw_par(i,A,d,x,n,om,myrank,nproz);
      if( myrank==(i%nproz) ) {
        my_lhs += (tmp-x[i])*(tmp-x[i]);
        my_rhs += x[i]*x[i];
        x[i] = tmp;
      }
    }
    MPI_Allreduce(&my_lhs,&lhs,1,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD);
    MPI_Allreduce(&my_rhs,&rhs,1,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD);
    lhs = sqrt(lhs);
    rhs = sqrt(rhs);
    if( myrank==0 ){
      if( rhs!=0. ){
        fprintf(fp,"%d %le\n",k,lhs/rhs);
      }
    }
    k++;
  }

  if( myrank==0 ){
    fclose(fp);
  }

  /* distribute the solution to all processes */
  for(i=0; i<n; i++ ){
    MPI_Bcast(&(x[i]),1,MPI_DOUBLE,i%nproz,MPI_COMM_WORLD);
  }

  /* finally check the result */
  if( myrank==0 ){
    for(i=0; i<n; i++){
      e[i] = 0.0; 
      for(j=0; j<n; j++){
        e[i] += A[i][j]*x[j];
      }
      e[i] = d[i]-e[i];
    }
    printf("\n");
    print_vector(e,n);  
  }

  MPI_Finalize();
}


