/*
 * Decompiled with CFR 0.152.
 */
package haubold.hmm.algorithm;

import haubold.hmm.algorithm.HiddenMarkovModel;
import haubold.hmm.algorithm.XiGamma;

public class Reestimation {
    private HiddenMarkovModel hmm;

    public HiddenMarkovModel reestimateHMM(HiddenMarkovModel hmm, int[] observedStates, XiGamma xiGamma) {
        this.hmm = hmm;
        double[][][] xi = xiGamma.getXi();
        double[][] gamma = xiGamma.getGamma();
        hmm = this.reestimateInitialProb(hmm, gamma);
        hmm = this.reestimateTransitionProb(hmm, xi, gamma);
        hmm = this.reestimateEmmissionProb(hmm, gamma, observedStates);
        return hmm;
    }

    private HiddenMarkovModel reestimateInitialProb(HiddenMarkovModel hmm, double[][] gamma) {
        double[] ip = hmm.getInitialProbabilities();
        int i = 0;
        while (i < hmm.getNumStates()) {
            ip[i] = gamma[i][0];
            ++i;
        }
        hmm.setInitialProbabilities(ip);
        return hmm;
    }

    private HiddenMarkovModel reestimateTransitionProb(HiddenMarkovModel hmm, double[][][] xi, double[][] gamma) {
        int m = hmm.getNumObservationSymbols();
        int n = hmm.getNumStates();
        double[][] tp = hmm.getTransitionProbabilities();
        int len = xi[0][0].length;
        int i = 0;
        while (i < n) {
            int j = 0;
            while (j < n) {
                double numerator = 0.0;
                double denominator = 0.0;
                int k = 0;
                while (k < len - 1) {
                    numerator += xi[i][j][k];
                    denominator += gamma[i][k];
                    ++k;
                }
                tp[i][j] = numerator / denominator;
                ++j;
            }
            ++i;
        }
        hmm.setTransitionProbabilities(tp);
        return hmm;
    }

    private HiddenMarkovModel reestimateEmmissionProb(HiddenMarkovModel hmm, double[][] gamma, int[] obsStat) {
        int m = hmm.getNumObservationSymbols();
        int n = hmm.getNumStates();
        double[][] op = hmm.getObservationProbabilities();
        int len = gamma[0].length;
        int i = 0;
        while (i < m) {
            int j = 0;
            while (j < n) {
                double numerator = 0.0;
                double denominator = 0.0;
                int k = 0;
                while (k < len) {
                    if (obsStat[k] == i) {
                        numerator += gamma[j][k];
                    }
                    denominator += gamma[j][k];
                    ++k;
                }
                op[j][i] = numerator / denominator;
                ++j;
            }
            ++i;
        }
        hmm.setObservationProbabilities(op);
        return hmm;
    }

    public HiddenMarkovModel getHmm() {
        return this.hmm;
    }
}

