package edu.utexas.cs.tamerProject.modeling.kNN;

import weka.classifiers.Classifier;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:edu/utexas/cs/tamerProject/modeling/kNN/KNN.class */
public class KNN extends Classifier {
    private static final long serialVersionUID = 6667;
    NearestNeighbourSearch nNSearch;
    int k;
    double biasStrength;
    double baselineBias;
    double neighborSum;

    public KNN() {
        this.biasStrength = 0.1d;
        this.baselineBias = 0.0d;
        this.nNSearch = new KDTree();
    }

    public KNN(String str) {
        this.biasStrength = 0.1d;
        this.baselineBias = 0.0d;
        if (str.equals("KDTree")) {
            this.nNSearch = new KDTree();
        } else if (str.equals("BallTree")) {
            this.nNSearch = new BallTree();
        } else if (str.equals("CoverTree")) {
            this.nNSearch = new CoverTree();
        } else if (str.equals("LinearNNSearch")) {
            this.nNSearch = new LinearNNSearch();
        }
        this.nNSearch.setMeasurePerformance(true);
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        EuclideanDistance euclideanDistance = new EuclideanDistance(instances);
        euclideanDistance.setDontNormalize(true);
        this.nNSearch.setDistanceFunction(euclideanDistance);
        this.nNSearch.setInstances(instances);
        this.k = Math.max(1, (int) Math.sqrt(Math.sqrt(instances.numInstances())));
    }

    @Override // weka.classifiers.Classifier
    public double classifyInstance(Instance instance) throws Exception {
        if (this.nNSearch.getInstances() == null || this.nNSearch.getInstances().numInstances() == 0) {
            return this.baselineBias;
        }
        Instances kNearestNeighbours = this.nNSearch.kNearestNeighbours(instance, this.k);
        this.neighborSum = 0.0d;
        for (int i = 0; i < kNearestNeighbours.numInstances(); i++) {
            double distance = this.nNSearch.getDistanceFunction().distance(instance, kNearestNeighbours.instance(i));
            double max = Math.max(1.0d - (distance * this.biasStrength), 1.0d / (1.0d + ((10.0d * this.biasStrength) * distance)));
            this.neighborSum += (kNearestNeighbours.instance(i).classValue() * max) + ((1.0d - max) * this.baselineBias);
        }
        return this.neighborSum / kNearestNeighbours.numInstances();
    }
}
