//-----------------------------------------------------------
// File and Version Information:
// $Id$
//
// Description:
//      Implementation of class Kalman
//      see Kalman.hh for details
//
// Environment:
//      Software developed for the PANDA Detector at FAIR.
//
// Author List:
//      Sebastian Neubert    TUM            (original author)
//
//
//-----------------------------------------------------------

// Panda Headers ----------------------

// This Class' Header ------------------
#include "Kalman.h"

// C/C++ Headers ----------------------
#include "assert.h"
#include <iostream>
#include <sstream>
#include "TMath.h"
// Collaborating Class Headers --------
#include "Track.h"
#include "AbsRecoHit.h"
#include "AbsTrackRep.h"
#include "FitParams.h"
#include "FitterExceptions.h"
  
// Class Member definitions -----------

 
Kalman::Kalman(){;}
  
Kalman::~Kalman(){;}

void
Kalman::processTrack(Track* trk){
  trk->setNextHitToFit(0);
  continueTrack(trk);
}

void
Kalman::continueTrack(Track* trk){
  //loop over hits
  //std::cout<<"Kalman::processTrack::Starting track"<<std::endl;
  unsigned int nhits=trk->getNumHits();
  unsigned int starthit=trk->getNextHitToFit();
  if(starthit==nhits) {
    std::cout<<"Kalman::processTrack::Already at end of Track!"<<std::endl;
    return;
  }
  int nreps=trk->getNumReps();
  //  for(int i=0;i<nreps;++i)trk->getTrackRep(i)->Print();
  for(int ihit=starthit; ihit<nhits; ++ihit){
    AbsRecoHit* ahit=trk->getHit(ihit);
    //ahit->Print();
    // loop over reps
    for(int irep=0; irep<nreps; ++irep){
	  //std::cout << "Process hit #" << ihit << " of rep #" << irep << std::endl;
      AbsTrackRep* arep=trk->getTrackRep(irep);
	  if(arep->getStatusFlag()==0) { 
		try {
		  //std::cout<<".";
		  processHit(ahit,arep,ihit);
		}
		catch(FitterException& e) {
		  std::cout << e.what() << std::endl;
		  e.info();
		  arep->setStatusFlag(1);
		  continue; // go to next rep immediately
		  //		throw e; //rethrow
		}	
	  }
    }
  }
  trk->setNextHitToFit(nhits);
  //std::cout<<"Track finished"<<std::endl;

  /*
  for(int i=0;i<nhits;i++) {
	FitParams* par = trk->getTrackRep(0)->getFitParams();
	std::cout << "ABC" << i << std::endl;
	TMatrixT<double> mat;
	par->getfJacobian(i,mat);
	mat.Print();
  }
  */

  //  smoothing(trk);

  //  std::cout<<"Final state and cov fro trkRep[0]:";
  //  trk->getTrackRep(0)->getState().Print();
  //  trk->getTrackRep(0)->getCov().Print();
}


double
Kalman::getChi2Hit(AbsRecoHit* hit, AbsTrackRep* rep)
{
  // get prototypes for matrices
  int repDim=rep->getDim();
  TMatrixT<double> state(repDim,1);
  TMatrixT<double> cov(repDim,repDim);;
  TMatrixT<double> jacobian(repDim,repDim);
  DetPlane pl=hit->getDetPlane(rep);
  rep->predict(pl,state,cov,jacobian);
  hit->setHMatrix(rep,state);
  //hit->setHMatrix(s,pred,);
  TMatrixT<double> H=hit->getHMatrix();
  // get hit covariances  
  TMatrixT<double> V=hit->getHitCov(pl);
  TMatrixT<double> r=hit->residualVector(rep,state);
  // residuals covariances:R=(V - HCH^T)
  TMatrixT<double> R(V);
  TMatrixT<double> covsum1(cov,TMatrixT<double>::kMultTranspose,H);
  TMatrixT<double> covsum(H,TMatrixT<double>::kMult,covsum1);

  R+=covsum; // note minus sign!

  // chisq= r^TR^(-1)r
  double det=0;
  TMatrixT<double> Rsave(R);
  R.Invert(&det);
  if(TMath::IsNaN(det))std::cout<<"predicted residual: det nan!"<<std::endl;
  TMatrixT<double> chisq=r.T()*(R*r); // note: .T() will change r!
  assert(chisq.GetNoElements()==1);
  return chisq[0][0];
}


  
void
Kalman::processHit(AbsRecoHit* hit, AbsTrackRep* rep,int hitIndex){

  //get fitParamsObject from trackRep in order to fill statePre,stateFilt,
  //covPred,covFilt for this hit
  FitParams* params =  rep->getFitParams();

  // make prediction ------------------------------------
  // get prototypes for matrices
  int repDim=rep->getDim();
  TMatrixT<double> state(repDim,1);
  TMatrixT<double> cov(repDim,repDim);;
  TMatrixT<double> jacobian(repDim,repDim);
  //double s=0;

  //rep->getState().Print();
  //rep->getCov().Print();
  //hit->getHitCoord(s,rep).Print();
  //hit->getHitCov(s,rep).Print();
  //std::cout << hit->getS() << std::endl;  

  // get the virtual detector plane
  DetPlane pl=hit->getDetPlane(rep);
  //pl.Print();

  // let the rep do the prediction
  //std::cout<<"++++++++++++++ do prediction: ++++++++++++++++"<<std::endl;
  rep->predict(pl,state,cov,jacobian);
  //state.Print();
  //cov.Print();

  TMatrixT<double> origcov=rep->getCov();

  for(int i=0; i<5; ++i){
    for(int j=0;j<5; ++j){
      if(cov[i][j]*origcov[i][j]<0){
	std::cout<<"AT HIT#"<<hitIndex<<" COV ELEMENT "<<i<<","
		 <<j<<" CHANGED ITS SIGN!"<<std::endl;
	//	cov[i][j]=-cov[i][j];                 // CHECK
      }
    }
  }


  //std::cout<<"++++++++++++++++++++++++++++++++++++++++++++++"<<std::endl;

  params->addfStatePred(hitIndex,state);
  params->addfCovPred(hitIndex,cov);
  params->addfJacobian(hitIndex,jacobian);
  // create a predicted trackrep
  //  AbsTrackRep* pred=rep->prototype();
  //  pred->setState(state); 
  // pred->setCov(cov);
  //pred->setS(s);

  // get H Matrix at prediction
  hit->setHMatrix(rep,state);
  //hit->setHMatrix(s,pred,);
  TMatrixT<double> H=hit->getHMatrix();
  // get hit covariances  
  TMatrixT<double> V=hit->getHitCov(pl);
  
  // calculate kalman gain ------------------------------
  TMatrixT<double> Gain(gain(cov,V,H));

  // --------- CHECK : q/p => 1/p -------------
  TMatrixT<double> res(2,1);
  TMatrixT<double> update(5,1);
  if(state[0][0] < 0)
    {
      TMatrixT<double> tmpstate(5,1);
      tmpstate = state;
      tmpstate[0][0] = fabs(tmpstate[0][0]);
      res=hit->residualVector(rep,tmpstate);
      update=Gain*res;
      update[0][0] = -update[0][0];
    }
  else
    {
      res=hit->residualVector(rep,state);
      update=Gain*res;
    }
  // ------------------------------------------

  // TMatrixT<double> res=hit->residualVector(rep,state);
  // calculate update -----------------------------------
  //  TMatrixT<double> update=Gain*res;
  update.Print();
 
  std::cout<<"+++++++++++++++Kalman Updating:"<<std::endl;
  std::cout<<"Hit:";
  hit->getHitCoord(pl).Print();
  V.Print();
  std::cout<<"State:"<<std::endl;
  state.Print();
  cov.Print();
  std::cout<<"Res:"<<std::endl;
  res.Print();
  std::cout<<"Gain:";
  Gain.Print();
  std::cout<<"Udpate:"<<std::endl;
  update.Print();  
  std::cout<<"Momentum Udate = U1= (C14R11+C15R21)res1 + (C14R12+C15R22)res2"<<std::endl;

  // FOR DEBUG ONLY::
  // calculate covsum (V + HCH^T)
  TMatrixT<double> cvsm1(cov,TMatrixT<double>::kMultTranspose,H);
  TMatrixT<double> cvsm(H,TMatrixT<double>::kMult,cvsm1);
  //TMatrixT<double> cvsm=H*(cov*H.T());
  std::cout<<"Cvsm==";
  cvsm.Print();

  cvsm+=V;
  
  // invert
  double det1=0;
  cvsm.Invert(&det1);

  std::cout<<"("<<cov[0][3]<<"*"<<cvsm[0][0]<<"+"<<cov[0][4]<<"*"<<cvsm[0][1]
	   <<")"<<"*"<<res[0][0]
	   <<"+("<<cov[0][3]<<"*"<<cvsm[1][0]<<"+"<<cov[0][4]<<"*"<<cvsm[1][1]
	   <<")*"<<res[1][0]<<std::endl;
  std::cout<<"U1="<<(cov[0][3]*cvsm[0][0]+cov[0][4]*cvsm[0][1])*res[0][0]
	   +(cov[0][3]*cvsm[1][0]+cov[0][4]*cvsm[1][1])*res[1][0]<<std::endl;

  /// END DEBUG OPERATIONS



  state+=update; // prediction overwritten!
  std::cout<<"Updated State:"<<std::endl;
  state.Print();

  //std::cout<<"++++++++after update++++++++++++++++++++++++++++++++++++"<<std::endl;
  // cleanup 
  //  delete pred;

  cov-=Gain*(H*cov);
  cov.Print();

  // update TrackRep
  rep->setState(state);
  rep->setCov(cov);
  rep->setReferencePlane(pl);
  params->addfStateFilt(hitIndex,state);
  params->addfCovFilt(hitIndex,cov);

  
  // calculate filtered chisq
  // filtered residual
  TMatrixT<double> r=hit->residualVector(rep,state);
  // residuals covariances:R=(V - HCH^T)
  TMatrixT<double> R(V);

  TMatrixT<double> covsum1(cov,TMatrixT<double>::kMultTranspose,H);
  TMatrixT<double> covsum(H,TMatrixT<double>::kMult,covsum1);

  R+=covsum; // note minus sign!

  // chisq= r^TR^(-1)r
  double det=0;
  TMatrixT<double> Rsave(R);
  R.Invert(&det);
  if(TMath::IsNaN(det))std::cout<<"filtered residual: det nan!"<<std::endl;
  TMatrixT<double> chisq=r.T()*(R*r); // note: .T() will change r!
  assert(chisq.GetNoElements()==1);
  //  cov.Print();
  //std::cout << "chi2 incr: " << chisq[0][0] << std::endl;
  rep->addChiSqu(chisq[0][0]);
  if(TMath::IsNaN(chisq[0][0])){
	FitterException exc("chi2 is nan",__LINE__,__FILE__);
	std::vector<double> numbers;
	numbers.push_back(det);
	exc.setNumbers("det",numbers);
	std::vector< TMatrixT<double> > matrices;
	matrices.push_back(r);
	matrices.push_back(V);
	matrices.push_back(Rsave);
	matrices.push_back(R);
	matrices.push_back(state);
	matrices.push_back(cov);
	matrices.push_back(Gain);
	exc.setMatrices("r, V, Rsave, R, state, cov and Gain",matrices);
    throw exc;
  }
}


TMatrixT<double>
Kalman::gain(const TMatrixT<double>& cov, 
					 const TMatrixT<double>& HitCov,
					 const TMatrixT<double>& H){

// calculate covsum (V + HCH^T)
  TMatrixT<double> covsum1(cov,TMatrixT<double>::kMultTranspose,H);
  TMatrixT<double> covsum(H,TMatrixT<double>::kMult,covsum1);
  //TMatrixT<double> covsum=H*(cov*H.T());
  std::cout<<"Covsum==";
  covsum.Print();

  covsum+=HitCov;
  
  // invert
  double det=0;
  covsum.Invert(&det);
  if(TMath::IsNaN(det)) throw FitterException("Kalman Gain: det of covum is nan",__LINE__,__FILE__);
  if(det==0){
	FitterException exc("cannot invert covsum in Kalman Gain - det=0",
						__LINE__,__FILE__);
	std::vector< TMatrixT<double> > matrices;
	matrices.push_back(cov);
	matrices.push_back(HitCov);
	matrices.push_back(covsum1);
	matrices.push_back(covsum);
	exc.setMatrices("cov, HitCov, covsum1 and covsum",matrices);
    throw exc;

  }
  // calculate gain
  TMatrixT<double> gain1(H,TMatrixT<double>::kTransposeMult,covsum);
  TMatrixT<double> gain(cov,TMatrixT<double>::kMult,gain1);

  return gain;
}

void Kalman::smoothing(Track* trk) {
  std::cout<<"Kalman::smoothing"<<std::endl;
  int nhits=trk->getNumHits();
  int nreps=trk->getNumReps();

  //for the last hit set the fin state and cov to the filt values
  for(int irep=0; irep<nreps; irep++){
	AbsTrackRep* rep=trk->getTrackRep(irep);
	FitParams* par = rep->getFitParams();
	TMatrixT<double> lastStateFilt;
	TMatrixT<double> lastCovFilt;
	par->getfStateFilt(nhits-1,lastStateFilt);
	par->getfCovFilt(nhits-1,lastCovFilt);
	par->addStateFin(nhits-1,lastStateFilt);
	par->addCovFin(nhits-1,lastCovFilt);
  }


  //go backwards over all the hit indices from n-1 to 0
  for(int k=nhits-2; k>=0; k--){
    for(int irep=0; irep<nreps; irep++){
      AbsTrackRep* rep=trk->getTrackRep(irep);
	  int repDim = rep->getDim();
	  FitParams* par = rep->getFitParams();

	  TMatrixT<double> p_k_k;
	  TMatrixT<double> p_kplus1_k;
	  TMatrixT<double> p_kplus1_n;
	  TMatrixT<double> C_k_k;
	  TMatrixT<double> C_kplus1_k;
	  TMatrixT<double> C_kplus1_n;
	  TMatrixT<double> F_kplus1;	  
	  par->getfStatePred(k+1,p_kplus1_k);
	  par->getfStateFilt(k,p_k_k);
	  par->getStateFin(k+1,p_kplus1_n);
	  par->getfCovPred(k+1,C_kplus1_k);
	  par->getfCovFilt(k,C_k_k);
	  par->getCovFin(k+1,C_kplus1_n);
	  par->getfJacobian(k+1,F_kplus1);

	  TMatrixT<double> C_kplus1_k_inv(C_kplus1_k);
	  double det = 0.;
// 	  std::cout << "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$" << std::endl;
// 	  C_kplus1_k_inv.Print();
	  C_kplus1_k_inv.Invert(&det);
// 	  C_kplus1_k_inv.Print();
// 	  C_kplus1_k_inv.Invert(&det);
// 	  C_kplus1_k_inv.Print();
// 	  C_kplus1_k_inv.Invert(&det);
	  TMatrixT<double> A1(F_kplus1,
						  TMatrixT<double>::kTransposeMult,
						  C_kplus1_k_inv);
	  TMatrixT<double> A(C_k_k,TMatrixT<double>::kMult,A1);

	  TMatrixT<double> stateFin(p_k_k-A*
								(p_kplus1_k-p_kplus1_n)
								);
	  TMatrixT<double> covFin2( C_kplus1_k-C_kplus1_n,
								TMatrixT<double>::kMultTranspose,
								A );
	  TMatrixT<double> covFin( C_k_k-(A*covFin2) );
	  
	  par->addStateFin(k,stateFin);
	  par->addCovFin(k,covFin);
	 
    }
  }


}






