// Copyright 1999, 2002 Robert Buff
// Contact: http://robertbuff.com/uvm
//
// This file is part of Mtg-Book.
//
// Mtg-Book is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published
// by the Free Software Foundation; either version 2 of the License,
// or (at your option) any later version.
//
// Mtg-Book is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Mtg-Book; if not, write to the 
//
// Free Software Foundation, Inc.
// 59 Temple Place, Suite 330
// Boston, MA 02111-1307
// USA

#include "MtgIncl.h"
#include "MtgVasicekModel.h"
#include "MtgBootstrap.h"
#include "MtgDrift.h"
#include "MtgFactor.h"
#include "MtgPathSpace.h"
#include "MtgSamplePath.h"
#include "MtgShortRateEngine.h"
#include "MtgStepDrift.h"

MTG_BEGIN_NAMESPACE


//
//   t E v o l u t i o n
//

tVasicekModel::tEvolution::tEvolution( const tVasicekModel& Model,
    const tPathSpace& PathSpace )
    : super( Model, PathSpace )

{
    m_Proxy.m_pParent = this;
}


//
//   p r e p a r e
//

tRetCode tVasicekModel::tEvolution::prepare()

{
    m_Mean.numOfElems( m_PathSpace.numOfSamples() );

    const tVasicekModel& VModel =
        static_cast<const tVasicekModel&>( m_Model );

    setInstantiation( 0 );

        // compute mean reversion level for first factor,
        // without multiplying by theta yet

    tRetCode nRet;

    if( VModel.m_pMeanDrift != 0 ) {
        const tDrift& Drift = *VModel.m_pMeanDrift;

        if( ! Drift.sameDateBaseAs( m_PathSpace ) ) {
                // need to convert first
            tStepDrift NewDrift;
            static_cast<tDateBase&>( NewDrift ) = m_PathSpace;

            if( ( nRet = NewDrift.addDrift( Drift ) ) != OK ||
                ( nRet = NewDrift.finalize() ) != OK ) {
                return nRet;
            }
            for( int i = 0; i < m_Mean.numOfElems(); ++i )
                m_Mean[i] = NewDrift.forwardSI( i );
        }
        else {
            for( int i = 0; i < m_Mean.numOfElems(); ++i )
                m_Mean[i] = Drift.forwardSI( i );
        }
    }
    else
    if( VModel.m_pMeanBoot != 0 ) {
            // create temporary tStepDrift object:
        tStepDrift Drift;
        static_cast<tDateBase&>( Drift ) = m_PathSpace;

            // load drift object with interest spline:
        if( ( nRet = Drift.addSpline(
                        VModel.m_pMeanBoot->spline() ) ) != OK ||
            ( nRet = Drift.finalize() ) != OK ) {
            return nRet;
        }

        for( int i = 0; i < m_Mean.numOfElems(); ++i )
            m_Mean[i] = Drift.forwardSI( i );
    }
    else {
            // constant mean reversion level
        m_Mean.fill( VModel.m_gConstMean );
    }

    return OK;
}


//
//   s e t I n s t a n t i a t i o n
//

void tVasicekModel::tEvolution::setInstantiation( int nInstantiation )

{
    m_Model.setInstantiation( m_FactorSelect, nInstantiation );
}


//
//   p r o g r e s s
//

tRetCode tVasicekModel::tEvolution::progress(
    const tShortRateTermStruct& ShortRate, tHeap<double>& Forward )

{
    const tSamplePath& Path = path();

    int nNumOfSamples = Path.numOfSamples();
    int nNumOfFactors = Path.numOfFactors();

    const tVasicekModel& VModel =
        static_cast<const tVasicekModel&>( m_Model );

    MTG_ASSERT( nNumOfSamples == m_Mean.numOfElems() );
    MTG_ASSERT( nNumOfFactors == m_Model.numOfFactors() );

    Forward.numOfElems( nNumOfSamples );

    if( nNumOfSamples > 0 ) {
        double dt = ShortRate.dtSI();
        double dt2 = sqrt( dt );

        tHeap<double> A, B, C;
        A.numOfElems( nNumOfFactors );
        B.numOfElems( nNumOfFactors );
        C.numOfElems( nNumOfFactors );

        for( int l = 0; l < nNumOfFactors; ++l ) {
            const tFactorParam& F = *VModel.m_FactorParam[l];

            double gSigma = F.m_Sigma[m_FactorSelect[3 * l]];
            double gTheta = F.m_Theta[m_FactorSelect[3 * l + 1]];
            double gAlpha = F.m_Alpha[m_FactorSelect[3 * l + 2]];

            A[l] = gSigma * dt2;            
            B[l] = gTheta * dt;
            C[l] = gAlpha * dt;
        }

        double f = m_gInitial;

        for( int j = 0; j < nNumOfSamples - 1; ++j ) {
            Forward[j] = f;

                // for the first factor, we have time-dependent
                // mean reversion
            double g = A[0] * Path[j][0] + B[0] * m_Mean[j] - C[0] * f;

                // for all other factors, this is not the case
            for( int l = 1; l < nNumOfFactors; ++l )
                g += A[l] * Path[j][l] + B[l] - C[l] * f;

            f += g;
        }

        Forward[nNumOfSamples - 1] = f;
    }

    return OK;
}


//
//   i n i t
//

void tVasicekModel::init()

{
    m_pMeanBoot = 0;
    m_pMeanDrift = 0;
    m_gConstMean = 1;
}


//
//   c l e a n u p
//

void tVasicekModel::cleanup()

{
    for( int i = 0; i < m_FactorParam.numOfElems(); ++i )
        delete m_FactorParam[i];
    m_FactorParam.reset();
}


//
//   c o p y F r o m
//

void tVasicekModel::copyFrom( const tVasicekModel& Model )

{
    if( &Model == this )
        return;

    cleanup();

    for( int i = 0; i < Model.m_FactorParam.numOfElems(); ++i ) {
        m_FactorParam.append( new tFactorParam );
        *m_FactorParam.last() = *Model.m_FactorParam[i];
    }

    m_gConstMean = Model.m_gConstMean;
    setObjRefToZeroOrSo( m_pMeanBoot, Model.m_pMeanBoot );
    setObjRefToZeroOrSo( m_pMeanDrift, Model.m_pMeanDrift );

    super::copyFrom( Model );
}


//
//   p a r s e F a c t o r
//

tRetCode tVasicekModel::parseFactor( tParser& Parser,
    tParseInfoStub& Info )

{
    MTG_ASSERT( Parser.curToken() == xTokId );

    tRetCode nRet;
    tObject* pObj;
    tFactor* pFactor;

    if( ( pObj = Parser.findObject() ) == 0 )
        return Parser.setError( NOT_FOUND );

    if( ( pFactor = dynamic_cast<tFactor*>( pObj ) ) == 0 ) 
        return Parser.setError( OBJECT_MISMATCH );

    if( ( nRet = Parser.readToken() ) != OK )
        return nRet;

    if( ( nRet = appendFactor( pFactor ) ) != OK )
        return nRet;

    if( Parser.curToken() == xTokSigma ) {
        if( ( nRet = Parser.readToken() ) != OK )
            return nRet;
    }

    tFactorParam* pParam = new tFactorParam;

    if( ( nRet = Parser.scanPosPercentageRange( pParam->m_Sigma ) ) != OK ) {
        delete pParam;
        return nRet;
    }

    if( pParam->m_Sigma.numOfElems() == 0 ) {
        delete pParam;
        return Parser.setError( MISSING_RANGE );
    }

    bool bNeedsTheta = false;

    if( Parser.curToken() == xTokTheta ) {
        if( ( nRet = Parser.readToken() ) != OK )
            return nRet;
        bNeedsTheta = true;
    }

    if( Parser.beginOfNumber() || Parser.beginOfObj() ) {
        if( ( nRet = Parser.scanNonNegPercentageRange( 
                pParam->m_Theta ) ) != OK ) {
            delete pParam;
            return nRet;
        }
    }
    else {
        if( bNeedsTheta ) {
            delete pParam;
            return Parser.setError( MISSING_RANGE );
        }
        pParam->m_Theta.append( 0 );
    }

    bool bNeedsAlpha = false;

    if( Parser.curToken() == xTokAlpha ) {
        if( ( nRet = Parser.readToken() ) != OK )
            return nRet;
        bNeedsAlpha = true;
    }

    if( Parser.beginOfNumber() || Parser.beginOfObj() ) {
        if( ( nRet = Parser.scanNonNegPercentageRange( 
                pParam->m_Alpha ) ) != OK ) {
            delete pParam;
            return nRet;
        }
    }
    else {
        if( bNeedsAlpha ) {
            delete pParam;
            return Parser.setError( MISSING_RANGE );
        }
        pParam->m_Alpha.append( 0 );
    }

    m_FactorParam.append( pParam );
    return OK;
}


//
//   p a r s e P a r a m
//

tRetCode tVasicekModel::parseParam( tParser& Parser,
    tParseInfoStub& Info )

{
    tRetCode nRet;
    tObject* pObj;
    tBootstrap* pBoot;
    tDrift* pDrift;
    tFactorParam FParam;

    tParseInfo& I = static_cast<tParseInfo&>( Info );

    switch( Parser.curToken() ) {
        case xTokFactor :
            if( ( nRet = Parser.readToken() ) != OK )
                return nRet;
            if( Parser.curToken() != xTokId ) 
                return Parser.setError( INVALID_KEYWORD );

                // fall through

        case xTokId :
            if( ( nRet = parseFactor( Parser, Info ) ) != OK )
                return nRet;
            break;

        case xTokMean :
            if( m_pMeanBoot != 0 || m_pMeanDrift != 0 || I.m_bParam1 )
                return Parser.setError( ATTR_REDEFINITION );
            if( ( nRet = Parser.readToken() ) != OK )
                return nRet;
            if( Parser.beginOfNumber() ) {
                if( ( nRet = Parser.scanPosPercentage( m_gConstMean ) ) != OK )
                    return nRet;
                I.m_bParam1 = true;
            }
            else {
                if( Parser.curToken() != xTokId )
                    return Parser.setError( INVALID_KEYWORD );
                if( ( pObj = Parser.findObject() ) == 0 )
                    return Parser.setError( NOT_FOUND );
                if( ( pBoot = dynamic_cast<tBootstrap*>( pObj ) ) != 0 )
                    setObjRefToZeroOrSo( m_pMeanBoot, pBoot );
                else
                if( ( pDrift = dynamic_cast<tDrift*>( pObj ) ) != 0 )
                    setObjRefToZeroOrSo( m_pMeanDrift, pDrift );
                else
                    return Parser.setError( "bootstrap/drift expected" );
                if( ( nRet = Parser.readToken() ) != OK )
                    return nRet;
            }
            break;

        default :
            return super::parseParam( Parser, Info );
    }

    return OK;
}


//
//   p a r s e P o s t f i x
//

tRetCode tVasicekModel::parsePostfix( tParser& Parser,
    tParseInfoStub& Info )

{
    if( numOfFactors() == 0 )
        return Parser.setError( MISSING_FACTOR );
    return super::parsePostfix( Parser, Info );
}


//
//   t V a s i c e k M o d e l
//

tVasicekModel::tVasicekModel()

{
    init();
}


//
//   t V a s i c e k M o d e l
//

tVasicekModel::tVasicekModel( const tVasicekModel& Model )

{
    init();
    copyFrom( Model );
}


//
//   ~ t V a s i c e k M o d e l
//

tVasicekModel::~tVasicekModel()

{
    setObjRefToZeroOrSo( m_pMeanBoot );
    setObjRefToZeroOrSo( m_pMeanDrift );
    cleanup();
}


//
//   o p e r a t o r =
//

tVasicekModel& tVasicekModel::operator=( const tVasicekModel& Model )

{
    if( &Model != this )
        copyFrom( Model );
    return *this;
}


//
//   c l o n e
//

tObject* tVasicekModel::clone() const

{
    return new tVasicekModel( *this );
}


//
//   f i n a l i z e
//

tRetCode tVasicekModel::finalize()

{
    return super::finalize();
}


//
//   n u m O f I n s t a n t i a t i o n s
//

int tVasicekModel::numOfInstantiations() const

{
    MTG_ASSERT( m_FactorParam.numOfElems() > 0 );

    int n = 1;

    for( int i = 0; i < m_FactorParam.numOfElems(); ++i ) {
        int s = m_FactorParam[i]->m_Sigma.numOfElems();
        int t = m_FactorParam[i]->m_Theta.numOfElems();
        int a = m_FactorParam[i]->m_Alpha.numOfElems();

        if( n >= ( ( INT_MAX / s ) / t ) / a )
            return -1;
        n *= s * t * a;
    }

    return n;
}


//
//   s e t I n s t a n t i a t i o n
//

void tVasicekModel::setInstantiation( tHeap<int>& Select,
    int nInstantiation ) const

{
    MTG_ASSERT( nInstantiation >= 0 &&
                nInstantiation < numOfInstantiations() );

    if( Select.numOfElems() == 0 )
        Select.numOfElems( 3 * m_FactorParam.numOfElems() );

    int k = Select.numOfElems();

    for( int i = numOfFactors() - 1; i >= 0; --i ) {
        const tFactorParam& P = *m_FactorParam[i];

        Select[--k] = nInstantiation % P.m_Alpha.numOfElems();
        nInstantiation /= P.m_Alpha.numOfElems();
        Select[--k] = nInstantiation % P.m_Theta.numOfElems();
        nInstantiation /= P.m_Theta.numOfElems();
        Select[--k] = nInstantiation % P.m_Sigma.numOfElems();
        nInstantiation /= P.m_Sigma.numOfElems();
    }

    MTG_ASSERT( k == 0 && nInstantiation == 0 );
}


//
//   c r e a t e E v o l u t i o n
//

tRetCode tVasicekModel::createEvolution( const tPathSpace& PathSpace,
    tMCEngine::tEvolutionStub*& pEvolution ) const

{
    MTG_ASSERT( isFinalized() );

    tRetCode nRet;
    tEvolution* p = new tEvolution( *this, PathSpace );

    if( ( nRet = p->prepare() ) != OK ) {
        delete p;
        return nRet;
    }

    pEvolution = p;
    return OK;
}


//
//   s a v e W e i g h t s H e a d e r
//

tRetCode tVasicekModel::saveWeightsHeader( ostream& Out ) const

{
    Out << numOfFactors() << " // number of factors" << endl;
    for( int i = 0; i < numOfFactors(); ++i ) {
        Out << "\"" << factor( i )->name() << "\""
            << " 3 // factor name, number of parameters\n"
            << "\"sigma\" \"theta\" \"alpha\" // parameter names"
            << endl;
    }

    return Out.good() ? OK : WRITE_ERROR;
}


//
//   s a v e W e i g h t s D a t a
//

tRetCode tVasicekModel::saveWeightsData( ostream& Out, tHeap<int> Select ) const

{
    MTG_ASSERT( Select.numOfElems() == 3 * numOfFactors() );

    for( int i = 0; i < numOfFactors(); ++i ) {
        const tFactorParam& F = *m_FactorParam[i];
        
        double gSigma = F.m_Sigma[Select[3 * i]];
        double gTheta = F.m_Theta[Select[3 * i + 1]];
        double gAlpha = F.m_Alpha[Select[3 * i + 2]];

        if( i > 0 )
            Out << " ";
        Out << gSigma << " " << gTheta << " " << gAlpha;
    }

    return Out.good() ? OK : WRITE_ERROR;
}

MTG_END_NAMESPACE