// LinSys.cpp
//
// Solvers for full and sparse linear systems
//
// Jiangguo (James) Liu, ColoState, 01/2007--10/2011
// With the great help of Rachel Cali (Spring 2007)


#include <cmath>
#include <cstdlib>
#include <iostream>
#include "LinSys.h"
#include "matrix.h"
#include "vector.h"
using namespace std;


// Solve a small full linear system by the 
// Gaussian Elimination with Partial Pivoting (GEPP) 

Vector slvFullLinSysGEPP(const FullMatrix &A, const Vector &b) 
{
   int i, j, k, m, n;
   // double cmax, tmp;
   double lier;
   double *p, *q;

   m = A.rowSize();
   n = A.colSize();

   if (m!=n) {
      cout << "Not a square matrix! EXit!\n";
      exit(EXIT_FAILURE);
   }

   p = new double[n*n];
   q = new double[n];

   for (i=1; i<=n; ++i) 
      for (j=1; j<=n; ++j) 
         p[(i-1)*n+(j-1)] = A(i,j);

   for (i=1; i<=n; ++i)  q[i-1] = b(i);

   for (j=1; j<=(n-1); ++j) {
/*
      cmax = fabs(p[(j-1)*n+(j-1)]);
      im = j;

      for (i=j+1; i<=n; ++i) 
         if (fabs(p[(i-1)*n+(j-1)])>cmax)  im = i;

      if (cmax==0) {
         cout << "Singular matrix! Exit!\n";
         cout << "Column " << j << "\n";
         exit(EXIT_FAILURE);
      }

      for (k=j; k<=n; ++k) {
         tmp = p[(j-1)*n+(k-1)];
         p[(j-1)*n+(k-1)] = p[(im-1)*n+(k-1)];
         p[(im-1)*n+(k-1)] = tmp;
      }

      tmp = q[j-1];
      q[j-1] = q[im-1];
      q[im-1] = q[j-1];
*/
      for (i=j+1; i<=n; ++i) {
         lier = p[(i-1)*n+(j-1)]/p[(j-1)*n+(j-1)];
         q[i-1] -= lier*q[j-1];
         for (k=j+1; k<=n; ++k) 
            p[(i-1)*n+(k-1)] -= lier*p[(j-1)*n+(k-1)];
      }
   }

   for (i=1; i<=n; ++i) {
      if (fabs(p[(i-1)*n+(i-1)])==0) {
         cout << "Singular matrix! Exit!\n";
         cout << "trouble at diagonal" << i << "\n";
         exit(EXIT_FAILURE);
      }
   }

   Vector x(n);

   x(n) = q[n-1]/p[n*n-1];

   for (i=(n-1); i>=1; --i) {
      x(i) = q[i-1];
      for (k=i+1; k<=n; ++k) 
         x(i) -= p[(i-1)*n+(k-1)]*x(k);
      x(i) /= p[(i-1)*n+(i-1)];
   }

   delete[] p, q;

   return x;
}


// Solve a small full lower triangular system

Vector slvFullLowerTrigSys(const FullMatrix &L, const Vector &b) 
{
   int i, j;
   double sum;

   int m = L.rowSize();
   int n = L.colSize();

   if (m!=n) {
      cout << "Not a square matrix!\n";
      exit(EXIT_FAILURE);
   }

   for (i=1; i<=(n-1); ++i) {
      for (j=i+1; j<=n; ++j) {
         if (L(i,j)!=0) {
            cout << "Not a lower triangular matrix!\n";
            break;
            // exit(EXIT_FAILURE);
         }
      }
   }

   for (i=1; i<=n; ++i) {
      if (L(i,i)==0.0) {
         cout << "Singular lower triangular matrix!\n";
         break;
         exit(EXIT_FAILURE);
      }
   }

   Vector x(n);

   x(1) = b(1)/L(1,1);
   for (i=2; i<=n; ++i) {
      sum = b(i);
      for (j=1; j<=(i-1); ++j)  sum -= L(i,j)*x(j);
      x(i) = sum/L(i,i);
   }

   return x;
}


// Solve a small full upper triangular system

Vector slvFullUpperTrigSys(const FullMatrix &U, const Vector &b) 
{
   int i, j;
   double sum;

   int m = U.rowSize();
   int n = U.colSize();

   if (m!=n) {
      cout << "Not a square matrix!\n";
      exit(EXIT_FAILURE);
   }

   for (i=2; i<=n; ++i) {
      for (j=1; j<=(i-1); ++j) {
         if (U(i,j)!=0) {
            cout << "Not an upper triangular matrix!\n";
            exit(EXIT_FAILURE);
         }
      }
   }

   for (i=1; i<=n; ++i) {
      if (U(i,i)==0.0) {
         cout << "Singular upper triangular matrix!\n";
         exit(EXIT_FAILURE);
      }
   }

   Vector x(n);

   x(n) = b(n)/U(n,n);
   for (i=(n-1); i>=1; --i) {
      sum = b(i);
      for (j=i+1; j<=n; ++j)  sum -= U(i,j)*x(j);
      x(i) = sum/U(i,i);
   }

   return x;
}


// Solve a small full SPD system by the Cholesky factorization

Vector slvFullSpdSysCholesky(const FullMatrix &A, const Vector &b) 
{
   int i, j, k;
   double sum1, sum2;

   int m = A.rowSize();
   int n = A.colSize();

   if (m!=n) {
      cout << "Not a square matrix! EXit!\n";
      exit(EXIT_FAILURE);
   }

   FullMatrix L(n,n);  // The lower triangular ChoLesky
   FullMatrix U(n,n);  // U = L' (transpose)

   for (j=1; j<=n; ++j) {
      sum1 = 0.0;
      for (k=1; k<=(j-1); ++k)  sum1 += L(j,k)*L(j,k);
      L(j,j) = sqrt(A(j,j)-sum1);
      for (i=j+1; i<=n; ++i) {
         sum2 = 0.0;
         for (k=1; k<=(j-1); ++k)  sum2 += L(i,k)*L(j,k);
         L(i,j) = (A(i,j)-sum2)/L(j,j);
      }
   }

   U = transpose(L);
   Vector x(n), y(n);

   y = slvFullLowerTrigSys(L, b);
   x = slvFullUpperTrigSys(U, y);

   return x;
}


// Solve a small full SPD system by the Conjugate Gradient method
// x brings in an initial guess and returns the solution 

void slvFullSpdSysCG(Vector &x, const FullMatrix &A, const Vector &b, 
   int &itr, int maxItr, double threshold, double tol) 
{
   int n;
   double alpha, b2, mu, nu, r2;

   b2 = b.l2norm();

   n = x.size();
   Vector r(n);
   Vector u(n);
   Vector v(n);

   itr = 0;
   r = b-A*x;
   v = r;
   mu = dotProd(r,r);

   while (itr<maxItr) {
      if (v.l2norm()<tol)  break;
      u = A*v;
      alpha = mu/dotProd(u,v);
      x += alpha*v;
      r -= alpha*u;
      nu = dotProd(r,r);
      r2 = sqrt(nu);
      if (r2<threshold*b2)  break;
      v = r + (nu/mu)*v;
      mu = nu;
      itr++;
   }

   return;
}


Vector slvDiagSys(const DiagMatrix &A, const Vector &b) 
{
   int n = b.size();
   Vector x(n);
   for (int i=1; i<=n; ++i)  x(i) = b(i)/A.getEntry(i);
   return x;
}


// Solve a nonsymmetric sparse linear system by 
// BiCGStab (Preconditioned Bi-Conjugate Gradient Stablized)
// Courtesy of Victor Ginting 
// cf: p.27 of the SIAM Templates book
// The return value indicates 
//    convergence within max_iter iterations (0), 
//    no convergence within max_iter iterations (1) 
// Upon successful return: 
//           x -- approximate solution to Ax = b
//    max_iter -- the number of iterations performed 
//                before the tolerance was reached
//         tol -- the residual after the final iteration

int slvSpaLinSysBiCGStab(Vector &x, 
   const SparseMatrix &A, const Vector &b, 
   const DiagMatrix &B, 
   int &max_iter, double &tol, double atol, int printit)
{
   double alpha, beta, omega, res, rho1, rho2;

   int n = A.getColSize();
   Vector p(n), phat(n), r(n), rtilde(n), s(n), shat(n), t(n), v(n);

   r = b - A*x;
   rtilde = r;

   res = r*r;
   if (printit)  cout << "itr " << 0 << ", (r,r)=" << res << "\n";

   tol *= res;
   tol = (atol > tol) ? atol : tol;

   if (res<=tol) {
      tol = res;
      max_iter = 0;
      return 0;
   }

   for (int i=1; i<=max_iter; i++) {
      rho1 = rtilde*r;

      if (rho1==0) {
         tol = res;
         if (printit)  cout << "itr " << i << ", (r,r)=" << res << "\n";
         return 2;
      }

      if (i==1)
         p = r;
      else {
         beta = (rho1/rho2)*(alpha/omega);
         p = p - omega*v;
         p = r + beta*p;
      }

      phat = slvDiagSys(B,p);
      v = A*phat;

      alpha = rho1/(rtilde*v);
      s = r - alpha*v;
      res = v*v;

      if (res<tol) {
         x = x + alpha*phat;
         tol = res;
         if (printit)  cout << "itr " << i << ", (s,s)=" << res << "\n";
         return 0;
      }

      if (printit)  cout << "itr " << i << ", (s,s)=" << res << ";";

      shat = slvDiagSys(B,s);
      t = A*shat;
      omega = (s*t)/(t*t);

      x += alpha*phat;
      x += omega*shat;
      r = s - omega*t;

      rho2 = rho1;
      res = r*r;

      if (printit)  cout << "(r,r)=" << res << "\n";

      if (res<tol) {
         tol = res;
         max_iter = i;
         return 0;
      }

      if (omega==0) {
         tol = res;
         return 3;
      }
   }

   tol = res;
   return 1;
}


// Solve a sparse SPD system by the Conjugate Gradient method
// x brings in an initial guess and returns the solution 

void slvSpaSpdSysCG(Vector &x, 
   const SparseMatrix &A, const Vector &b,
   int &itr, int maxItr, double threshold, double tol) 
{
   int n;
   double alpha, b2, mu, nu, r2;

   b2 = b.l2norm();

   n = x.size();
   Vector r(n);
   Vector u(n);
   Vector v(n);

   itr = 0;
   r = b-A*x;
   v = r;
   mu = dotProd(r,r);

   while (itr<maxItr) {
      if (v.l2norm()<tol)  break;
      u = A*v;
      alpha = mu/dotProd(u,v);
      x += alpha*v;
      r -= alpha*u;
      nu = dotProd(r,r);
      r2 = sqrt(nu);
      if (r2<threshold*b2)  break;
      v = r + (nu/mu)*v;
      mu = nu;
      itr++;
   }

   return;
}


// Solve a sparse block SPD system by the Conjugate Gradient method
// x brings in an initial guess and returns the solution 

void slvSpaBlkSpdSysCG(Vector &x, 
   const SparseBlockMatrix &A, const Vector &b, 
   int &itr, int maxItr, double threshold, double tol) 
{
   int n;
   double alpha, b2, mu, nu, r2;

   b2 = b.l2norm();

   n = x.size();
   Vector r(n);
   Vector u(n);
   Vector v(n);

   itr = 0;
   r = b-A*x;
   v = r;
   mu = dotProd(r,r);

   while (itr<maxItr) {
      if (v.l2norm()<tol)  break;
      u = A*v;
      alpha = mu/dotProd(u,v);
      x += alpha*v;
      r -= alpha*u;
      nu = dotProd(r,r);
      r2 = sqrt(nu);
      if (r2<threshold*b2)  break;
      v = r + (nu/mu)*v;
      mu = nu;
      itr++;
   }

   return;
}


// Solve a block diagonal SPD system by the Cholesky facorization 
// assuming each diagonal block is small and SPD

void slvBlkDiagSpdSysCholesky(Vector &x, 
   const BlockDiagMatrix &A, const Vector &b) 
{
/*
   int nb = A.numBlks();

   for (int ib=1; ib<=nb; ++ib) {
      FullMatrix B(A.get(ib));
      // Vector r(b.bgnBlk(ib), b.dimBlk(ib));
      // Vector xb(slvFullSpdSysCholesky(B,r));
      // Vector xb(slvFullSpdSysCholesky(B,bb); 
      
   }
*/
   return ;
}


// Solve a spa(rse) block SPD system by the Conjugate Gradient 
// Use a block diagonal SPD matrix (B) as the preconditioner 

void slvSpaBlkSpdSysPCG(Vector &x, 
   const SparseBlockMatrix &A, const Vector &b, 
   const BlockDiagMatrix &B, 
   int &itr, int maxItr, double threshold, double tol) 
{
   double alpha, b2, mu, nu, r2;

   b2 = b.l2norm();

   int n = x.size();
   Vector r(n);
   Vector s(n);
   Vector u(n);
   Vector v(n);

   itr = 0;
   r = b-A*x;
   slvBlkDiagSpdSysCholesky(s, B, r);
   mu = dotProd(r,s);
   v = s;

   while (itr<maxItr) {
      if (v.l2norm()<tol)  break;
      u = A*v;
      alpha = mu/dotProd(u,v);
      x += alpha*v;
      r -= alpha*u;
      slvBlkDiagSpdSysCholesky(s, B, r);
      nu = dotProd(r,s);
      r2 = sqrt(dotProd(r,r));
      if (r2<threshold*b2)  break;
      v = s + (nu/mu)*v;
      mu = nu;
      itr++;
   }

   return;
}
