/*
 * Project: MoleCuilder
 * Description: creates and alters molecular systems
 * Copyright (C)  2010 University of Bonn. All rights reserved.
 * Please see the LICENSE file or "Copyright notice" in builder.cpp for details.
 */

/*
 * MatrixUnitTest.cpp
 *
 *  Created on: Jul 7, 2010
 *      Author: crueger
 */

// include config.h
#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include <cppunit/CompilerOutputter.h>
#include <cppunit/extensions/TestFactoryRegistry.h>
#include <cppunit/ui/text/TestRunner.h>

#include <cmath>

#include "MatrixUnitTest.hpp"
#include "LinearAlgebra/Vector.hpp"
#include "LinearAlgebra/RealSpaceMatrix.hpp"
#include "Exceptions/NotInvertibleException.hpp"

#ifdef HAVE_TESTRUNNER
#include "UnitTestMain.hpp"
#endif /*HAVE_TESTRUNNER*/

// Registers the fixture into the 'registry'
CPPUNIT_TEST_SUITE_REGISTRATION( MatrixUnittest );

void MatrixUnittest::setUp(){
  zero = new RealSpaceMatrix();
  for(int i =NDIM;i--;) {
    for(int j =NDIM;j--;) {
      zero->at(i,j)=0.;
    }
  }
  one = new RealSpaceMatrix();
  for(int i =NDIM;i--;){
    one->at(i,i)=1.;
  }
  full=new RealSpaceMatrix();
  for(int i=NDIM;i--;){
    for(int j=NDIM;j--;){
      full->at(i,j)=1.;
    }
  }
  diagonal = new RealSpaceMatrix();
  for(int i=NDIM;i--;){
    diagonal->at(i,i)=i+1.;
  }
  perm1 = new RealSpaceMatrix();
  perm1->column(0) = unitVec[0];
  perm1->column(1) = unitVec[2];
  perm1->column(2) = unitVec[1];


  perm2 = new RealSpaceMatrix();
  perm2->column(0) = unitVec[1];
  perm2->column(1) = unitVec[0];
  perm2->column(2) = unitVec[2];

  perm3 = new RealSpaceMatrix();
  perm3->column(0) = unitVec[1];
  perm3->column(1) = unitVec[2];
  perm3->column(2) = unitVec[0];

  perm4 = new RealSpaceMatrix();
  perm4->column(0) = unitVec[2];
  perm4->column(1) = unitVec[1];
  perm4->column(2) = unitVec[0];

  perm5 = new RealSpaceMatrix();
  perm5->column(0) = unitVec[2];
  perm5->column(1) = unitVec[0];
  perm5->column(2) = unitVec[1];

}
void MatrixUnittest::tearDown(){
  delete zero;
  delete one;
  delete full;
  delete diagonal;
  delete perm1;
  delete perm2;
  delete perm3;
  delete perm4;
  delete perm5;
}

void MatrixUnittest::AccessTest(){
  RealSpaceMatrix mat;
  for(int i=NDIM;i--;){
    for(int j=NDIM;j--;){
      CPPUNIT_ASSERT_EQUAL(mat.at(i,j),0.);
    }
  }
  int k=1;
  for(int i=NDIM;i--;){
    for(int j=NDIM;j--;){
      mat.at(i,j)=k++;
    }
  }
  k=1;
  for(int i=NDIM;i--;){
    for(int j=NDIM;j--;){
      CPPUNIT_ASSERT_EQUAL(mat.at(i,j),(double)k);
      ++k;
    }
  }
}

void MatrixUnittest::VectorTest(){
  RealSpaceMatrix mat;
  for(int i=NDIM;i--;){
    CPPUNIT_ASSERT_EQUAL(mat.row(i),zeroVec);
    CPPUNIT_ASSERT_EQUAL(mat.column(i),zeroVec);
  }
  CPPUNIT_ASSERT_EQUAL(mat.diagonal(),zeroVec);

  mat.setIdentity();
  CPPUNIT_ASSERT_EQUAL(mat.row(0),unitVec[0]);
  CPPUNIT_ASSERT_EQUAL(mat.row(1),unitVec[1]);
  CPPUNIT_ASSERT_EQUAL(mat.row(2),unitVec[2]);
  CPPUNIT_ASSERT_EQUAL(mat.column(0),unitVec[0]);
  CPPUNIT_ASSERT_EQUAL(mat.column(1),unitVec[1]);
  CPPUNIT_ASSERT_EQUAL(mat.column(2),unitVec[2]);

  Vector t1=Vector(1.,1.,1.);
  Vector t2=Vector(2.,2.,2.);
  Vector t3=Vector(3.,3.,3.);
  Vector t4=Vector(1.,2.,3.);

  mat.row(0)=t1;
  mat.row(1)=t2;
  mat.row(2)=t3;
  CPPUNIT_ASSERT_EQUAL(mat.row(0),t1);
  CPPUNIT_ASSERT_EQUAL(mat.row(1),t2);
  CPPUNIT_ASSERT_EQUAL(mat.row(2),t3);
  CPPUNIT_ASSERT_EQUAL(mat.column(0),t4);
  CPPUNIT_ASSERT_EQUAL(mat.column(1),t4);
  CPPUNIT_ASSERT_EQUAL(mat.column(2),t4);
  CPPUNIT_ASSERT_EQUAL(mat.diagonal(),t4);
  for(int i=NDIM;i--;){
    for(int j=NDIM;j--;){
      CPPUNIT_ASSERT_EQUAL(mat.at(i,j),i+1.);
    }
  }

  mat.column(0)=t1;
  mat.column(1)=t2;
  mat.column(2)=t3;
  CPPUNIT_ASSERT_EQUAL(mat.column(0),t1);
  CPPUNIT_ASSERT_EQUAL(mat.column(1),t2);
  CPPUNIT_ASSERT_EQUAL(mat.column(2),t3);
  CPPUNIT_ASSERT_EQUAL(mat.row(0),t4);
  CPPUNIT_ASSERT_EQUAL(mat.row(1),t4);
  CPPUNIT_ASSERT_EQUAL(mat.row(2),t4);
  CPPUNIT_ASSERT_EQUAL(mat.diagonal(),t4);
  for(int i=NDIM;i--;){
    for(int j=NDIM;j--;){
      CPPUNIT_ASSERT_EQUAL(mat.at(i,j),j+1.);
    }
  }
}

void MatrixUnittest::TransposeTest(){
  RealSpaceMatrix res;

  // transpose of unit is unit
  res.setIdentity();
  res.transpose();
  CPPUNIT_ASSERT_EQUAL(res,*one);

  // transpose of transpose is same matrix
  res.setZero();
  res.set(2,2, 1.);
  CPPUNIT_ASSERT_EQUAL(res.transpose().transpose(),res);
}

void MatrixUnittest::OperationTest(){
  RealSpaceMatrix res;

  res =(*zero) *(*zero);
  std::cout << *zero << " times " << *zero << " is " << res << std::endl;
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*one);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*full);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*diagonal);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*perm1);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*perm2);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*perm3);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*perm4);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*perm5);
  CPPUNIT_ASSERT_EQUAL(res,*zero);

  res =(*one)*(*one);
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res =(*one)*(*full);
  CPPUNIT_ASSERT_EQUAL(res,*full);
  res =(*one)*(*diagonal);
  CPPUNIT_ASSERT_EQUAL(res,*diagonal);
  res =(*one)*(*perm1);
  CPPUNIT_ASSERT_EQUAL(res,*perm1);
  res =(*one)*(*perm2);
  CPPUNIT_ASSERT_EQUAL(res,*perm2);
  res =(*one)*(*perm3);
  CPPUNIT_ASSERT_EQUAL(res,*perm3);
  res =(*one)*(*perm4);
  CPPUNIT_ASSERT_EQUAL(res,*perm4);
  res =(*one)*(*perm5);
  CPPUNIT_ASSERT_EQUAL(res,*perm5);

  res = (*full)*(*perm1);
  CPPUNIT_ASSERT_EQUAL(res,*full);
  res = (*full)*(*perm2);
  CPPUNIT_ASSERT_EQUAL(res,*full);
  res = (*full)*(*perm3);
  CPPUNIT_ASSERT_EQUAL(res,*full);
  res = (*full)*(*perm4);
  CPPUNIT_ASSERT_EQUAL(res,*full);
  res = (*full)*(*perm5);
  CPPUNIT_ASSERT_EQUAL(res,*full);

  res = (*diagonal)*(*perm1);
  CPPUNIT_ASSERT_EQUAL(res.column(0),unitVec[0]);
  CPPUNIT_ASSERT_EQUAL(res.column(1),3*unitVec[2]);
  CPPUNIT_ASSERT_EQUAL(res.column(2),2*unitVec[1]);
  res = (*diagonal)*(*perm2);
  CPPUNIT_ASSERT_EQUAL(res.column(0),2*unitVec[1]);
  CPPUNIT_ASSERT_EQUAL(res.column(1),unitVec[0]);
  CPPUNIT_ASSERT_EQUAL(res.column(2),3*unitVec[2]);
  res = (*diagonal)*(*perm3);
  CPPUNIT_ASSERT_EQUAL(res.column(0),2*unitVec[1]);
  CPPUNIT_ASSERT_EQUAL(res.column(1),3*unitVec[2]);
  CPPUNIT_ASSERT_EQUAL(res.column(2),unitVec[0]);
  res = (*diagonal)*(*perm4);
  CPPUNIT_ASSERT_EQUAL(res.column(0),3*unitVec[2]);
  CPPUNIT_ASSERT_EQUAL(res.column(1),2*unitVec[1]);
  CPPUNIT_ASSERT_EQUAL(res.column(2),unitVec[0]);
  res = (*diagonal)*(*perm5);
  CPPUNIT_ASSERT_EQUAL(res.column(0),3*unitVec[2]);
  CPPUNIT_ASSERT_EQUAL(res.column(1),unitVec[0]);
  CPPUNIT_ASSERT_EQUAL(res.column(2),2*unitVec[1]);
}

void MatrixUnittest::RotationTest(){
  RealSpaceMatrix res;
  RealSpaceMatrix inverse;

  // zero rotation angles yields unity matrix
  res.setRotation(0,0,0);
  CPPUNIT_ASSERT_EQUAL(*one, res);

  // arbitrary rotation matrix has det = 1
  res.setRotation(M_PI/3.,1.,M_PI/7.);
  CPPUNIT_ASSERT(fabs(fabs(res.determinant()) -1.) < MYEPSILON);

  // inverse is rotation matrix with negative angles
  res.setRotation(M_PI/3.,0.,0.);
  inverse.setRotation(-M_PI/3.,0.,0.);
  CPPUNIT_ASSERT_EQUAL(*one, res * inverse);

  // ... or transposed
  res.setRotation(M_PI/3.,0.,0.);
  CPPUNIT_ASSERT_EQUAL(inverse, ((const RealSpaceMatrix) res).transpose());
}

void MatrixUnittest::InvertTest(){
  CPPUNIT_ASSERT_THROW(zero->invert(),NotInvertibleException);
  CPPUNIT_ASSERT_THROW(full->invert(),NotInvertibleException);

  RealSpaceMatrix res;
  res = (*one)*one->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res = (*diagonal)*diagonal->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res = (*perm1)*perm1->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res = (*perm2)*perm2->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res = (*perm3)*perm3->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res = (*perm4)*perm4->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res = (*perm5)*perm5->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
}


void MatrixUnittest::DeterminantTest(){
  CPPUNIT_ASSERT_EQUAL(zero->determinant(),0.);
  CPPUNIT_ASSERT_EQUAL(one->determinant(),1.);
  CPPUNIT_ASSERT_EQUAL(diagonal->determinant(),6.);
  CPPUNIT_ASSERT_EQUAL(full->determinant(),0.);
  CPPUNIT_ASSERT_EQUAL(perm1->determinant(),-1.);
  CPPUNIT_ASSERT_EQUAL(perm2->determinant(),-1.);
  CPPUNIT_ASSERT_EQUAL(perm3->determinant(),1.);
  CPPUNIT_ASSERT_EQUAL(perm4->determinant(),-1.);
  CPPUNIT_ASSERT_EQUAL(perm5->determinant(),1.);
}

void MatrixUnittest::VecMultTest(){
  CPPUNIT_ASSERT_EQUAL((*zero)*unitVec[0],zeroVec);
  CPPUNIT_ASSERT_EQUAL((*zero)*unitVec[1],zeroVec);
  CPPUNIT_ASSERT_EQUAL((*zero)*unitVec[2],zeroVec);
  CPPUNIT_ASSERT_EQUAL((*zero)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*one)*unitVec[0],unitVec[0]);
  CPPUNIT_ASSERT_EQUAL((*one)*unitVec[1],unitVec[1]);
  CPPUNIT_ASSERT_EQUAL((*one)*unitVec[2],unitVec[2]);
  CPPUNIT_ASSERT_EQUAL((*one)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*diagonal)*unitVec[0],unitVec[0]);
  CPPUNIT_ASSERT_EQUAL((*diagonal)*unitVec[1],2*unitVec[1]);
  CPPUNIT_ASSERT_EQUAL((*diagonal)*unitVec[2],3*unitVec[2]);
  CPPUNIT_ASSERT_EQUAL((*diagonal)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*perm1)*unitVec[0],unitVec[0]);
  CPPUNIT_ASSERT_EQUAL((*perm1)*unitVec[1],unitVec[2]);
  CPPUNIT_ASSERT_EQUAL((*perm1)*unitVec[2],unitVec[1]);
  CPPUNIT_ASSERT_EQUAL((*perm1)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*perm2)*unitVec[0],unitVec[1]);
  CPPUNIT_ASSERT_EQUAL((*perm2)*unitVec[1],unitVec[0]);
  CPPUNIT_ASSERT_EQUAL((*perm2)*unitVec[2],unitVec[2]);
  CPPUNIT_ASSERT_EQUAL((*perm2)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*perm3)*unitVec[0],unitVec[1]);
  CPPUNIT_ASSERT_EQUAL((*perm3)*unitVec[1],unitVec[2]);
  CPPUNIT_ASSERT_EQUAL((*perm3)*unitVec[2],unitVec[0]);
  CPPUNIT_ASSERT_EQUAL((*perm3)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*perm4)*unitVec[0],unitVec[2]);
  CPPUNIT_ASSERT_EQUAL((*perm4)*unitVec[1],unitVec[1]);
  CPPUNIT_ASSERT_EQUAL((*perm4)*unitVec[2],unitVec[0]);
  CPPUNIT_ASSERT_EQUAL((*perm4)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*perm5)*unitVec[0],unitVec[2]);
  CPPUNIT_ASSERT_EQUAL((*perm5)*unitVec[1],unitVec[0]);
  CPPUNIT_ASSERT_EQUAL((*perm5)*unitVec[2],unitVec[1]);
  CPPUNIT_ASSERT_EQUAL((*perm5)*zeroVec,zeroVec);

  Vector t = Vector(1.,2.,3.);
  CPPUNIT_ASSERT_EQUAL((*perm1)*t,Vector(1,3,2));
  CPPUNIT_ASSERT_EQUAL((*perm2)*t,Vector(2,1,3));
  CPPUNIT_ASSERT_EQUAL((*perm3)*t,Vector(3,1,2));
  CPPUNIT_ASSERT_EQUAL((*perm4)*t,Vector(3,2,1));
  CPPUNIT_ASSERT_EQUAL((*perm5)*t,Vector(2,3,1));
}
