/*
 * Decompiled with CFR 0.152.
 */
package haubold.hmm.algorithm;

import haubold.hmm.algorithm.Backward;
import haubold.hmm.algorithm.Forward;
import haubold.hmm.algorithm.HiddenMarkovModel;
import haubold.hmm.algorithm.Reestimation;
import haubold.hmm.algorithm.XiGamma;

public class BaumWelch {
    private double tol = 1.0E-5;
    private HiddenMarkovModel hmm;
    private int[] obsSeq;
    private int maxIt;
    private double logP;
    private Forward forward;
    private Backward backward;
    private XiGamma xiGamma;
    private Reestimation reestimation;

    public BaumWelch() {
    }

    public BaumWelch(HiddenMarkovModel initialHmm, int[] obsSeq, int maxIt) {
        this.obsSeq = obsSeq;
        this.maxIt = maxIt;
        this.initializeHmm(initialHmm);
    }

    private HiddenMarkovModel initializeHmm(HiddenMarkovModel initialHmm) {
        this.hmm = new HiddenMarkovModel();
        this.hmm.setInitialProbabilities(initialHmm.getInitialProbabilities());
        this.hmm.setTransitionProbabilities(initialHmm.getTransitionProbabilities());
        this.hmm.setObservationProbabilities(initialHmm.getObservationProbabilities());
        this.forward = new Forward();
        this.backward = new Backward();
        this.xiGamma = new XiGamma();
        this.reestimation = new Reestimation();
        return this.hmm;
    }

    public HiddenMarkovModel newHmm() {
        double[][] transProb = this.hmm.getTransitionProbabilities();
        double[][] obsProb = this.hmm.getObservationProbabilities();
        int it = 0;
        this.forward.computeForward(this.hmm, this.obsSeq);
        this.logP = this.forward.getLogP();
        this.backward.computeBackward(this.hmm, this.obsSeq, this.forward.getScale());
        this.xiGamma.computeXiGamma(this.forward.getForwardProbabilities(), this.backward.getBackwardProbabilities(), this.hmm, this.obsSeq);
        while (it < this.maxIt) {
            this.hmm = this.reestimation.reestimateHMM(this.hmm, this.obsSeq, this.xiGamma);
            this.forward.computeForward(this.hmm, this.obsSeq);
            if (Math.abs(this.logP - this.forward.getLogP()) < this.tol) break;
            ++it;
            this.logP = this.forward.getLogP();
            this.backward.computeBackward(this.reestimation.getHmm(), this.obsSeq, this.forward.getScale());
            this.xiGamma.computeXiGamma(this.forward.getForwardProbabilities(), this.backward.getBackwardProbabilities(), this.reestimation.getHmm(), this.obsSeq);
        }
        return this.reestimation.getHmm();
    }

    public double getTol() {
        return this.tol;
    }

    public void setTol(double d) {
        this.tol = d;
    }

    public HiddenMarkovModel getHmm() {
        return this.hmm;
    }

    public void setHmm(HiddenMarkovModel initialModel) {
        this.hmm = this.initializeHmm(initialModel);
    }

    public void setObsSeq(int[] is) {
        this.obsSeq = is;
    }

    public void setMaxIt(int i) {
        this.maxIt = i;
    }

    public double getLogP() {
        return this.logP;
    }
}

