/*
 * Project: MoleCuilder
 * Description: creates and alters molecular systems
 * Copyright (C)  2010 University of Bonn. All rights reserved.
 * Please see the LICENSE file or "Copyright notice" in builder.cpp for details.
 */

/*
 * MatrixUnittest.cpp
 *
 *  Created on: Jul 7, 2010
 *      Author: crueger
 */

// include config.h
#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include <cppunit/CompilerOutputter.h>
#include <cppunit/extensions/TestFactoryRegistry.h>
#include <cppunit/ui/text/TestRunner.h>

#include <cmath>

#include "MatrixUnittest.hpp"
#include "LinearAlgebra/Matrix.hpp"
#include "LinearAlgebra/Vector.hpp"
#include "Exceptions/NotInvertibleException.hpp"

#ifdef HAVE_TESTRUNNER
#include "UnitTestMain.hpp"
#endif /*HAVE_TESTRUNNER*/

// Registers the fixture into the 'registry'
CPPUNIT_TEST_SUITE_REGISTRATION( MatrixUnittest );

void MatrixUnittest::setUp(){
  zero = new Matrix();
  one = new Matrix();
  for(int i =NDIM;i--;){
    one->at(i,i)=1.;
  }
  full=new Matrix();
  for(int i=NDIM;i--;){
    for(int j=NDIM;j--;){
      full->at(i,j)=1.;
    }
  }
  diagonal = new Matrix();
  for(int i=NDIM;i--;){
    diagonal->at(i,i)=i+1.;
  }
  perm1 = new Matrix();
  perm1->column(0) = e1;
  perm1->column(1) = e3;
  perm1->column(2) = e2;


  perm2 = new Matrix();
  perm2->column(0) = e2;
  perm2->column(1) = e1;
  perm2->column(2) = e3;

  perm3 = new Matrix();
  perm3->column(0) = e2;
  perm3->column(1) = e3;
  perm3->column(2) = e1;

  perm4 = new Matrix();
  perm4->column(0) = e3;
  perm4->column(1) = e2;
  perm4->column(2) = e1;

  perm5 = new Matrix();
  perm5->column(0) = e3;
  perm5->column(1) = e1;
  perm5->column(2) = e2;

}
void MatrixUnittest::tearDown(){
  delete zero;
  delete one;
  delete full;
  delete diagonal;
  delete perm1;
  delete perm2;
  delete perm3;
  delete perm4;
  delete perm5;
}

void MatrixUnittest::AccessTest(){
  Matrix mat;
  for(int i=NDIM;i--;){
    for(int j=NDIM;j--;){
      CPPUNIT_ASSERT_EQUAL(mat.at(i,j),0.);
    }
  }
  int k=1;
  for(int i=NDIM;i--;){
    for(int j=NDIM;j--;){
      mat.at(i,j)=k++;
    }
  }
  k=1;
  for(int i=NDIM;i--;){
    for(int j=NDIM;j--;){
      CPPUNIT_ASSERT_EQUAL(mat.at(i,j),(double)k);
      ++k;
    }
  }
}

void MatrixUnittest::VectorTest(){
  Matrix mat;
  for(int i=NDIM;i--;){
    CPPUNIT_ASSERT_EQUAL(mat.row(i),zeroVec);
    CPPUNIT_ASSERT_EQUAL(mat.column(i),zeroVec);
  }
  CPPUNIT_ASSERT_EQUAL(mat.diagonal(),zeroVec);

  mat.one();
  CPPUNIT_ASSERT_EQUAL(mat.row(0),e1);
  CPPUNIT_ASSERT_EQUAL(mat.row(1),e2);
  CPPUNIT_ASSERT_EQUAL(mat.row(2),e3);
  CPPUNIT_ASSERT_EQUAL(mat.column(0),e1);
  CPPUNIT_ASSERT_EQUAL(mat.column(1),e2);
  CPPUNIT_ASSERT_EQUAL(mat.column(2),e3);

  Vector t1=Vector(1.,1.,1.);
  Vector t2=Vector(2.,2.,2.);
  Vector t3=Vector(3.,3.,3.);
  Vector t4=Vector(1.,2.,3.);

  mat.row(0)=t1;
  mat.row(1)=t2;
  mat.row(2)=t3;
  CPPUNIT_ASSERT_EQUAL(mat.row(0),t1);
  CPPUNIT_ASSERT_EQUAL(mat.row(1),t2);
  CPPUNIT_ASSERT_EQUAL(mat.row(2),t3);
  CPPUNIT_ASSERT_EQUAL(mat.column(0),t4);
  CPPUNIT_ASSERT_EQUAL(mat.column(1),t4);
  CPPUNIT_ASSERT_EQUAL(mat.column(2),t4);
  CPPUNIT_ASSERT_EQUAL(mat.diagonal(),t4);
  for(int i=NDIM;i--;){
    for(int j=NDIM;j--;){
      CPPUNIT_ASSERT_EQUAL(mat.at(i,j),i+1.);
    }
  }

  mat.column(0)=t1;
  mat.column(1)=t2;
  mat.column(2)=t3;
  CPPUNIT_ASSERT_EQUAL(mat.column(0),t1);
  CPPUNIT_ASSERT_EQUAL(mat.column(1),t2);
  CPPUNIT_ASSERT_EQUAL(mat.column(2),t3);
  CPPUNIT_ASSERT_EQUAL(mat.row(0),t4);
  CPPUNIT_ASSERT_EQUAL(mat.row(1),t4);
  CPPUNIT_ASSERT_EQUAL(mat.row(2),t4);
  CPPUNIT_ASSERT_EQUAL(mat.diagonal(),t4);
  for(int i=NDIM;i--;){
    for(int j=NDIM;j--;){
      CPPUNIT_ASSERT_EQUAL(mat.at(i,j),j+1.);
    }
  }
}

void MatrixUnittest::TransposeTest(){
  Matrix res;

  // transpose of unit is unit
  res.one();
  (const Matrix)res.transpose();
  CPPUNIT_ASSERT_EQUAL(res,*one);

  // transpose of transpose is same matrix
  res.zero();
  res.set(2,2, 1.);
  CPPUNIT_ASSERT_EQUAL(res.transpose().transpose(),res);
}

void MatrixUnittest::OperationTest(){
  Matrix res;

  res =(*zero) *(*zero);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*one);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*full);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*diagonal);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*perm1);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*perm2);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*perm3);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*perm4);
  CPPUNIT_ASSERT_EQUAL(res,*zero);
  res =(*zero) *(*perm5);
  CPPUNIT_ASSERT_EQUAL(res,*zero);

  res =(*one)*(*one);
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res =(*one)*(*full);
  CPPUNIT_ASSERT_EQUAL(res,*full);
  res =(*one)*(*diagonal);
  CPPUNIT_ASSERT_EQUAL(res,*diagonal);
  res =(*one)*(*perm1);
  CPPUNIT_ASSERT_EQUAL(res,*perm1);
  res =(*one)*(*perm2);
  CPPUNIT_ASSERT_EQUAL(res,*perm2);
  res =(*one)*(*perm3);
  CPPUNIT_ASSERT_EQUAL(res,*perm3);
  res =(*one)*(*perm4);
  CPPUNIT_ASSERT_EQUAL(res,*perm4);
  res =(*one)*(*perm5);
  CPPUNIT_ASSERT_EQUAL(res,*perm5);

  res = (*full)*(*perm1);
  CPPUNIT_ASSERT_EQUAL(res,*full);
  res = (*full)*(*perm2);
  CPPUNIT_ASSERT_EQUAL(res,*full);
  res = (*full)*(*perm3);
  CPPUNIT_ASSERT_EQUAL(res,*full);
  res = (*full)*(*perm4);
  CPPUNIT_ASSERT_EQUAL(res,*full);
  res = (*full)*(*perm5);
  CPPUNIT_ASSERT_EQUAL(res,*full);

  res = (*diagonal)*(*perm1);
  CPPUNIT_ASSERT_EQUAL(res.column(0),e1);
  CPPUNIT_ASSERT_EQUAL(res.column(1),3*e3);
  CPPUNIT_ASSERT_EQUAL(res.column(2),2*e2);
  res = (*diagonal)*(*perm2);
  CPPUNIT_ASSERT_EQUAL(res.column(0),2*e2);
  CPPUNIT_ASSERT_EQUAL(res.column(1),e1);
  CPPUNIT_ASSERT_EQUAL(res.column(2),3*e3);
  res = (*diagonal)*(*perm3);
  CPPUNIT_ASSERT_EQUAL(res.column(0),2*e2);
  CPPUNIT_ASSERT_EQUAL(res.column(1),3*e3);
  CPPUNIT_ASSERT_EQUAL(res.column(2),e1);
  res = (*diagonal)*(*perm4);
  CPPUNIT_ASSERT_EQUAL(res.column(0),3*e3);
  CPPUNIT_ASSERT_EQUAL(res.column(1),2*e2);
  CPPUNIT_ASSERT_EQUAL(res.column(2),e1);
  res = (*diagonal)*(*perm5);
  CPPUNIT_ASSERT_EQUAL(res.column(0),3*e3);
  CPPUNIT_ASSERT_EQUAL(res.column(1),e1);
  CPPUNIT_ASSERT_EQUAL(res.column(2),2*e2);
}

void MatrixUnittest::RotationTest(){
  Matrix res;
  Matrix inverse;

  // zero rotation angles yields unity matrix
  res.rotation(0,0,0);
  CPPUNIT_ASSERT_EQUAL(*one, res);

  // arbitrary rotation matrix has det = 1
  res.rotation(M_PI/3.,1.,M_PI/7.);
  CPPUNIT_ASSERT(fabs(fabs(res.determinant()) -1.) < MYEPSILON);

  // inverse is rotation matrix with negative angles
  res.rotation(M_PI/3.,0.,0.);
  inverse.rotation(-M_PI/3.,0.,0.);
  CPPUNIT_ASSERT_EQUAL(*one, res * inverse);

  // ... or transposed
  res.rotation(M_PI/3.,0.,0.);
  CPPUNIT_ASSERT_EQUAL(inverse, ((const Matrix) res).transpose());
}

void MatrixUnittest::InvertTest(){
  CPPUNIT_ASSERT_THROW(zero->invert(),NotInvertibleException);
  CPPUNIT_ASSERT_THROW(full->invert(),NotInvertibleException);

  Matrix res;
  res = (*one)*one->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res = (*diagonal)*diagonal->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res = (*perm1)*perm1->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res = (*perm2)*perm2->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res = (*perm3)*perm3->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res = (*perm4)*perm4->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
  res = (*perm5)*perm5->invert();
  CPPUNIT_ASSERT_EQUAL(res,*one);
}


void MatrixUnittest::DeterminantTest(){
  CPPUNIT_ASSERT_EQUAL(zero->determinant(),0.);
  CPPUNIT_ASSERT_EQUAL(one->determinant(),1.);
  CPPUNIT_ASSERT_EQUAL(diagonal->determinant(),6.);
  CPPUNIT_ASSERT_EQUAL(full->determinant(),0.);
  CPPUNIT_ASSERT_EQUAL(perm1->determinant(),-1.);
  CPPUNIT_ASSERT_EQUAL(perm2->determinant(),-1.);
  CPPUNIT_ASSERT_EQUAL(perm3->determinant(),1.);
  CPPUNIT_ASSERT_EQUAL(perm4->determinant(),-1.);
  CPPUNIT_ASSERT_EQUAL(perm5->determinant(),1.);
}

void MatrixUnittest::VecMultTest(){
  CPPUNIT_ASSERT_EQUAL((*zero)*e1,zeroVec);
  CPPUNIT_ASSERT_EQUAL((*zero)*e2,zeroVec);
  CPPUNIT_ASSERT_EQUAL((*zero)*e3,zeroVec);
  CPPUNIT_ASSERT_EQUAL((*zero)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*one)*e1,e1);
  CPPUNIT_ASSERT_EQUAL((*one)*e2,e2);
  CPPUNIT_ASSERT_EQUAL((*one)*e3,e3);
  CPPUNIT_ASSERT_EQUAL((*one)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*diagonal)*e1,e1);
  CPPUNIT_ASSERT_EQUAL((*diagonal)*e2,2*e2);
  CPPUNIT_ASSERT_EQUAL((*diagonal)*e3,3*e3);
  CPPUNIT_ASSERT_EQUAL((*diagonal)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*perm1)*e1,e1);
  CPPUNIT_ASSERT_EQUAL((*perm1)*e2,e3);
  CPPUNIT_ASSERT_EQUAL((*perm1)*e3,e2);
  CPPUNIT_ASSERT_EQUAL((*perm1)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*perm2)*e1,e2);
  CPPUNIT_ASSERT_EQUAL((*perm2)*e2,e1);
  CPPUNIT_ASSERT_EQUAL((*perm2)*e3,e3);
  CPPUNIT_ASSERT_EQUAL((*perm2)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*perm3)*e1,e2);
  CPPUNIT_ASSERT_EQUAL((*perm3)*e2,e3);
  CPPUNIT_ASSERT_EQUAL((*perm3)*e3,e1);
  CPPUNIT_ASSERT_EQUAL((*perm3)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*perm4)*e1,e3);
  CPPUNIT_ASSERT_EQUAL((*perm4)*e2,e2);
  CPPUNIT_ASSERT_EQUAL((*perm4)*e3,e1);
  CPPUNIT_ASSERT_EQUAL((*perm4)*zeroVec,zeroVec);

  CPPUNIT_ASSERT_EQUAL((*perm5)*e1,e3);
  CPPUNIT_ASSERT_EQUAL((*perm5)*e2,e1);
  CPPUNIT_ASSERT_EQUAL((*perm5)*e3,e2);
  CPPUNIT_ASSERT_EQUAL((*perm5)*zeroVec,zeroVec);

  Vector t = Vector(1.,2.,3.);
  CPPUNIT_ASSERT_EQUAL((*perm1)*t,Vector(1,3,2));
  CPPUNIT_ASSERT_EQUAL((*perm2)*t,Vector(2,1,3));
  CPPUNIT_ASSERT_EQUAL((*perm3)*t,Vector(3,1,2));
  CPPUNIT_ASSERT_EQUAL((*perm4)*t,Vector(3,2,1));
  CPPUNIT_ASSERT_EQUAL((*perm5)*t,Vector(2,3,1));
}
