// $Id: valmatrix.h 1247 2021-12-13 18:08:41Z ge $
/// \file valmatrix.h
/// \brief contains valmatrix class
///
/// $Revision: 1247 $
/// \author Gerald Weber <gweberbh@gmail.com>
#ifndef VALMATRIX_H
#define VALMATRIX_H "$Id: valmatrix.h 1247 2021-12-13 18:08:41Z ge $"
#include <valarray>
#include <complex>
#include <limits>
#include <iostream>
#include "ErrorCodes.h"

namespace gbc
  {
  template<class _Tp, class _VectorTp=std::valarray<_Tp> >
  class valmatrix
    {
    public:
    typedef _Tp value_type;
    typedef _VectorTp vector_type;

    static const int Undefined  = 0; ///< No shape defined
    static const int Diagonal   = 1; ///< Diagonal matrix
    static const int ColumnWise = 1;

    size_t Rows, Columns, Current;
    vector_type Array;
    int Shape;
    bool Lock_size;

    valmatrix(void): Rows(), Columns(), Current(), Array(), 
                     Shape(Undefined), Lock_size(false)  {}

    valmatrix(size_t r, size_t c, value_type t=value_type()): 
      Rows(r), Columns(c), Current(), Array(t,r*c), 
      Shape(Undefined), Lock_size(false) {}

    inline static std::string class_name(void) {return std::string("valmatrix<>");}

    inline size_t size(void) {return Array.size();}

    inline void lock_size(bool ls=true) {Lock_size=ls;}

    /// \brief Resizes the number of rows and columns of the array.
    ///
    /// The method first checks if the total size of the array changes
    /// and then resizes the valarray if necessary. 
    inline bool resize(const size_t& r, const size_t& c)
      {
      if (r*c != Columns*Rows)
        {
        Array.resize(r*c);
        Rows=r; Columns=c;
        return true;
        }
      else
        {
        if ((r != Rows) || (c != Columns))
          {
          Rows=r; Columns=c;
          return true;
          }
        }
      return false;
      }
    
    inline void clear(void)
      {
      Rows=Columns=size_t();
      Array.resize(size_t());
      Shape=Undefined; 
      Lock_size=false;
      }
      

    template<class _Tp2>
    /// \brief Assigns a generic vector to the array.
    inline valmatrix& operator=(const std::valarray<_Tp2>& val)
      {
      if (Rows*Columns != val.size()) 
        {
        CERR_ERROR(ERRTAVCS) << "Trying to assign vector of size " << val.size()
                 << " to " << Rows << "x" << Columns << " matrix" << std::endl;
        CERR_TERM
        }
      for (size_t n=0; n < Rows*Columns; ++n) Array[n]=value_type(val[n]); 
      return *this;
      }

    /// \brief Assigns another valarray.
    inline valmatrix& operator=(const valmatrix& val)
      {
      Shape=val.Shape;
      resize(val.Rows,val.Columns);
      Array=val.Array;
      return *this;
      }

    /// \brief Assigns a generic array to the array.
    inline valmatrix& operator=(const value_type* val)
      {
      for (size_t n=0; n < Rows*Columns; ++n) Array[n]=value_type(val[n]); 
      return *this;
      }

    /// \brief Read-only access to the (r,c)-th element.
    inline value_type operator() (size_t r, size_t c) const
      {return Array[r*Columns+c];}

    /// \brief Access to the (r,c)-th element.
    inline value_type& operator() (size_t r, size_t c) 
      {return Array[r*Columns+c];}

    /// \brief Provides access to the r-th element of the internal std::valarray
    inline value_type operator[] (size_t r) const 
      {return Array[r];}

    /// \brief Provides access to the r-th element of the internal std::valarray
    inline value_type& operator[] (size_t r) 
      {return Array[r];}

    /// \brief Provides std::slice access to the internal std::valarray
    inline vector_type& operator[] (std::slice sl) 
      {return Array[sl];}

    inline operator vector_type (void) {return Array;}

    inline valmatrix row(size_t r) const
      {
      valmatrix res(1,Columns);
      res.Array=Array[std::slice(r*Columns,Columns,1)];
      return res;
      }

    inline valmatrix column(size_t c) const
      {
      valmatrix res(Rows,1);
      res.Array=Array[std::slice(c,Rows,Columns)];
      return res;
      }

    inline void column(size_t c, const valmatrix& vl)
      {
      Array[std::slice(c,Rows,Columns)]=vl.Array;
      }

    inline void column(size_t c, const vector_type& vl)
      {
      Array[std::slice(c,Rows,Columns)]=vl[std::slice(0,Rows,1)];
      }

    inline valmatrix submatrix(size_t first_row, size_t first_column, 
                               size_t last_row,  size_t last_column) const
      {
      if ((Rows < last_row) || (Columns < last_column))
        {
        CERR_ERROR(ERRSUBNPOS) << "Submatrix of ( " << first_row << "," << first_column
                  << "," << last_row << "," << last_column
                 << " from a matrix of size " << Rows << "x" << Columns << " is not possible." << std::endl;
        CERR_TERM
        }
      valmatrix res(last_row-first_row,last_column-first_column);
      if ((this->Shape == Diagonal) && (first_row==first_column) && (last_row==last_column)) 
        res.Shape=Diagonal;
        
      for(size_t i=0; i < res.Rows; ++i)
        for(size_t j=0; j < res.Columns; ++j)
           res(i,j)=(*this)(i+first_row,j+first_column);
      return res;
      }

    template<class _Tp2>
    inline void diagonal(const _Tp2& vl)
      {
      resize(Rows,Rows);
      for (size_t n=0; n< Rows; ++n) (*this)(n,n)=vl[n];
      Shape=Diagonal;
      }

    inline void diagonal(const value_type& vl)
      {
      Array=value_type();
      for (size_t n=0; n< Rows; ++n) (*this)(n,n)=vl;
      Shape=Diagonal;
      }

    inline void zero(void)
      {
      for (size_t n=0; n< Rows*Columns; ++n) Array[n]=value_type();
      Shape=Undefined;
      }

    inline void zero(size_t r, size_t c)
      {
      resize(r,c);
      zero();
      }

    inline value_type trace(void)
      {
      value_type tr=value_type();
      for (size_t n=0; n< Rows; ++n) tr+=(*this)(n,n);
      return tr;
      }

    template<class _Tp2>
    inline valmatrix& operator<<(const _Tp2& val)
      {
      if (Current >= Array.size()) Current=size_t();
      Array[Current++]=val;
      return *this;
      }

    template<class _Tp2>
    inline void operator*=(const _Tp2& val)
      {
      this->Array *= val;
      } 

    inline void reverse(void)
      {
      vector_type tmp=Array[std::slice(Array.size()-1,Array.size(),-1)];
      Array=tmp;
      }

    inline void reverse_columns(void)
      {
      vector_type res(Rows*Columns);
      for (size_t c=0; c < Columns; ++c)
        res[std::slice(c,Rows,Columns)]=Array[std::slice(Columns-c-1,Rows,Columns)];
      Array=res;
      }

    inline friend valmatrix transpose(const valmatrix& vm)
      {
      valmatrix res(vm.Columns,vm.Rows);
      for(size_t i=0; i < res.Rows; ++i)
        for(size_t j=0; j < res.Columns; ++j)
           res(i,j)=vm(j,i);
      return res;

      }

    inline friend valmatrix operator+(const valmatrix& v1, const valmatrix& v2)
      {
      if ((v1.Columns != v2.Columns) && (v2.Columns != v2.Columns))
        {
        CERR_ERROR(ERRMATDNA) << "Matrices of dimensions "
        << v1.Rows << "x" << v1.Columns << " and " 
        << v2.Rows << "x" << v2.Columns << " can not be added"; CERR_TERM
        }
      valmatrix res=v1;
      res.Array+=v2.Array;
      return res;
      }

    inline friend valmatrix operator-(const valmatrix& v1, const valmatrix& v2)
      {
      if ((v1.Columns != v2.Columns) && (v2.Columns != v2.Columns))
        {
        CERR_ERROR(ERRMATDNS) << "Matrices of dimensions "
        << v1.Rows << "x" << v1.Columns << " and " 
        << v2.Rows << "x" << v2.Columns << " can not be substracted"; CERR_TERM
        }
      valmatrix res=v1;
      res.Array-=v2.Array;
      return res;
      }
    template<class _Tp2>
    inline friend valmatrix operator*(const valmatrix<_Tp2>& v1, const valmatrix& v2) 
      {
      if (v1.Columns != v2.Rows)
        {
        CERR_ERROR(ERRMATDNM) << "Matrices of dimensions "
        << v1.Rows << "x" << v1.Columns << " and " 
        << v2.Rows << "x" << v2.Columns << " can not be multiplied"; CERR_TERM
        }
      valmatrix res(v1.Rows,v2.Columns);
      size_t common_dim=v1.Columns;

      for(size_t i=0; i < res.Rows; ++i)
        {
        for(size_t j=0; j < res.Columns; ++j)
          {
          res(i,j)=value_type();
          if (v1.Shape==Diagonal)
            {res(i,j)=v1(i,i)*v2(i,j);}
          else
            {
            if (v2.Shape==Diagonal)
              {res(i,j)=v1(i,j)*v2(j,j);}
            else
              for (size_t k=0; k < common_dim; ++k) res(i,j)=res(i,j)+v1(i,k)*v2(k,j);
            }
          }
        }
      return res;
      }

    inline friend valmatrix operator*(const valmatrix& v1, value_type v)
      {
      valmatrix res=v1;
      res.Array *= v;
      return res;
      } 

    inline friend std::ostream& operator<<(std::ostream& out, const valmatrix& vl)
      {
      out << vl.class_name() << " " << vl.Shape << " "<< vl.Rows << " " << vl.Columns << std::endl;
      std::ios::fmtflags oldflags=out.flags();
      out.setf(std::ios::scientific);
      out.precision(std::numeric_limits<value_type>::digits10+1);
      for (size_t n=0; n < vl.Rows; ++n)
        {
        if (vl.Shape == valmatrix::Diagonal) 
          out << vl(n,n) << " ";
        else
          for (size_t m=0; m < vl.Columns; ++m) out << vl(n,m) << " ";
          out << std::endl;
        }
      out.flags(oldflags);
      return out;
      }

    inline friend std::istream& operator>>(std::istream& in, valmatrix& vl)
      {
      std::string class_name;
      while((class_name != vl.class_name()) && !in.eof())
        in >> class_name;
      if (in.eof()) 
        {CERR_ERROR(ERRMFNCOMP) << "Matrix file seems not to be compatible with "<< vl.class_name() << std::endl; CERR_TERM}
      in >> vl.Shape;

      size_t InRows, InColumns;
      in >> InRows >> InColumns;
      if (vl.Lock_size)
        {
        if ((InRows < vl.Rows) || (InColumns < vl.Columns))
          {CERR_ERROR(ERRMFSST) << "Matrix in file is of size "<< InRows << "x" << InColumns
                    << " which smaller than " << vl.Rows << "x" << vl.Columns << std::endl; CERR_TERM}
        } 
      else vl.resize(InRows,InColumns);

      if (vl.Shape == valmatrix::Diagonal)
        {
        vl.Array=value_type();
        for (size_t n=0; n < vl.Rows; ++n) in >> vl(n,n);
        }
      else
        {
        for (size_t n=0; n < vl.Rows; ++n)
          { 
          for (size_t m=0; m < vl.Columns; ++m) in >> vl(n,m);
          if (vl.Columns < InColumns)
            {
            value_type ign;
            for (size_t m=vl.Columns; m < InColumns; ++m) in >> ign;
            }
          }
        }
      return in;
      }

    inline void real(const valmatrix<double>& vl)
      {this->Array=vl.Array;}

    inline void imag(const valmatrix<double>& vl)
      {this->Array=vl.Array;}

    inline void conj(void) {}

    };

  template<>
  inline std::string valmatrix<double>::class_name(void) 
    {return std::string("valmatrix<double>");}


  template<class _Tp>
  inline valmatrix<_Tp> real(const valmatrix<std::complex<_Tp> >& vl)
    {
    valmatrix<_Tp> res(vl.Rows,vl.Columns);
      {
      for (size_t n=0; n < vl.Rows*vl.Columns; ++n) res[n]=real(vl[n]); 
      return res;
      }
    }

  template<class _Tp>
  inline valmatrix<_Tp> imag(const valmatrix<std::complex<_Tp> >& vl)
    {
    valmatrix<_Tp> res(vl.Rows,vl.Columns);
      {
      for (size_t n=0; n < vl.Rows*vl.Columns; ++n) res[n]=imag(vl[n]); 
      return res;
      }
    }

  template<>
  inline void valmatrix<std::complex<double> >::real(const valmatrix<double>& vl)
    {
    this->resize(vl.Rows,vl.Columns);
      {
      for (size_t n=0; n < vl.Rows*vl.Columns; ++n) 
        this->Array[n]=std::complex<double>(vl[n],std::imag(this->Array[n])); 
      }
    }

  template<>
  inline void valmatrix<std::complex<double> >::imag(const valmatrix<double>& vl)
    {
    this->resize(vl.Rows,vl.Columns);
      {
      for (size_t n=0; n < vl.Rows*vl.Columns; ++n) 
        this->Array[n]=std::complex<double>(std::real(this->Array[n]),vl[n]); 
      }
    }

  template<>
  inline void valmatrix<std::complex<double> >::conj(void) 
    {
    for (size_t n=0; n < Rows*Columns; ++n) 
      Array[n]=std::conj(Array[n]);
    }


  };
#endif
