#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#include<mpi.h>

/* This function prints the matrix A to the screen. */
void print_matrix(double **A, int imax, int jmax) {
  int i, j;
  for(j=jmax-1; j>=0; j--){
    for(i=0; i<imax; i++) {
      printf("% 6.2e  ",A[i][j]);
    }
    printf("\n");
  }
  printf("\n");
}


/* This function prints the matrix A to the harddrive. */
void print_matrix_disc(double **A, int imax, int jmax, 
                       double dx, double dy, int icarti, int icartj, int k) {
  int i,j;
  char file[500];
  FILE *fp;

  sprintf(file,"p_%.2d%.2d.%.4d.dat",icarti,icartj,k);

  fp = fopen(file,"w");

  for(i=1; i<=imax; i++) {
    for(j=1; j<=jmax; j++){
      fprintf(fp,"%6.2e\n",A[i][j]);
    }
  }
  fclose(fp);
}


int main(int argc, char **argv){

  double **p,**d;
  double *p_vert_ts,*p_vert_tr,*p_vert_bs,*p_vert_br;
  double dx,dy;
  double res,eps,tmp,res_ges;
  int i,j,imax,jmax;
  int k,k_stop;

  int ndimi,ndimj;
  MPI_Status status;
  MPI_Comm comm_cart;
  int *idims, *iperiods, *icart;
  int myrank,nproz;
  int i_lef,i_rig,i_top,i_bot;

  imax = 8; jmax = 8;
  eps = 1.0e-10;
  k_stop = 50;
  ndimi = 3; ndimj = 2;
  dx = 1./((float)imax*ndimi);
  dy = 1./((float)jmax*ndimj);

  p = (double**)malloc((imax+2)*sizeof(double*));
  d = (double**)malloc((imax+2)*sizeof(double*));

  p_vert_ts = (double*)malloc((imax+2)*sizeof(double*));
  p_vert_bs = (double*)malloc((imax+2)*sizeof(double*));
  p_vert_tr = (double*)malloc((imax+2)*sizeof(double*));
  p_vert_br = (double*)malloc((imax+2)*sizeof(double*));

  for (i=0; i<imax+2; i++){
    p[i] = (double*)malloc((jmax+2)*sizeof(double));
    d[i] = (double*)malloc((jmax+2)*sizeof(double));
  }

  /* Initialisation: */
  MPI_Init(&argc,&argv);
  MPI_Comm_rank(MPI_COMM_WORLD,&myrank);
  MPI_Comm_size(MPI_COMM_WORLD,&nproz);

  if( ndimi*ndimj!=nproz ){
    printf("number of processes inadequate\n");
  }

  icart    = (int*)malloc(2*sizeof(int));
  idims    = (int*)malloc(2*sizeof(int));
  iperiods = (int*)malloc(2*sizeof(int));
  idims[0] = ndimi; idims[1] = ndimj;  /* dimensions in each direction */
  iperiods[0] = 0; iperiods[1] = 0;  /* no periodic boundaries */

  /* Create the new cartsian-based communicator */
  MPI_Cart_create(MPI_COMM_WORLD,2,idims,iperiods,0,&comm_cart);

  /* Get the own cartesian coordinates */
  MPI_Cart_coords(comm_cart,myrank,2,icart);

  /* Get the numbers (ranks) of the neighbours */
  MPI_Cart_shift(comm_cart,1,-1,&i_top,&i_bot);
  MPI_Cart_shift(comm_cart,0,-1,&i_rig,&i_lef);

  /* If desired: doublecheck the topology */
/*
  printf("%d: l:%d r:%d t:%d b:%d\n",myrank,i_lef,i_rig,i_top,i_bot);
  printf("%d: icart0: %d,  icart1: %d\n",myrank,icart[0],icart[1]);
*/

  for(i=0; i<imax+2; i++){
    for(j=0; j<jmax+2; j++){
      d[i][j] = 0.;
      if( i==0 || i==imax+1 || j==0 || j==jmax+1 )
        p[i][j] = 0.;
      else
        p[i][j] = 1.;
    }
  }


  res_ges = 10.*eps;
  k = 0;

  print_matrix_disc(p,imax,jmax,dx,dy,icart[0],icart[1],k);

  /* the iteration loop */

  while( k<k_stop && res_ges>eps ){

    res = 0.;

    /* Set boundary conditions only at real boundaries. */
    if( i_top==MPI_PROC_NULL )
      for(i=1; i<=imax; i++) p[i][jmax+1] = -p[i][jmax];
    if( i_bot==MPI_PROC_NULL )
      for(i=1; i<=imax; i++) p[i][0]      = -p[i][1];
    if( i_lef==MPI_PROC_NULL )
      for(j=1; j<=jmax; j++) p[0][j]      = -p[1][j];
    if( i_rig==MPI_PROC_NULL )
      for(j=1; j<=jmax; j++) p[imax+1][j] = -p[imax][j];

    /* Vertical communication: */
    /* Collect data for shift vertically: */
    if( i_top!=MPI_PROC_NULL )
      for (i=1; i<=imax; i++){ p_vert_ts[i]=p[i][jmax];
//      printf("%d: i=%d, p_vert_ts[i]=%f\n",myrank,i,p_vert_ts[i]);
    }
    if( i_bot!=MPI_PROC_NULL )
      for (i=1; i<=imax; i++) p_vert_bs[i]=p[i][1];

    MPI_Sendrecv(&(p_vert_ts[1]),imax,MPI_DOUBLE,i_top,0,
                 &(p_vert_br[1]),imax,MPI_DOUBLE,i_bot,0,
                 MPI_COMM_WORLD,&status);
    MPI_Sendrecv(&(p_vert_bs[1]),imax,MPI_DOUBLE,i_bot,0,
                 &(p_vert_tr[1]),imax,MPI_DOUBLE,i_top,0,
                 MPI_COMM_WORLD,&status);

    if( i_top!=MPI_PROC_NULL )    
      for (i=1; i<=imax; i++) p[i][jmax+1] = p_vert_tr[i];
    if( i_bot!=MPI_PROC_NULL )
      for (i=1; i<=imax; i++) p[i][0]      = p_vert_br[i];

    /* Horicontal communication: */
    MPI_Sendrecv(&(p[1][1]),jmax,MPI_DOUBLE,i_lef,0,
                 &(p[imax+1][1]),jmax,MPI_DOUBLE,i_rig,0,
                 MPI_COMM_WORLD,&status);
    MPI_Sendrecv(&(p[imax][1]),jmax,MPI_DOUBLE,i_rig,0,
                 &(p[0][1]),jmax,MPI_DOUBLE,i_lef,0,
                 MPI_COMM_WORLD,&status);

    for(i=1; i<=imax; i++){
      for(j=1; j<=jmax; j++){
        p[i][j] = -1./ (2./(dx*dx)+2./(dy*dy)) * ( d[i][j] -
            (p[i+1][j]+p[i-1][j])/(dx*dx) + (p[i][j+1]+p[i][j-1])/(dy*dy) );
        tmp = (p[i+1][j]-2.*p[i][j]+p[i+1][j])/(dx*dx)
            + (p[i][j+1]-2.*p[i][j]+p[i][j-1])/(dy*dy) - d[i][j];
        res += tmp*tmp;
      }
    }

    k++;

    MPI_Allreduce(&res,&res_ges,1,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD);
    res_ges = sqrt(res_ges);
    if( myrank==0 )
      printf("%d Iterationen, Residuum: %f\n",k,res);


    print_matrix_disc(p,imax,jmax,dx,dy,icart[0],icart[1],k);
  }

  MPI_Finalize();
  return 0;
}

