package edu.utexas.cs.tamerProject.modeling;

import edu.utexas.cs.tamerProject.featGen.FeatGenerator;
import java.util.ArrayList;

/* loaded from: input_file:edu/utexas/cs/tamerProject/modeling/WekaModelPerActionModel.class */
public class WekaModelPerActionModel extends RegressionModel {
    private final boolean verbose = false;
    WekaRegressor[] wMIs;
    int numAttributes;
    int[] actionFeatIndices;
    int[] numFeatValsPerFeatI;

    public WekaModelPerActionModel(String str, FeatGenerator featGenerator) {
        this.actionFeatIndices = featGenerator.getActionFeatIndices();
        this.numFeatValsPerFeatI = featGenerator.getNumFeatValsPerFeatI();
        this.numAttributes = this.numFeatValsPerFeatI.length - this.actionFeatIndices.length;
        this.featGen = featGenerator;
        int i = 1;
        for (int i2 = 0; i2 < this.actionFeatIndices.length; i2++) {
            i *= this.numFeatValsPerFeatI[this.actionFeatIndices[i2]];
        }
        this.wMIs = new WekaRegressor[i];
        for (int i3 = 0; i3 < i; i3++) {
            this.wMIs[i3] = new WekaRegressor(str, this.numAttributes);
        }
        System.out.println("Initiating model with " + this.numAttributes + " attributes.");
    }

    private int getActI(double[] dArr) {
        return FeatGenerator.getActIntIndex(getActFeats(dArr), FeatGenerator.possStaticActions);
    }

    private int[] getActFeats(double[] dArr) {
        int[] iArr = new int[this.actionFeatIndices.length];
        for (int i = 0; i < this.actionFeatIndices.length; i++) {
            iArr[i] = (int) dArr[this.actionFeatIndices[i]];
        }
        return iArr;
    }

    private double[] removeActFeats(double[] dArr) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < dArr.length; i++) {
            if (!inIntArray(this.actionFeatIndices, i)) {
                arrayList.add(new Double(dArr[i]));
            }
        }
        double[] dArr2 = new double[arrayList.size()];
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            dArr2[i2] = ((Double) arrayList.get(i2)).doubleValue();
        }
        return dArr2;
    }

    private boolean inIntArray(int[] iArr, int i) {
        boolean z = false;
        for (int i2 : iArr) {
            if (i == i2) {
                z = true;
            }
        }
        return z;
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void addInstance(Sample sample) {
        int actI = getActI(sample.feats);
        this.wMIs[actI].addInstance(new Sample(removeActFeats(sample.feats), sample.label, sample.weight, sample.unique));
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void addInstances(Sample[] sampleArr) {
        for (Sample sample : sampleArr) {
            addInstance(sample);
        }
    }

    public void addInstanceWReplacement(Sample sample) {
        int actI = getActI(sample.feats);
        this.wMIs[actI].addInstanceWReplacement(new Sample(removeActFeats(sample.feats), sample.label, sample.weight, sample.unique));
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void addInstancesWReplacement(Sample[] sampleArr) {
        for (Sample sample : sampleArr) {
            addInstanceWReplacement(sample);
        }
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void buildModel() {
        for (int i = 0; i < this.wMIs.length; i++) {
            this.wMIs[i].buildModel();
        }
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public double predictLabel(double[] dArr) {
        int actI = getActI(dArr);
        return this.wMIs[actI].classifyInstance(removeActFeats(dArr));
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void clearSamplesAndReset() {
        for (int i = 0; i < this.wMIs.length; i++) {
            this.wMIs[i].clearSamplesAndReset();
        }
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void changeClassifier() {
        for (int i = 0; i < this.wMIs.length; i++) {
            this.wMIs[i].changeClassifier();
        }
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void loadDataAsArff(String str, String str2, String str3) {
        for (int i = 0; i < this.wMIs.length; i++) {
            this.wMIs[i].loadDataAsArff(str, str2, String.valueOf(str3) + i);
        }
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void saveDataAsArff(String str, double d, String str2) {
        for (int i = 0; i < this.wMIs.length; i++) {
            this.wMIs[i].saveDataAsArff(str, d, String.valueOf(str2) + i);
        }
    }

    public static void main(String[] strArr) {
        new int[1][0] = 2;
        new int[3][2] = 3;
        WekaModelPerActionModel wekaModelPerActionModel = new WekaModelPerActionModel("", null);
        Sample sample = new Sample(new double[]{1.0d, 1.0d, 1.0d}, 20.0d, 1.0d, 1);
        wekaModelPerActionModel.addInstance(sample);
        wekaModelPerActionModel.addInstance(sample);
        wekaModelPerActionModel.addInstance(sample);
        Sample sample2 = new Sample(new double[]{0.0d, 1.0d, 1.0d}, 10.0d, 1.0d, 1);
        wekaModelPerActionModel.addInstance(sample2);
        wekaModelPerActionModel.addInstance(sample2);
        wekaModelPerActionModel.addInstance(new Sample(new double[]{1.0d, 0.0d, 1.0d}, 10.0d, 1.0d, 1));
        wekaModelPerActionModel.addInstance(new Sample(new double[]{0.0d, 0.0d, 2.0d}, 20.0d, 1.0d, 1));
        wekaModelPerActionModel.buildModel();
        System.out.println("predicted label: " + wekaModelPerActionModel.predictLabel(new double[]{0.0d, 1.0d, 1.0d}));
    }
}
