package weka.classifiers.functions.pace;

import java.util.Random;
import weka.core.RevisionUtils;
import weka.core.matrix.DoubleVector;
import weka.core.matrix.Maths;

/* loaded from: input_file:weka/classifiers/functions/pace/NormalMixture.class */
public class NormalMixture extends MixtureDistribution {
    protected double separatingThreshold = 0.05d;
    protected double trimingThreshold = 0.7d;
    protected double fittingIntervalLength = 3.0d;

    public double getSeparatingThreshold() {
        return this.separatingThreshold;
    }

    public void setSeparatingThreshold(double d) {
        this.separatingThreshold = d;
    }

    public double getTrimingThreshold() {
        return this.trimingThreshold;
    }

    public void setTrimingThreshold(double d) {
        this.trimingThreshold = d;
    }

    @Override // weka.classifiers.functions.pace.MixtureDistribution
    public boolean separable(DoubleVector doubleVector, int i, int i2, double d) {
        double d2 = 0.0d;
        for (int i3 = i; i3 <= i2; i3++) {
            d2 += Maths.pnorm(-Math.abs(d - doubleVector.get(i3)));
        }
        return d2 < this.separatingThreshold;
    }

    @Override // weka.classifiers.functions.pace.MixtureDistribution
    public DoubleVector supportPoints(DoubleVector doubleVector, int i) {
        if (doubleVector.size() < 2) {
            throw new IllegalArgumentException("data size < 2");
        }
        return doubleVector.copy();
    }

    @Override // weka.classifiers.functions.pace.MixtureDistribution
    public PaceMatrix fittingIntervals(DoubleVector doubleVector) {
        DoubleVector cat = doubleVector.cat(doubleVector.minus(this.fittingIntervalLength));
        DoubleVector cat2 = doubleVector.plus(this.fittingIntervalLength).cat(doubleVector);
        PaceMatrix paceMatrix = new PaceMatrix(cat.size(), 2);
        paceMatrix.setMatrix(0, cat.size() - 1, 0, cat);
        paceMatrix.setMatrix(0, cat2.size() - 1, 1, cat2);
        return paceMatrix;
    }

    @Override // weka.classifiers.functions.pace.MixtureDistribution
    public PaceMatrix probabilityMatrix(DoubleVector doubleVector, PaceMatrix paceMatrix) {
        int size = doubleVector.size();
        int rowDimension = paceMatrix.getRowDimension();
        PaceMatrix paceMatrix2 = new PaceMatrix(rowDimension, size);
        for (int i = 0; i < rowDimension; i++) {
            for (int i2 = 0; i2 < size; i2++) {
                paceMatrix2.set(i, i2, Maths.pnorm(paceMatrix.get(i, 1), doubleVector.get(i2), 1.0d) - Maths.pnorm(paceMatrix.get(i, 0), doubleVector.get(i2), 1.0d));
            }
        }
        return paceMatrix2;
    }

    public double empiricalBayesEstimate(double d) {
        if (Math.abs(d) > 10.0d) {
            return d;
        }
        DoubleVector dnormLog = Maths.dnormLog(d, this.mixingDistribution.getPointValues(), 1.0d);
        dnormLog.minusEquals(dnormLog.max());
        DoubleVector map = dnormLog.map("java.lang.Math", "exp");
        map.timesEquals(this.mixingDistribution.getFunctionValues());
        return this.mixingDistribution.getPointValues().innerProduct(map) / map.sum();
    }

    public DoubleVector empiricalBayesEstimate(DoubleVector doubleVector) {
        DoubleVector doubleVector2 = new DoubleVector(doubleVector.size());
        for (int i = 0; i < doubleVector.size(); i++) {
            doubleVector2.set(i, empiricalBayesEstimate(doubleVector.get(i)));
        }
        trim(doubleVector2);
        return doubleVector2;
    }

    public DoubleVector nestedEstimate(DoubleVector doubleVector) {
        DoubleVector doubleVector2 = new DoubleVector(doubleVector.size());
        for (int i = 0; i < doubleVector.size(); i++) {
            doubleVector2.set(i, hf(doubleVector.get(i)));
        }
        doubleVector2.cumulateInPlace();
        int indexOfMax = doubleVector2.indexOfMax();
        DoubleVector copy = doubleVector.copy();
        if (indexOfMax < doubleVector.size() - 1) {
            copy.set(indexOfMax + 1, doubleVector.size() - 1, 0.0d);
        }
        trim(copy);
        return copy;
    }

    public DoubleVector subsetEstimate(DoubleVector doubleVector) {
        DoubleVector h = h(doubleVector);
        DoubleVector copy = doubleVector.copy();
        for (int i = 0; i < doubleVector.size(); i++) {
            if (h.get(i) <= 0.0d) {
                copy.set(i, 0.0d);
            }
        }
        trim(copy);
        return copy;
    }

    public void trim(DoubleVector doubleVector) {
        for (int i = 0; i < doubleVector.size(); i++) {
            if (Math.abs(doubleVector.get(i)) <= this.trimingThreshold) {
                doubleVector.set(i, 0.0d);
            }
        }
    }

    public double hf(double d) {
        DoubleVector pointValues = this.mixingDistribution.getPointValues();
        DoubleVector functionValues = this.mixingDistribution.getFunctionValues();
        DoubleVector dnormLog = Maths.dnormLog(d, pointValues, 1.0d);
        dnormLog.minusEquals(dnormLog.max());
        DoubleVector map = dnormLog.map("java.lang.Math", "exp");
        map.timesEquals(functionValues);
        return pointValues.times(2.0d * d).minusEquals(d * d).innerProduct(map) / map.sum();
    }

    public double h(double d) {
        DoubleVector pointValues = this.mixingDistribution.getPointValues();
        return pointValues.times(2.0d * d).minusEquals(d * d).innerProduct(Maths.dnorm(d, pointValues, 1.0d).timesEquals(this.mixingDistribution.getFunctionValues()));
    }

    public DoubleVector h(DoubleVector doubleVector) {
        DoubleVector doubleVector2 = new DoubleVector(doubleVector.size());
        for (int i = 0; i < doubleVector.size(); i++) {
            doubleVector2.set(i, h(doubleVector.get(i)));
        }
        return doubleVector2;
    }

    public double f(double d) {
        DoubleVector pointValues = this.mixingDistribution.getPointValues();
        return Maths.dchisq(d, pointValues).timesEquals(this.mixingDistribution.getFunctionValues()).sum();
    }

    public DoubleVector f(DoubleVector doubleVector) {
        DoubleVector doubleVector2 = new DoubleVector(doubleVector.size());
        for (int i = 0; i < doubleVector.size(); i++) {
            doubleVector2.set(i, h(doubleVector2.get(i)));
        }
        return doubleVector2;
    }

    @Override // weka.classifiers.functions.pace.MixtureDistribution
    public String toString() {
        return this.mixingDistribution.toString();
    }

    @Override // weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.5 $");
    }

    public static void main(String[] strArr) {
        DoubleVector cat = Maths.rnorm(50, 0.0d, 1.0d, new Random()).cat(Maths.rnorm(50, 5.0d, 1.0d, new Random()));
        DoubleVector cat2 = new DoubleVector(50, 0.0d).cat(new DoubleVector(50, 5.0d));
        System.out.println("==========================================================");
        System.out.println("This is to test the estimation of the mixing\ndistribution of the mixture of unit variance normal\ndistributions. The example mixture used is of the form: \n\n   0.5 * N(mu1, 1) + 0.5 * N(mu2, 1)\n");
        System.out.println("It also tests three estimators: the subset\nselector, the nested model selector, and the empirical Bayes\nestimator. Quadratic losses of the estimators are given, \nand are taken as the measure of their performance.");
        System.out.println("==========================================================");
        System.out.println("mu1 = 0.0 mu2 = 5.0\n");
        System.out.println(String.valueOf(cat.size()) + " observations are: \n\n" + cat);
        System.out.println("\nQuadratic loss of the raw data (i.e., the MLE) = " + cat.sum2(cat2));
        System.out.println("==========================================================");
        NormalMixture normalMixture = new NormalMixture();
        normalMixture.fit(cat, 1);
        System.out.println("The estimated mixing distribution is:\n" + normalMixture);
        DoubleVector rev = normalMixture.nestedEstimate(cat.rev()).rev();
        System.out.println("\nThe Nested Estimate = \n" + rev);
        System.out.println("Quadratic loss = " + rev.sum2(cat2));
        DoubleVector subsetEstimate = normalMixture.subsetEstimate(cat);
        System.out.println("\nThe Subset Estimate = \n" + subsetEstimate);
        System.out.println("Quadratic loss = " + subsetEstimate.sum2(cat2));
        DoubleVector empiricalBayesEstimate = normalMixture.empiricalBayesEstimate(cat);
        System.out.println("\nThe Empirical Bayes Estimate = \n" + empiricalBayesEstimate);
        System.out.println("Quadratic loss = " + empiricalBayesEstimate.sum2(cat2));
    }
}
