// Copyright (C) 2016 EDF
// All Rights Reserved
// This code is published under the GNU Lesser General Public License (GNU LGPL)
#ifdef USE_MPI
#include <boost/mpi.hpp>
#include <boost/mpi/collectives.hpp>
#include "StOpt/core/parallelism/all_gatherv.hpp"
#endif
#include "StOpt/semilagrangien/SimulateStepSemilagrang.h"
#include "StOpt/core/utils/eigenGeners.h"
#include "StOpt/core/grids/FullGridGeners.h"
#include "StOpt/core/grids/RegularSpaceGridGeners.h"
#include "StOpt/core/grids/GeneralSpaceGridGeners.h"
#include "StOpt/core/grids/SparseSpaceGridNoBoundGeners.h"
#include "StOpt/core/grids/SparseSpaceGridBoundGeners.h"

using namespace std;
using namespace StOpt;
using namespace Eigen;

SimulateStepSemilagrang::SimulateStepSemilagrang(gs::BinaryFileArchive &p_ar,  const int &p_iStep,  const string &p_name, const   shared_ptr<SpaceGrid>   &p_gridNext,
        const  shared_ptr<StOpt::OptimizerSLBase > &p_pOptimize):
    m_gridNext(p_gridNext), m_specInterp(p_pOptimize->getNbRegime()), m_semiLag(p_pOptimize->getNbRegime()), m_pOptimize(p_pOptimize)
{
    vector< shared_ptr<ArrayXd>  > vecFunctionNext;
    string valDump = p_name + "Val";
    gs::Reference<decltype(vecFunctionNext)>(p_ar, valDump.c_str(), boost::lexical_cast<string>(p_iStep).c_str()).restore(0, &vecFunctionNext);
    // create interpolator and semi lagrangian
    for (size_t ireg = 0; ireg <  vecFunctionNext.size(); ++ireg)
    {
        m_specInterp[ireg] = m_gridNext->createInterpolatorSpectral(*vecFunctionNext[ireg]);
        m_semiLag[ireg] = make_shared<SemiLagrangEspCond>(m_specInterp[ireg], m_gridNext->getExtremeValues(), m_pOptimize->getBModifVol());
    }

}

void SimulateStepSemilagrang::oneStep(const ArrayXXd   &p_gaussian, ArrayXXd &p_statevector, ArrayXi &p_iReg, ArrayXXd  &p_phiInOut) const
{
#ifdef USE_MPI
    boost::mpi::communicator world;
    int rank = world.rank();
    int nbProc = world.size();    // parallelism
    int nsimPProc = (int)(p_statevector.cols() / nbProc);
    int nRestSim = p_statevector.cols() % nbProc;
    int iFirstSim = rank * nsimPProc + (rank < nRestSim ? rank : nRestSim);
    int iLastSim  = iFirstSim + nsimPProc + (rank < nRestSim ? 1 : 0);
    ArrayXXd statePerSim(p_statevector.rows(), iLastSim - iFirstSim);
    ArrayXXd valueFunctionPerSim(p_phiInOut.rows(), iLastSim - iFirstSim);
    // spread calculations on processors
    int  is = 0 ;
#ifdef _OPENMP
    #pragma omp parallel for  private(is)
#endif
    for (is = iFirstSim; is <  iLastSim; ++is)
    {
        ArrayXd phiInPt(m_semiLag.size());
        for (size_t iReg = 0; iReg < m_semiLag.size(); ++iReg)
            phiInPt[iReg] = m_specInterp[iReg]->apply(p_statevector.col(is));
        m_pOptimize->stepSimulate(*m_gridNext, m_semiLag, p_statevector.col(is), p_iReg(is), p_gaussian.col(is), phiInPt, p_phiInOut.col(is));
        statePerSim.col(is - iFirstSim) = p_statevector.col(is);
        valueFunctionPerSim.col(is - iFirstSim) = p_phiInOut.col(is);
    }
    boost::mpi::all_gatherv<double>(world, statePerSim.data(), statePerSim.size(), p_statevector.data());
    boost::mpi::all_gatherv<double>(world, valueFunctionPerSim.data(), valueFunctionPerSim.size(), p_phiInOut.data());
#else
    int is ;
#ifdef _OPENMP
    #pragma omp parallel for  private(is)
#endif
    for (is = 0; is <  p_statevector.cols(); ++is)
    {
        ArrayXd phiInPt(m_semiLag.size());
        for (size_t iReg = 0; iReg < m_semiLag.size(); ++iReg)
            phiInPt[iReg] = m_specInterp[iReg]->apply(p_statevector.col(is));
        m_pOptimize->stepSimulate(*m_gridNext, m_semiLag, p_statevector.col(is), p_iReg(is), p_gaussian.col(is), phiInPt, p_phiInOut.col(is));
    }
#endif
}
