// $Id: RegressionModel.h 1369 2024-11-29 14:20:00Z ge $
/// \file RegressionModel.h
/// \brief contains the Regression function class
///
/// $Revision: 1369 $
/// \author Gerald Weber <gweberbh@gmail.com>
#ifndef GBC_REGRESSIONMODEL_H
#define GBC_REGRESSIONMODEL_H "$Id: RegressionModel.h 1369 2024-11-29 14:20:00Z ge $"
#include "Regression.h"
#include "SequenceInfo.h"
#include "SequenceDataset.h"
#include "Actions.h"
#include "ErrorCodes.h"

namespace gbc
{
template<class _Tp=double, class _Action=ActionIndex>
class RegressionModel
  {
  public:
    typedef _Tp value_type;
    typedef std::valarray<value_type> vector_type;
    typedef std::pair<std::string,std::string> sequence_type;
    typedef size_t length_type;
    typedef std::map<length_type,value_type> length_map_type;
    typedef std::map<value_type,length_map_type> salt_length_map_type;
    typedef std::map<value_type,value_type> salt_index_map_type;
    typedef SequenceInfo<true,value_type,sequence_type> sequence_info_type; //true refers to periodicity, for HTM this needs to be true
    typedef std::deque<sequence_info_type> sequence_info_deque_type;
    typedef std::map<length_type,    ///< Sequence length
                     sequence_info_deque_type> length_si_type;

    typedef SequenceDataset<sequence_info_type> sequence_dataset_type;
    typedef typename sequence_dataset_type::string_deque_type string_deque_type;
    typedef typename sequence_info_type::salt_map_type salt_map_type;
    
    
    enum group_flags {Salt,Ct,Key}; //group flags for prediction_method=-2

    sequence_dataset_type dataset;

    HeterogenousTM<value_type>* HTM;
    value_type thermal_equivalence_expoent;   ///< Regression expoent for thermal_equivalence, usually 0.5
    value_type N_expoent;                     ///< Regression expoent for N, usually 0.5
    size_t Min_a_regression;                  ///< Minimum number of datapoints for a0 and a1 regression
    length_type Min_length_regression;             ///< Minimum length for including in length regression
    int prediction_method;
    group_flags prediction_group;            ///< by default prediction is grouped by Salt

    bool Recalculate_regression;             ///< Flag to check if regression should be redone, usually true but is set to false if regression file is read
    bool Fixed_regression;                   ///< For fixed regression we use coefficients from file and never calculate regressions

    std::string Debug;                       ///< Debug flag, should be __FILE__

    std::vector<value_type> log_salt;
    salt_length_map_type a0, a1;
    salt_index_map_type b00, b01, b10, b11;
    value_type c000, c100, c001, c101, c010, c110, c011, c111, all0, all1;

    std::string data_set_type;

    typedef ReferenceSet<NeighbourSequence<> >         neighbours_ref_type;
    typedef ReferenceSet<Duplex<> >                    duplex_ref_type;
    neighbours_ref_type  Neighbours_set;          ///< The set of different base pair neighbours
    duplex_ref_type      Duplex_set;              ///< The set of different base pairs

    salt_map_type       Parameter_salt_concentration; ///< the salt concentration for which the parameters are valid

    RegressionModel(HeterogenousTM<value_type>& htm, ///< Here we plug-in the calculation
                    value_type eo=0.5,               ///< Regression expoent for the thermal equivalence
                    value_type ne=0.5)               ///< Regression expoent for the sequence length dependence
      : HTM(&htm), thermal_equivalence_expoent(eo), N_expoent(ne), 
        Min_a_regression(3), Min_length_regression(0), prediction_method(2), prediction_group(Salt), Recalculate_regression(true), Fixed_regression(false)
        {
        dataset.Species_concentration_necessary=false;
        }


    sequence_info_type add_temperature(const std::string& seq,        ///< The sequence string 
             const std::string& comp,       ///< The complimentary sequence string 
             const value_type& salt_conc,   ///< The salt concentration
             const value_type& temperature, ///< The experimental melting temperature
             const value_type& spec_conc)   ///< Species concentration.
      {
      sequence_info_type si=dataset.add_temperature(seq,comp,salt_conc,temperature,spec_conc);
      si.thermal_equivalence=calculate_thermal_equivalence(&si);
      si.thermal_index=pow(si.thermal_equivalence,thermal_equivalence_expoent);
      return si;
      }
      
    //Rules are read from parameter files, make sure that this is set before read_xml
    //otherwise the rules are not applied to the dataset
    void read_rules_and_parameters_from_list_of_files(string_deque_type &par_file_deque)
      {
      dataset.read_rules_from_list_of_files(par_file_deque);
      HTM->get_parameters_from_list_of_files(par_file_deque);
      }

    /// \brief Reads previously calculated c and b regression coefficients.
    ///
    /// We do not need the a coefficient since they are specific for N, instead
    /// we generate them from b coefficients. Therefore, generating predictions
    /// from these files may lead to poorer results.
    void read_c(std::string filename)
      {
      std::ifstream reg;
      COUT_INFO(INFORRCF) << " reading regression coefficients from " << filename << std::endl;
      reg.open(filename.c_str());
      if (!reg.good())
        {
        CERR_ERROR(ERRCNOF) << "Could not open file " << filename << std::endl;
        CERR_TERM
        }

      Recalculate_regression=false;

      value_type ct;
      reg >> thermal_equivalence_expoent >> N_expoent >> ct;
      if (prediction_group != Ct) dataset.Species_concentration.push_back(ct); //ATTENTION Species_concentration is later used in read_xml, needs revision

      std::string tp;
      reg >> tp;
      while ((prediction_method == -2) and tp==std::string("a") and !reg.eof()) 
        {
        double sl, zero;
        reg >> sl; log_salt.push_back(log(sl));
        reg >> zero >> a0[sl][0] >> a1[sl][0];
        reg >> tp;
        }
          
      if (tp==std::string("c"))
        {
        reg >> c000 >> c100 
            >> c001 >> c101 
            >> c010 >> c110 
            >> c011 >> c111;
        reg >> tp;
        }
      while (tp==std::string("b"))
        {
        double sl;
        reg >> sl; log_salt.push_back(log(sl));
        COUT_INFO(INFORTBC) << "Reading type b coefficients for salt concentration " << sl << " mM" << std::endl; 
        reg >> b00[sl] >> b01[sl] >> b10[sl] >> b11[sl];
        reg >> tp;
        }
      if  (tp==std::string("all"))
        {
        reg >> all0 >> all1;
        reg >> tp;
        }

      reg.close();
      }

    void write_b(std::ostream& out)
      {
      typename salt_index_map_type::const_iterator st;
      for(st=b00.begin(); st != b00.end(); ++st)
        {
        value_type salt=st->first;
        out << "b "<< std::endl;
        out << salt << " " << b00[salt] << " " << b01[salt] << " " 
            << b10[salt] << " " << b11[salt] << std::endl;
        }
      }

    void write_a(std::ostream& out)
      {
      typename salt_length_map_type::const_iterator st0, st1;
      if (a0.size() == 0)
        out << "Set a0 empty"<< std::endl;
      for(st0=a0.begin(), st1=a1.begin(); st0 != a0.end(); ++st0, ++st1)
        {
        out << "a " << std::endl;
        out << st0->first << std::endl; //writes the salt or Ct concentration
        typename length_map_type::const_iterator l0, l1;
        for(l0=st0->second.begin(), l1=st1->second.begin(); 
            l0 != st0->second.end(); ++l0, ++l1)
          out << l0->first  << " "  //writes the length N
              << l0->second << " " //writes a0
              << l1->second << std::endl; //writes a1
        }
      }

    void write_c(std::ostream& out)
      {
      out << "c " << std::endl;
      out << c000 << " " << c100 << std::endl;
      out << c001 << " " << c101 << std::endl;
      out << c010 << " " << c110 << std::endl;
      out << c011 << " " << c111 << std::endl;
      }


    void write_all(std::ostream& out)
      {
      out << "all " << std::endl;
      out << all0 << " " << all1 << std::endl;
      }

    /// \brief Calculates the regression for all data
    ///
    /// Calculates the regression for all data without separating into length groups, -pm=-1
    inline void calculate_regression_simple(void)
      {
      Regression<value_type> thermal_equivalence_reg(thermal_equivalence_expoent);
      thermal_equivalence_reg.x(dataset.thermal_equivalence_vector()); 
      thermal_equivalence_reg.y(dataset.temperature_vector());
      thermal_equivalence_reg.calculate_regression();
      all0=thermal_equivalence_reg.c0;
      all1=thermal_equivalence_reg.c1;
      CERR_DEBUG(DRMO_THEQREG) << "all0= " << all0 << ", all1= " <<  all1 << std::endl;
      }

    /// \brief Calculates the a coefficients for all lengths with more than 3 samples.
    ///
    /// For each salt concentration we calculate the regression for a given
    /// sequence length.
    inline void calculate_regression_a(typename sequence_dataset_type::salt_length_dataset_type::iterator& salt_it)
      {
      typename sequence_dataset_type::length_dataset_type::iterator length_it;
      if (salt_it->second.size()==0) CERR_WARN(WLSEMPTY) << "Length set empty" << std::endl;
      for (length_it=salt_it->second.begin(); length_it != salt_it->second.end(); ++length_it)
        {
        length_type length=length_it->first;
        if (length >= Min_length_regression)
	  {
          value_type  salt=salt_it->first;
          if (length_it->second.size() >= Min_a_regression)
            {
            Regression<value_type> thermal_equivalence_reg(thermal_equivalence_expoent);
            thermal_equivalence_reg.x(dataset.thermal_equivalence_vector(length_it)); 
            thermal_equivalence_reg.y(dataset.temperature_vector(length_it));
            thermal_equivalence_reg.calculate_regression();
	    if (thermal_equivalence_reg.c1 > 0)
	      {
              a0[salt][length]=thermal_equivalence_reg.c0;
              a1[salt][length]=thermal_equivalence_reg.c1;
	      }
	    else
	      {
	      CERR_DEBUG(DRMO_SLENNEGA1) << "salt=" << salt << " length=" << length << "negative a1= " <<  thermal_equivalence_reg.c1 << std::endl;
	      }
            CERR_DEBUG(DRMO_SLENNEGA0) << __FILE__ << "salt=" << salt << " length=" << length << ", a0= " << a0[salt][length] << ", a1= " <<  a1[salt][length]<< std::endl;
	    }
          else
	    {
            CERR_DEBUG(DRMO_NEDPRS) << ":Not enough data points for regression salt=" << salt << " length=" << length << ", available= " << length_it->second.size() << ", required=" << Min_a_regression << std::endl;
	    }
	  }
        }
      }

    /// \brief Calculates the b coefficients for each salt concentrations
    ///
    /// The b coefficient calculates the regression for varying N.
    /// Therefore we iterate over all length-groups.
    inline void calculate_regression_b(void)
      {
      log_salt.clear(); b00.clear(); b01.clear(); b10.clear(); b11.clear();
      typename sequence_dataset_type::salt_length_dataset_type::iterator salt_it;
      for (salt_it=dataset.Salt_length_dataset.begin(); salt_it != dataset.Salt_length_dataset.end(); ++salt_it)
        {
        log_salt.push_back(log(salt_it->first));
        calculate_regression_a(salt_it);
        if (a0[salt_it->first].size() > 1)
          {
          Regression<value_type> N_reg(N_expoent);
          N_reg.xy(a0[salt_it->first]);
          N_reg.calculate_regression();
          b00[salt_it->first]=N_reg.c0;
          b10[salt_it->first]=N_reg.c1;
          N_reg.xy(a1[salt_it->first]);
          N_reg.calculate_regression();
          b01[salt_it->first]=N_reg.c0;
          b11[salt_it->first]=N_reg.c1;
          }
        }
      }
      
    /// \brief Calculates the a coefficients for each group key, uses code for salt
    /// generates a coefficients for -pm=-2 
    inline void calculate_regression_a_group(typename sequence_dataset_type::salt_length_dataset_type &set)
      {
      typename sequence_dataset_type::salt_length_dataset_type::iterator it;
      for (it=set.begin(); it != set.end(); ++it)
        {
        calculate_regression_a(it);
        }
      }

      /// \brief Calculates the a coefficients for each group key, uses code for salt
    /// generates a coefficients for -pm=-2 
    inline void calculate_regression_a_Key(void)
      {
      calculate_regression_a_group(dataset.Key_dataset);
      }

    /// \brief Calculates the a coefficients for each Ct concentrations, uses code for salt
    /// generates a coefficients for -pm=-2 
    inline void calculate_regression_a_Ct(void)
      {
      calculate_regression_a_group(dataset.Ct_dataset);
      }

    /// \brief Calculates the a coefficients for each salt concentrations
    /// generates a coefficients for -pm=-2 
    inline void calculate_regression_a_salt(void)
      {
      calculate_regression_a_group(dataset.Salt_dataset);
      }

    /// \brief Calculated all regression coefficients.
    ///
    /// If there is more than one salt concentration in the experimental
    /// data then the c coefficients are calculated. Otherwise we calculate only
    /// the b coefficients.
    inline void calculate_regression_abc(void)
      {
      if (Fixed_regression) return;
        
      CERR_DEBUG(DRMO_PREDMETH) << "prediction_method = " << prediction_method << std::endl;
      
      if (prediction_method == -1)
        {
	calculate_regression_simple();
        return;
	}
	
      if (prediction_method == -2) //like -1 but with seperate a coefficient for each salt
        {
        CERR_DEBUG(DRMO_PREDGRP)  << "prediction_group = " << prediction_group << std::endl;
      
        switch (prediction_group)
          {
          case Salt: calculate_regression_a_salt(); break;
          case Ct  : calculate_regression_a_Ct();   break;
          case Key : calculate_regression_a_Key();  break;
          }
        return;
	}
	
      calculate_regression_b();

      if (prediction_method == 3)
        {
        Regression<value_type> salt_reg;
        salt_reg.x(log_salt);

        salt_reg.y(b00);
        salt_reg.calculate_regression();
        c000=salt_reg.c0; c100=salt_reg.c1;

        salt_reg.y(b01);
        salt_reg.calculate_regression();
        c001=salt_reg.c0; c101=salt_reg.c1;

        salt_reg.y(b10);
        salt_reg.calculate_regression();
        c010=salt_reg.c0; c110=salt_reg.c1;

        salt_reg.y(b11);
        salt_reg.calculate_regression();
        c011=salt_reg.c0; c111=salt_reg.c1;
        }

     }

    /// \brief Calculate the regression and saves is to a file.
    inline void calculate_regression_and_save(std::ostream& out)
      {
      calculate_regression_abc();
      out << thermal_equivalence_expoent << " " << N_expoent << " " << dataset.Species_concentration[0] << std::endl;
      if (prediction_method == -1) {write_all(out); return;}
      if (prediction_method == -2) {write_a(out); return;}
      if (prediction_method == 3) write_c(out);
      write_b(out); write_a(out);
      }

   inline void print_verify(std::ostream& out) //old name verify
      {
      dataset.print_verify(out);
      out << " prediction method=" << prediction_method 
          << std::endl;
      if (prediction_method == -2) out << "prediction_group= " << prediction_group << std::endl;
      }

   inline void calculate_and_print_verify(std::ostream& out)
      {
      dataset.calculate_and_print_verify(out);
      out << " prediction method=" << prediction_method 
          << std::endl;
      if (prediction_method == -2) out << "prediction_group= " << prediction_group << std::endl;
      }

   inline void print_tex(std::ostream& out)
     {
     out << "\\Action{" << _Action::name() << "}" << std::endl;
     out << "\\Model{" << HTM->hamiltonian() << "}" << std::endl;
     dataset.print_tex(out);
     }

    /// \brief Calculate the thermal equivalence
    inline value_type calculate_thermal_equivalence(sequence_info_type* psi)
      {
      HTM->sequence(psi);
      value_type thermal_equivalence;
      Looping<_Action>::result_value(thermal_equivalence,*HTM,0);
      return thermal_equivalence;
      }

    /// \brief Calculate the thermal equivalence
    inline std::pair<value_type,value_type> calculate_partition_function_and_helmholtz_energy(sequence_info_type* psi)
      {
      std::pair<value_type,value_type> result;
      HTM->sequence(psi);
      Looping<ActionZyNoRecalculate>::result_value(result.first,*HTM,0);
      result.second=HTM->Fy;
      return result;
      }

    /// \brief Calculate average_y for a sequence
    inline void calculate_average_y(sequence_info_type* psi)
      {
      HTM->sequence(psi);
      Looping<ActionAverageY>::result_vector(psi->average_y,*HTM,0);
      }
      
    /// \brief Recalculate all thermal_equivalences
    inline void recalculate_all_thermal_equivalences(void)
      {
      typename sequence_dataset_type::sequence_info_deque_type::iterator si=dataset.Raw_dataset.begin();
      HTM->Always_recalculate_matrices=false;//change to true for debuging purposes
      //Sequence evaluation needs to be done only once to populate de TM_map
      if (HTM->TM_map.size() <= (size_t)1)//TM_map has always at least one element which is the base expansion, usually CG_CG
        {
        for(si=dataset.Raw_dataset.begin(); si != dataset.Raw_dataset.end(); si++)
          { 
          HTM->pSI=&(*si);
          HTM->evaluate_sequence_information();
          }
        }
      HTM->retrieve_or_calculate_matrices();
      for(si=dataset.Raw_dataset.begin(); si != dataset.Raw_dataset.end(); si++)
        { 
        HTM->pSI=&(*si);
        si->thermal_equivalence=HTM->melting_index();
        si->thermal_index=pow(si->thermal_equivalence,thermal_equivalence_expoent);
        }
      }
      
    /// \brief Get partition function
    inline void get_partition_function_and_helmholtz_energy(void)
      {
      typename sequence_dataset_type::sequence_info_deque_type::iterator si=dataset.Raw_dataset.begin();
      for(si=dataset.Raw_dataset.begin(); si != dataset.Raw_dataset.end(); si++)
        { 
        std::pair<value_type,value_type> result=calculate_partition_function_and_helmholtz_energy(&(*si));
        si->partition_function = result.first;
        si->helmholtz_energy   = result.second;
        }
      }
      
    /// \brief Get partition function
    inline void calculate_print_all_average_y(std::ostream& out)
      {
      typename sequence_dataset_type::sequence_info_deque_type::iterator si=dataset.Raw_dataset.begin();
      for(si=dataset.Raw_dataset.begin(); si != dataset.Raw_dataset.end(); si++)
        { 
        calculate_average_y(&(*si));
        get_partition_function_and_helmholtz_energy();
        out << si->sequence.first << "/" << si->sequence.second << " ";
        out << si->partition_function << " " << si->helmholtz_energy << " ";
        out << si->average_y << std::endl;
        }
      }



    /// \brief Calculate any missing thermal_equivalences
    inline bool calculate_missing_thermal_equivalences(void)
      {
      bool anymissing=false;
      typename sequence_dataset_type::sequence_info_deque_type::iterator si=dataset.Raw_dataset.begin();
      for(si=dataset.Raw_dataset.begin(); si != dataset.Raw_dataset.end(); si++)
        {
        if (si->thermal_equivalence == value_type())
          {
          anymissing=true;
          si->thermal_equivalence=calculate_thermal_equivalence(&(*si));
          si->thermal_index=pow(si->thermal_equivalence,thermal_equivalence_expoent);
          }
        }
      return anymissing;
      }

    /// \brief Predict temperature for a specific sequence.
    /// 
    /// If pmethod==0 it will try to use the most favourable prediction method.
    inline void predict(sequence_info_type& si, 
                 int pmethod=0)
      {
      size_t     l=si.BP_number;
      value_type s=si.salt_concentration["Na+"];
      
      switch (prediction_group)
          {
          case Salt: s=si.salt_concentration["Na+"]; break;
          case Ct  : s=si.species_concentration;   break;
          case Key : s=si.group_key;  break;
          }


      if (pmethod == 0) //Try to determine which method to use
        {
        typename sequence_dataset_type::iterator dt=dataset.Salt_length_dataset.find(s);
        if (dt == dataset.Salt_length_dataset.end()) pmethod=3; //unknown salt concentration
        else
          {
          if (dt->second.find(l) == dt->second.end()) pmethod=2; //unknown length
          else                                        pmethod=-1; 
          }
        }

      CERR_DEBUG(DRMO_PREDICT) << "pmethod=" << pmethod << " length=" << si.BP_number << " seq=" << si.Exact_duplex.formatted_string() << std::endl;

      si.prediction_method=pmethod;
      value_type local_a0=0.0,local_a1=0.0;
      value_type local_b00=0.0, local_b01=0.0, local_b10=0.0, local_b11=0.0;

      typename salt_index_map_type::const_iterator smt=b00.find(s);

      switch(pmethod)
        {
        case 3: local_b00=c000+c100*log(s);
                local_b10=c010+c110*log(s);
                local_b01=c001+c101*log(s);
                local_b11=c011+c111*log(s);

                local_a0=local_b00+local_b10*pow(l,N_expoent);
                local_a1=local_b01+local_b11*pow(l,N_expoent);
                break;
        case 2: if (smt == b00.end()) CERR_ERROR(ERRNRPFSC) << "predict(): No regression parameters for salt concentration " << s << std::endl;
                local_b00=b00[s];
                local_b10=b10[s];
                local_b01=b01[s];
                local_b11=b11[s];

                CERR_DEBUG(DRMO_PREDICT) << "local_b00=" << local_b00 << " local_b10=" << local_b10 << " local_b01=" << local_b01 << " local_b11=" << local_b11 << std::endl;
     
                local_a0=local_b00+local_b10*pow(l,N_expoent);
                local_a1=local_b01+local_b11*pow(l,N_expoent);
                break;
        case 1: local_a0=a0[s][l];
                local_a1=a1[s][l];
		break;
	case -1: local_a0=all0; local_a1=all1;
		break;
        case -2: local_a0=a0[s][0]; //all sequences at salt s
                local_a1=a1[s][0];
        }
      si.temperature.predicted=local_a0+local_a1*pow(si.thermal_equivalence,thermal_equivalence_expoent);
      if (si.Prediction_with_salt_correction) si.apply_salt_correction(Parameter_salt_concentration["Na+"],&(HTM->parameter_map)); 

      CERR_DEBUG(DRMO_PREDICT) << "local_a0=" << local_a0 << " local_a1=" << local_a1 << " thermal_equivalence=" << si.thermal_equivalence << std::endl;
      
      }

    /// \brief Calculates temperature predictions.
    ///
    /// This is a convenience function. First we check if there is any thermal
    /// equivalence parameter that needs to be calculated. Then we calculate
    /// the regression unless a file was given with regression data. If everything
    /// is OK we calculate the predicted temperatures.
    
    inline void predict_all(void)
      {
      //These two tests are made separately on purpose, otherwise there are conflicts when we use -O3
      if (calculate_missing_thermal_equivalences()) 
        if (Recalculate_regression) calculate_regression_abc(); 
        
      for(auto &si : dataset.Raw_dataset) predict(si,prediction_method);
      }
      
    inline void predict_and_print_sequentially(std::ostream& out)
      {
      bool first_line=true;
      std::ifstream xml;
      double multspconc=1;
      typename sequence_dataset_type::species_concentration_deque_type spconc;
      size_t number_of_sequences=0;
      for (auto xmlfile : dataset.Data_file_deque)
        {
        dataset.open_and_read_xml_header(xml,xmlfile);
        while(!xml.eof()) 
          {
          dataset.read_xml_line(xml,xmlfile,spconc,multspconc);
          number_of_sequences++;
          for(auto &si : dataset.Raw_dataset) 
            {
            predict_all();
          
            dataset.collect_tm_difference(si);
          
            if (first_line) si.print_head(out);
            first_line=false;
            si.print(out);
            }
          dataset.Raw_dataset.clear();
          }
        }
      dataset.calculate_tm_statistics();
      }

    template<class _SVD>
    inline void set_svd_variables(const _SVD &svd)
      {
      CERR_IERROR(IERRFSVDVNF) << "function set_svd_variables not defined" << std::endl;
      CERR_TERM
      }

    template<class _Mat>
    inline void set_svd_matrices(_Mat &A, _Mat &B)
      {
      CERR_IERROR(IERRFSVDMNF) << "function set_svd_matrices not defined" << std::endl;
      CERR_TERM
      }
      
    template<class _Vec>
    inline void adjust_svd_units(_Vec &x)
      {
      CERR_IERROR(IERRFASVDUNF) << "function adjust_svd_units not defined" << std::endl;
      CERR_TERM
      } 
  };

}
#endif
