package ir.classifiers;

import ir.utilities.MoreMath;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Vector;

/* loaded from: input_file:ir/classifiers/CVLearningCurve.class */
public class CVLearningCurve {
    protected Vector<Example>[] totalExamples;
    protected Vector<Example>[][] foldBins;
    protected Classifier classifier;
    protected long randomSeed;
    protected int numClasses;
    protected int totalNumTrain;
    protected int numFolds;
    protected double[] points;
    protected static double[] DEFAULT_POINTS = {0.0d, 0.01d, 0.05d, 0.1d, 0.2d, 0.3d, 0.4d, 0.5d, 0.6d, 0.7d, 0.8d, 0.9d, 1.0d};
    protected boolean debug;
    protected double trainTime;
    protected double testTime;
    protected int testTimeNum;
    protected PointResults[] testResults;
    protected PointResults[] trainResults;

    public CVLearningCurve(int i, Classifier classifier, List<Example> list, double[] dArr, long j, boolean z) {
        this.debug = false;
        if (i < 2) {
            throw new IllegalArgumentException("Cannot have less than 2 folds");
        }
        this.numFolds = i;
        this.classifier = classifier;
        this.numClasses = classifier.getCategories().length;
        this.totalExamples = new Vector[this.numClasses];
        this.foldBins = new Vector[this.numClasses][this.numFolds];
        setTotalExamples(list);
        this.points = dArr;
        this.testResults = new PointResults[dArr.length];
        this.trainResults = new PointResults[dArr.length];
        this.randomSeed = j;
        this.debug = z;
        this.testTime = 0.0d;
        this.trainTime = 0.0d;
    }

    public CVLearningCurve(Classifier classifier, List<Example> list) {
        this(10, classifier, list, DEFAULT_POINTS, 1L, false);
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    public void setClassifier(Classifier classifier) {
        this.classifier = classifier;
    }

    public Vector[] getTotalExamples() {
        return this.totalExamples;
    }

    public void setTotalExamples(Vector<Example>[] vectorArr) {
        this.totalExamples = vectorArr;
    }

    public Vector<Example>[][] getFoldBins() {
        return this.foldBins;
    }

    public void setFoldBins(Vector<Example>[][] vectorArr) {
        this.foldBins = vectorArr;
    }

    public void setTotalExamples(List<Example> list) {
        this.totalNumTrain = (int) Math.round((1.0d - (1.0d / this.numFolds)) * list.size());
        for (Example example : list) {
            int category = example.getCategory();
            if (this.totalExamples[category] == null) {
                this.totalExamples[category] = new Vector<>();
            }
            this.totalExamples[category].add(example);
        }
    }

    public void run() throws Exception {
        System.out.println("Generating 10 fold CV learning curves...");
        trainAndTest();
        System.out.println();
        System.out.println("Total Training time in seconds: " + (this.trainTime / 1000.0d));
        System.out.println("Testing time per example in milliseconds: " + MoreMath.roundTo(this.testTime / this.testTimeNum, 2));
        makeGnuplotFile(this.trainResults, this.classifier.getName() + "Train");
        System.out.println("GNUPLOT train accuracy file is " + this.classifier.getName() + "Train.gplot");
        makeGnuplotFile(this.testResults, this.classifier.getName());
        System.out.println("GNUPLOT test accuracy file is " + this.classifier.getName() + ".gplot");
    }

    public void trainAndTest() {
        randomizeOrder();
        binExamples();
        for (int i = 0; i < this.points.length; i++) {
            double d = this.points[i];
            System.out.println("Train Percentage: " + (100.0d * d) + "%");
            this.testResults[i] = new PointResults(this.numFolds);
            this.trainResults[i] = new PointResults(this.numFolds);
            for (int i2 = 0; i2 < this.numFolds; i2++) {
                System.out.println("  Calculating results for fold " + i2);
                Vector<Example> trainCV = getTrainCV(i2, d);
                Vector<Example> testCV = getTestCV(i2);
                trainAndTestFold(trainCV, testCV, i2, this.testResults[i], this.trainResults[i]);
                if (this.debug) {
                    System.out.println("Training on:\n" + trainCV);
                    System.out.println("Testing on:\n" + testCV);
                }
            }
        }
    }

    public void trainAndTestFold(Vector<Example> vector, Vector<Example> vector2, int i, PointResults pointResults, PointResults pointResults2) {
        long currentTimeMillis = System.currentTimeMillis();
        this.classifier.train(vector);
        this.trainTime += System.currentTimeMillis() - currentTimeMillis;
        int i2 = 0;
        long currentTimeMillis2 = System.currentTimeMillis();
        Iterator<Example> it = vector2.iterator();
        while (it.hasNext()) {
            if (this.classifier.test(it.next())) {
                i2++;
            }
        }
        this.testTime += System.currentTimeMillis() - currentTimeMillis2;
        this.testTimeNum += vector2.size();
        pointResults.setPoint(vector.size());
        double size = (1.0d * i2) / vector2.size();
        pointResults.addResult(i, size);
        int i3 = 0;
        Iterator<Example> it2 = vector.iterator();
        while (it2.hasNext()) {
            if (this.classifier.test(it2.next())) {
                i3++;
            }
        }
        pointResults2.setPoint(vector.size());
        double size2 = (1.0d * i3) / vector.size();
        if (vector.size() == 0) {
            size2 = 1.0d;
        }
        pointResults2.addResult(i, size2);
        System.out.println("    Train Accuracy = " + MoreMath.roundTo(100.0d * size2, 3) + "%; Test Accuracy = " + MoreMath.roundTo(100.0d * size, 3) + "%");
    }

    public void binExamples() {
        for (int i = 0; i < this.numClasses; i++) {
            for (int i2 = 0; i2 < this.numFolds; i2++) {
                this.foldBins[i][i2] = new Vector<>();
            }
            for (int i3 = 0; i3 < this.totalExamples[i].size(); i3++) {
                this.foldBins[i][i3 % this.numFolds].add(this.totalExamples[i].get(i3));
            }
        }
    }

    public Vector<Example> getTrainCV(int i, double d) {
        Vector<Example> vector = new Vector<>();
        int round = (int) Math.round(d * this.totalNumTrain);
        int i2 = 0;
        while (true) {
            if (i2 >= this.numFolds) {
                break;
            }
            if (i2 != i) {
                int sizeOfFold = sizeOfFold(i2);
                if (vector.size() + sizeOfFold <= round) {
                    for (int i3 = 0; i3 < this.numClasses; i3++) {
                        vector.addAll(this.foldBins[i3][i2]);
                    }
                } else {
                    double size = (round - vector.size()) / sizeOfFold;
                    for (int i4 = 0; i4 < this.numClasses; i4++) {
                        int round2 = (int) Math.round(size * this.foldBins[i4][i2].size());
                        for (int i5 = 0; i5 < round2; i5++) {
                            vector.add(this.foldBins[i4][i2].get(i5));
                        }
                    }
                }
            }
            i2++;
        }
        System.out.println("    Number of training examples: " + vector.size());
        return vector;
    }

    protected int sizeOfFold(int i) {
        int i2 = 0;
        for (int i3 = 0; i3 < this.numClasses; i3++) {
            i2 += this.foldBins[i3][i].size();
        }
        return i2;
    }

    public Vector<Example> getTestCV(int i) {
        Vector<Example> vector = new Vector<>();
        for (int i2 = 0; i2 < this.numClasses; i2++) {
            vector.addAll(this.foldBins[i2][i]);
        }
        return vector;
    }

    private void randomizeOrder() {
        Random random = new Random(this.randomSeed);
        for (int i = 0; i < this.numClasses; i++) {
            int size = this.totalExamples[i].size();
            for (int i2 = size - 1; i2 > 0; i2--) {
                int nextInt = random.nextInt(size);
                Example example = this.totalExamples[i].get(i2);
                this.totalExamples[i].set(i2, this.totalExamples[i].get(nextInt));
                this.totalExamples[i].set(nextInt, example);
            }
        }
    }

    void writeCurve(PointResults[] pointResultsArr, String str) throws IOException {
        PrintWriter printWriter = new PrintWriter(new FileWriter(str + ".data"));
        for (PointResults pointResults : pointResultsArr) {
            double d = 0.0d;
            double point = pointResults.getPoint();
            for (double d2 : pointResults.getResults()) {
                d += d2;
            }
            printWriter.println(Math.round(point) + "\t" + (d / r0.length));
        }
        printWriter.close();
    }

    void makeGnuplotFile(PointResults[] pointResultsArr, String str) throws IOException {
        writeCurve(pointResultsArr, str);
        PrintWriter printWriter = new PrintWriter(new FileWriter(new File(str + ".gplot")));
        printWriter.print("set xlabel \"Size of training set\"\nset ylabel \"Accuracy\"\n\nset terminal postscript color\nset size 0.75,0.75\n\nset style data linespoints\nset key bottom right\n\nplot '" + str + ".data' title \"" + str + "\"");
        printWriter.close();
    }
}
