package ir.classifiers;

import ir.utilities.Weight;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:ir/classifiers/NaiveBayes.class */
public class NaiveBayes extends Classifier {
    boolean isLaplace = true;
    double EPSILON = 1.0E-6d;
    BayesResult trainResult;
    public static final String name = "NaiveBayes";
    int numCategories;
    int numFeatures;
    int numExamples;
    boolean debug;

    public NaiveBayes(String[] strArr, boolean z) {
        this.debug = false;
        this.categories = strArr;
        this.debug = z;
        this.numCategories = strArr.length;
    }

    public void setDebug(boolean z) {
        this.debug = z;
    }

    public void setLaplace(boolean z) {
        this.isLaplace = z;
    }

    public void setEpsilon(double d) {
        this.EPSILON = d;
    }

    @Override // ir.classifiers.Classifier
    public String getName() {
        return name;
    }

    public double getEpsilon() {
        return this.EPSILON;
    }

    public BayesResult getTrainResult() {
        return this.trainResult;
    }

    public boolean getIsLaplace() {
        return this.isLaplace;
    }

    @Override // ir.classifiers.Classifier
    public void train(List<Example> list) {
        this.trainResult = new BayesResult();
        this.numExamples = list.size();
        this.trainResult.setClassPriors(calculatePriors(list));
        this.trainResult.setFeatureTable(conditionalProbs(list));
        if (this.debug) {
            displayProbs(this.trainResult.getClassPriors(), this.trainResult.getFeatureTable());
        }
    }

    @Override // ir.classifiers.Classifier
    public boolean test(Example example) {
        double[] calculateProbs = calculateProbs(example);
        int argMax = argMax(calculateProbs);
        if (this.debug) {
            System.out.print("Document: " + example.name + "\nResults: ");
            for (int i = 0; i < this.numCategories; i++) {
                System.out.print(this.categories[i] + "(" + calculateProbs[i] + ")\t");
            }
            System.out.println("\nCorrect class: " + example.getCategory() + ", Predicted class: " + argMax + "\n");
        }
        return argMax == example.getCategory();
    }

    protected double[] calculatePriors(List<Example> list) {
        double[] dArr = new double[this.numCategories];
        for (int i = 0; i < this.numCategories; i++) {
            dArr[i] = 0.0d;
        }
        Iterator<Example> it = list.iterator();
        while (it.hasNext()) {
            int category = it.next().getCategory();
            dArr[category] = dArr[category] + 1.0d;
        }
        for (int i2 = 0; i2 < this.numCategories; i2++) {
            if (this.isLaplace) {
                dArr[i2] = Math.log((dArr[i2] + 1.0d) / (this.numExamples + this.numCategories));
            } else {
                dArr[i2] = Math.log(dArr[i2] / this.numExamples);
            }
        }
        if (this.debug) {
            System.out.println("\nLog Class Priors:");
            for (int i3 = 0; i3 < this.numCategories; i3++) {
                System.out.print(dArr[i3] + " ");
            }
            System.out.println();
        }
        return dArr;
    }

    protected Hashtable<String, double[]> conditionalProbs(List<Example> list) {
        double[] dArr;
        Hashtable<String, double[]> hashtable = new Hashtable<>();
        double[] dArr2 = new double[this.numCategories];
        for (int i = 0; i < this.numCategories; i++) {
            dArr2[i] = 0.0d;
        }
        for (Example example : list) {
            if (this.debug) {
                System.out.println("\nExample: " + example);
                System.out.println("Number of tokens: " + example.getHashMapVector().hashMap.size());
            }
            for (Map.Entry<String, Weight> entry : example.getHashMapVector().entrySet()) {
                String key = entry.getKey();
                int value = (int) entry.getValue().getValue();
                if (this.debug) {
                    System.out.println("Counts of token: " + key);
                }
                if (hashtable.containsKey(key)) {
                    dArr = hashtable.get(key);
                } else {
                    dArr = new double[this.numCategories];
                    for (int i2 = 0; i2 < this.numCategories; i2++) {
                        dArr[i2] = 0.0d;
                    }
                    hashtable.put(key, dArr);
                }
                double[] dArr3 = dArr;
                int category = example.getCategory();
                dArr3[category] = dArr3[category] + value;
                int category2 = example.getCategory();
                dArr2[category2] = dArr2[category2] + value;
                if (this.debug) {
                    for (double d : dArr) {
                        System.out.print(d + " ");
                    }
                    System.out.println();
                }
            }
        }
        this.numFeatures = hashtable.size();
        if (this.debug) {
            System.out.println("\nLog Probs before multiplying priors...\n");
        }
        for (Map.Entry<String, double[]> entry2 : hashtable.entrySet()) {
            String key2 = entry2.getKey();
            double[] value2 = entry2.getValue();
            for (int i3 = 0; i3 < this.numCategories; i3++) {
                if (this.isLaplace) {
                    value2[i3] = (value2[i3] + 1.0d) / (dArr2[i3] + this.numFeatures);
                } else if (value2[i3] == 0.0d) {
                    value2[i3] = this.EPSILON;
                } else {
                    value2[i3] = value2[i3] / dArr2[i3];
                }
                value2[i3] = Math.log(value2[i3]);
            }
            if (this.debug) {
                System.out.println("Log probs of " + key2);
                for (double d2 : value2) {
                    System.out.print(d2 + " ");
                }
                System.out.println();
            }
        }
        return hashtable;
    }

    protected double[] calculateProbs(Example example) {
        double[] dArr = (double[]) this.trainResult.getClassPriors().clone();
        Hashtable<String, double[]> featureTable = this.trainResult.getFeatureTable();
        for (Map.Entry<String, Weight> entry : example.getHashMapVector().entrySet()) {
            String key = entry.getKey();
            int value = (int) entry.getValue().getValue();
            if (featureTable.containsKey(key)) {
                double[] dArr2 = featureTable.get(key);
                for (int i = 0; i < this.numCategories; i++) {
                    int i2 = i;
                    dArr[i2] = dArr[i2] + (value * dArr2[i]);
                }
            }
        }
        return dArr;
    }

    protected void displayProbs(double[] dArr, Hashtable<String, double[]> hashtable) {
        System.out.println("\nAfter multiplying priors...");
        for (Map.Entry<String, double[]> entry : hashtable.entrySet()) {
            String key = entry.getKey();
            double[] value = entry.getValue();
            System.out.print("\nFeature: " + key + ", Probs: ");
            for (int i = 0; i < value.length; i++) {
                System.out.print(" " + Math.pow(2.718281828459045d, dArr[i] + value[i]));
            }
        }
        System.out.println();
    }
}
