#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <math.h>
#include <mpi.h>


double random_number(double *rmin, double *rmax){
  return (*rmin) + (float)rand()/(float)RAND_MAX*((*rmax)-(*rmin));
}

/* This function prints the matrix A to the screen. */
void print_matrix(double **A, int n) {
  int i, j;
  for(i=0; i<n; i++) {
    for(j=0; j<n; j++){
      printf("% 6.2e  ",A[i][j]);
    }
    printf("\n");
  }
  printf("\n");
}


/* This function prints the vector c to the screen. */
void print_vector(double *c, int n) {
  int i;
  for(i=0; i<n; i++){
    printf("% 6.2e  ", c[i]);
  }
  printf("\n");
}


/* This function computes the scalar product a*b and returns the result. */
double scalar(double *a, double *b, int n){
  int i;
  double sum=0.;
  for(i=0; i<n; i++){
    sum += a[i]*b[i];
  }
  return sum;
}


/* This function computes the matrix-vector product Ax. The result
   is returned as f. */
void matrix_vector(double **A, double *x, double *f, int n){
  int i,j;
  for (i=0; i<n; i++){
    f[i] = 0.;
    for (j=0; j<n; j++){
      f[i] += A[i][j] * x[j];
    }
  }
}


/* This function computes the matrix-vector product Ax in parallel.
   The result is returned as as partial sums in f. */
void matrix_vector_par(double **A, double *x, double *f, int n,
                       int start, int end){
  int i,j;
  for (i=0; i<n; i++){
    f[i] = 0.;
    for (j=start; j<end; j++){
      f[i] += A[i][j] * x[j];
    }
  }
}


/* This function computes x+al*y and stores the result in u. */
void addmult(double *u, double *x, double *y, double al, int n){
  int i;
  for (i=0; i<n; i++){
    u[i] = x[i] + al*y[i];
  }
}


int main(int argc, char **argv) {
  double **A, *d, *x, *e, *p, *r, *f, *g;
  double *Af;
  double om = 1.7;     /* omega: over-relaxation */
  double eps;          /* stop criterion */
  double rmin,rmax;    /* for generation of randum numbers */
  double tmp,tmp2,res; /* res is the resiuum */
  double res_ges,tmp_ges,tmp2_ges;
  double al,be;
  int i,j,k,n;
  int k_stop;          /* maximal iteration number */

  int myrank,nproz;
  int n_pro,n_pro_real,n_sent;
  int *n_loc, *n_pos;
  int dest,source;

  FILE *fp;

  MPI_Init(&argc,&argv);
  MPI_Comm_rank(MPI_COMM_WORLD,&myrank);
  MPI_Comm_size(MPI_COMM_WORLD,&nproz);

  n_loc = (int*)malloc(nproz*sizeof(int));
  n_pos = (int*)malloc(nproz*sizeof(int));

  n = 8;   /* dimension of the problem */
  k_stop = 100;
  eps = 1.0e-20;

  rmin = 0;
  rmax = 1;

  p = (double*)malloc(n*sizeof(double));   
  g = (double*)malloc(n*sizeof(double));   

  A  = (double**)malloc(n*sizeof(double*));
  Af = (double*)malloc(n*n*sizeof(double));

  for(i=0; i<n; i++){
    A[i] = &Af[i*n];
  }

  if ( myrank==0 ){

    d = (double*)malloc(n*sizeof(double));   
    x = (double*)malloc(n*sizeof(double)); 
    e = (double*)malloc(n*sizeof(double));
    r = (double*)malloc(n*sizeof(double));   
    f = (double*)malloc(n*sizeof(double));   

    /* initialisation */
    srand( (unsigned)time(0));
    for(i=0; i<n; i++) {
      for(j=0; j<=i; j++){
        A[i][j] = 0.;
        while( fabs(A[i][j]) < 1.0e-10 ){
          A[i][j] = random_number(&rmin,&rmax);
          A[j][i] = A[i][j];
        }
      }
      A[i][i] += (float)n+3.;   /* ensure diagonal dominant matrix */ 
      d[i] = random_number(&rmin,&rmax);
      x[i] = 0.0;      /* start vector */
    }
  }

  MPI_Bcast(&A[0][0],n*n,MPI_DOUBLE,0,MPI_COMM_WORLD);

  n_pro = n/nproz;

  if ( n%nproz!=0 ) n_pro_real = n_pro+1;
  else n_pro_real = n_pro;  /* n_pro_real: usual number
                                    of data per processor */

  n_loc[0] = n_pro_real;   /* the local number of data */
  n_pos[0] = 0;

  for (dest=1; dest<nproz; dest++){
    n_pos[dest] = n_pos[dest-1] + n_loc[dest-1];
    if (dest==nproz-1) n_loc[dest] = n - n_pos[nproz-1];
    else n_loc[dest] = n_pro_real;
    if ( myrank==dest && myrank!=0 ){
      x = (double*)malloc(n_loc[dest]*sizeof(double));
      f = (double*)malloc(n_loc[dest]*sizeof(double));
      r = (double*)malloc(n_loc[dest]*sizeof(double));
    }
  }

  k=0;
  res_ges = 1.;

  if( myrank==0 ){
    matrix_vector(A,x,f,n);
    for(i=0; i<n; i++){
      r[i] = d[i] - f[i];
      p[i] = r[i];
    }
  }

  MPI_Scatterv(r,n_loc,n_pos,MPI_DOUBLE,r,n_loc[myrank],
               MPI_DOUBLE,0,MPI_COMM_WORLD);
  MPI_Scatterv(x,n_loc,n_pos,MPI_DOUBLE,x,n_loc[myrank],
               MPI_DOUBLE,0,MPI_COMM_WORLD); 
  MPI_Bcast(p,n,MPI_DOUBLE,0,MPI_COMM_WORLD);

  tmp = scalar(r,r,n_loc[myrank]);
  MPI_Allreduce(&tmp,&tmp_ges,1,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD);
  
  if( myrank==0 ){
    fp = fopen("residual.dat","w");
  }

  while( res_ges>eps && k<=k_stop ){
    matrix_vector_par(A,p,g,n,n_pos[myrank],n_pos[myrank]+n_loc[myrank]);
    for(dest=0;dest<nproz;dest++){
      MPI_Reduce(&(g[n_pos[dest]]),f,n_loc[dest],MPI_DOUBLE,MPI_SUM,
                 dest,MPI_COMM_WORLD);
    }

    tmp2 = scalar(&(p[n_pos[myrank]]),f,n_loc[myrank]);
    MPI_Allreduce(&tmp2,&tmp2_ges,1,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD);
    al = -tmp_ges/tmp2_ges;
    addmult(x,x,&(p[n_pos[myrank]]),-al,n_loc[myrank]);
    addmult(r,r,f,al,n_loc[myrank]);
    res = scalar(r,r,n_loc[myrank]);
    MPI_Allreduce(&res,&res_ges,1,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD);
    if( myrank==0 ){
      fprintf(fp,"%d %le\n",k,res_ges);
    }
    be = res_ges/tmp_ges;
    tmp_ges = res_ges;

    addmult(&(p[n_pos[myrank]]),r,&(p[n_pos[myrank]]),be,n_loc[myrank]);

    k++;
  }
  printf("%d: Iterationen %d\n",myrank,k-1);
  if( myrank==0 ){
    fclose(fp);
  }


  /* finally check the result */
  MPI_Gatherv(x,n_loc[myrank],MPI_DOUBLE,x,n_loc,n_pos,MPI_DOUBLE,
              0,MPI_COMM_WORLD);
  if ( myrank==0 ){
    for(i=0; i<n; i++){
      e[i] = 0.0; 
      for(j=0; j<n; j++){
        e[i] += A[i][j]*x[j];
      }
      e[i] = d[i]-e[i];
    }
    printf("\n");
    print_vector(e,n);  
  }

  MPI_Finalize();

}

