package edu.utexas.cs.tamerProject.modeling;

import edu.utexas.cs.tamerProject.modeling.kNN.KNN;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Arrays;
import weka.attributeSelection.GreedyStepwise;
import weka.attributeSelection.WrapperSubsetEval;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.lazy.IBk;
import weka.classifiers.trees.M5P;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;

/* loaded from: input_file:edu/utexas/cs/tamerProject/modeling/WekaRegressor.class */
public class WekaRegressor extends WekaModel implements Externalizable {
    private static final long serialVersionUID = 248;

    public WekaRegressor(int i) {
        super(i);
    }

    public WekaRegressor(String str, int i) {
        super(str, i);
    }

    @Override // edu.utexas.cs.tamerProject.modeling.WekaModel
    protected void initData(int i) {
        this.numAttributes = i + 1;
        System.out.println("Instantiating batch model with " + this.numAttributes + " attributes.");
        System.out.println("Class: " + this);
        this.attrInfo = new FastVector();
        for (int i2 = 0; i2 < i; i2++) {
            this.attrInfo.addElement(new Attribute(new StringBuilder().append(i2 + 1).toString()));
        }
        this.attrInfo.addElement(new Attribute("Label"));
        this.data = new Instances("Test", this.attrInfo, 0);
        this.data.setClassIndex(this.data.numAttributes() - 1);
    }

    @Override // edu.utexas.cs.tamerProject.modeling.WekaModel
    protected void setUpClassifiers(String str) {
        this.modelName = str;
        try {
            System.out.println("Given model name: " + str);
            if (str.equals("") || str.equals("IBk")) {
                this.classifiers.add(new IBk());
            } else if (str.equals("M5P")) {
                this.classifiers.add(new M5P());
            } else if (str.equals("KDTree")) {
                this.classifiers.add(new KNN("KDTree"));
            } else if (str.equals("BallTree")) {
                this.classifiers.add(new KNN("BallTree"));
            } else if (str.equals("CoverTree")) {
                this.classifiers.add(new KNN("CoverTree"));
            } else if (str.equals("LinearNNSearch")) {
                this.classifiers.add(new KNN("LinearNNSearch"));
            }
            AttributeSelectedClassifier attributeSelectedClassifier = new AttributeSelectedClassifier();
            WrapperSubsetEval wrapperSubsetEval = new WrapperSubsetEval();
            wrapperSubsetEval.setClassifier(this.classifiers.get(0));
            GreedyStepwise greedyStepwise = new GreedyStepwise();
            greedyStepwise.setSearchBackwards(true);
            attributeSelectedClassifier.setClassifier(this.classifiers.get(0));
            attributeSelectedClassifier.setEvaluator(wrapperSubsetEval);
            attributeSelectedClassifier.setSearch(greedyStepwise);
            this.classifiers.add(attributeSelectedClassifier);
        } catch (Exception e) {
            System.out.println("Exception while initializing classifiers: " + e);
        }
        this.classifier = this.classifiers.get(0);
        System.out.println("Regressor in use: " + this.classifier);
    }

    public void addInstance(Sample sample) {
        double d = sample.weight;
        double[] attributes = sample.getAttributes();
        if (this.numAttributes != attributes.length) {
            System.err.println("The number of attributes used to instantiate the model doesn't match the number in the sample to be added.");
            System.err.println("Number from instantiation: " + this.numAttributes + ". Number in sample: " + attributes.length);
            System.exit(1);
        }
        this.data.add(new Instance(d, attributes));
        this.uniques.add(new Double(sample.unique));
    }

    public void addInstanceWReplacement(Sample sample) {
        if (sample.weight != 0.0d) {
            int indexOf = this.uniques.indexOf(new Double(sample.unique));
            if (indexOf != -1 && sample.unique != -1) {
                this.uniques.remove(indexOf);
                this.data.delete(indexOf);
            }
            addInstance(sample);
        }
    }

    @Override // edu.utexas.cs.tamerProject.modeling.WekaModel
    public void buildModel() {
        try {
            super.buildModel();
            if (this.classifier.getClass().getName().equals("weka.classifiers.lazy.IBk")) {
                int max = Math.max(1, (int) Math.sqrt(this.data.numInstances()));
                ((IBk) this.classifier).setKNN(max);
                System.out.println("number of neighbors for k-NN: " + max);
            }
        } catch (Exception e) {
            System.out.println("Exception while building classifier: " + e);
            System.out.println(".Classifier " + this.classifier.getClass().toString() + " will not be built.");
            System.out.println("Cause: " + e.getCause());
            System.out.println("\nStack trace: ");
            e.printStackTrace();
        }
    }

    @Override // edu.utexas.cs.tamerProject.modeling.WekaModel
    public Instance makeInstance(Sample sample) {
        double[] attributes = sample.getAttributes();
        if (this.numAttributes != attributes.length) {
            System.err.println("The number of attributes used to instantiate the model doesn't match the number in the sample to be added.");
            System.err.println("Number from instantiation: " + this.numAttributes + ". Number in sample: " + attributes.length);
            System.exit(1);
        }
        return new Instance(sample.weight, attributes);
    }

    @Override // java.io.Externalizable
    public void writeExternal(ObjectOutput objectOutput) throws IOException {
        System.out.println("\n\n\n\n\n\n\n\n\nCALLED!!!!!!!!!!!!!!!!\n\n\n\n");
        objectOutput.writeObject(this.data);
    }

    @Override // java.io.Externalizable
    public void readExternal(ObjectInput objectInput) throws IOException, ClassNotFoundException {
        System.out.println("\n\n\n\n\n\n\n\n\nCALLED!!!!!!!!!!!!!!!!\n\n\n\n");
        this.data = (Instances) objectInput.readObject();
    }

    public static void main(String[] strArr) {
        FastVector fastVector = new FastVector();
        fastVector.addElement(new Attribute("pos"));
        fastVector.addElement(new Attribute("vel"));
        fastVector.addElement(new Attribute("hreinf"));
        Instances instances = new Instances("Test", fastVector, 0);
        System.out.println("The instances metainfo: " + instances + "\n\n");
        instances.setClassIndex(instances.numAttributes() - 1);
        double[] dArr = {1.0d, 2.0d, -4.0d};
        Instance instance = new Instance(1.0d, dArr);
        instances.add(instance);
        double[] dArr2 = {1.0d, 4.0d, -8.0d};
        System.out.println("The instance: " + instance + "\n\n");
        Instance instance2 = new Instance(1.0d, dArr2);
        instances.add(instance2);
        double[] dArr3 = {1.0d, 8.0d, -16.0d};
        System.out.println("The instance: " + instance2 + "\n\n");
        Instance instance3 = new Instance(1.0d, dArr3);
        instances.add(instance3);
        System.out.println("The instance: " + instance3 + "\n\n");
        System.out.println("Number of attributes: " + instances.numAttributes());
        System.out.println("Class index: " + instances.classIndex());
        LinearRegression linearRegression = new LinearRegression();
        try {
            linearRegression.buildClassifier(instances);
        } catch (Exception e) {
            System.out.println("Exception while building classifier: " + e);
        }
        System.out.println(linearRegression);
        double d = 0.0d;
        try {
            d = linearRegression.classifyInstance(new Instance(1.0d, new double[]{1.0d, 2.0d, Instance.missingValue()}));
        } catch (Exception e2) {
            System.out.println("Exception while classifying instance: " + e2);
        }
        System.out.println("\nclassification: " + d);
        try {
            d = linearRegression.classifyInstance(new Instance(1.0d, new double[]{1.0d, 3.0d, Instance.missingValue()}));
        } catch (Exception e3) {
            System.out.println("Exception while classifying instance: " + e3);
        }
        System.out.println("\nclassification: " + d);
        System.out.println("\n\n-------Testing class methods-------\n\n");
        WekaRegressor wekaRegressor = new WekaRegressor("KDTree", 2);
        wekaRegressor.addInstance(new Sample(Arrays.copyOfRange(dArr, 0, 2), dArr[2], 1.0d));
        wekaRegressor.addInstance(new Sample(Arrays.copyOfRange(dArr2, 0, 2), dArr2[2], 1.0d));
        wekaRegressor.addInstance(new Sample(Arrays.copyOfRange(dArr3, 0, 2), dArr3[2], 1.0d));
        wekaRegressor.buildModel();
        System.out.println("\nclassification: " + wekaRegressor.classifyInstance(new double[]{1.0d, 2.0d}));
        System.out.println("\nclassification: " + wekaRegressor.classifyInstance(new double[]{1.0d, 3.0d}));
    }
}
