package edu.utexas.cs.tamerProject.modeling;

import edu.utexas.cs.tamerProject.featGen.FeatGenerator;
import edu.utexas.cs.tamerProject.utilities.MutableDouble;
import java.util.Arrays;

/* loaded from: input_file:edu/utexas/cs/tamerProject/modeling/IncGDLinearModel.class */
public class IncGDLinearModel extends IncModel {
    private double stepSize;
    private MutableDouble discountFactor;
    private double[] weights;
    private double[] complSampleWts;
    private double biasWt;
    public double complSampleBiasWt;
    private double[] traces;
    private boolean useBiasWt;
    private double decayFactor = 0.0d;
    private String traceStyle = "replacing";
    final double APPROX_ONE = 0.99999d;

    public IncGDLinearModel(int i, double d, FeatGenerator featGenerator, double d2, boolean z) {
        this.stepSize = 1.0d;
        this.biasWt = 0.0d;
        this.useBiasWt = false;
        this.weights = new double[i];
        this.complSampleWts = new double[i];
        this.traces = new double[i];
        clearSamplesAndReset();
        this.stepSize = d;
        this.featGen = featGenerator;
        for (int i2 = 0; i2 < this.weights.length; i2++) {
            this.weights[i2] = d2;
            this.complSampleWts[i2] = d2;
        }
        this.useBiasWt = z;
        this.biasWt = d2;
        this.complSampleBiasWt = d2;
    }

    public double[] getWeights() {
        return this.weights;
    }

    public void setWeights(double[] dArr) {
        this.weights = (double[]) dArr.clone();
    }

    public void setComplSampleWts(double[] dArr) {
        this.complSampleWts = (double[]) dArr.clone();
    }

    public void setTraces(double[] dArr) {
        this.traces = (double[]) dArr.clone();
    }

    public void setDiscountFactor(double d) {
        this.discountFactor = new MutableDouble(d);
    }

    public void setDiscountFactor(MutableDouble mutableDouble) {
        this.discountFactor = mutableDouble;
    }

    public IncGDLinearModel duplicate() {
        IncGDLinearModel incGDLinearModel = new IncGDLinearModel(this.weights.length, this.stepSize, this.featGen, 0.0d, this.useBiasWt);
        incGDLinearModel.setEligTraceParams(this.decayFactor, this.discountFactor, this.traceStyle);
        incGDLinearModel.setWeights(this.weights);
        incGDLinearModel.setComplSampleWts(this.complSampleWts);
        incGDLinearModel.setTraces(this.traces);
        return incGDLinearModel;
    }

    public void setEligTraceParams(double d, MutableDouble mutableDouble, String str) {
        this.decayFactor = d;
        this.discountFactor = mutableDouble;
        this.traceStyle = str;
    }

    private void gradDescUpdate(Sample sample, double d) {
        if (this.verbose) {
            System.out.println("---- before -----");
            System.out.println("this.weights:" + Arrays.toString(this.weights));
            System.out.println("this.complSampleWts:" + Arrays.toString(this.complSampleWts));
            System.out.println("this.biasWt:" + this.biasWt);
            System.out.println("this.complSampleBiasWt:" + this.complSampleBiasWt);
            System.out.println("Prediction before update: " + predictLabel(sample.feats));
            System.out.println("Prediction augmentation: " + d);
        }
        double predictLabel = predictLabel(sample.feats) + d;
        updateEligTraces(sample.feats);
        double d2 = sample.label - predictLabel;
        double d3 = this.stepSize * d2 * sample.weight;
        if (this.verbose) {
            System.out.println("Label: " + sample.label);
            System.out.println("Error: " + d2);
            System.out.println("wtedErr: " + d3);
        }
        double d4 = 0.0d;
        for (double d5 : this.traces) {
            d4 += d5;
        }
        double d6 = sample.weight;
        for (int i = 0; i < this.complSampleWts.length; i++) {
            double[] dArr = this.weights;
            int i2 = i;
            dArr[i2] = dArr[i2] + (this.traces[i] * d3);
            if (Double.isInfinite(this.weights[i])) {
                System.err.println("weight is infinite from trace and err: " + this.traces[i] + ", " + d3);
            }
        }
        if (this.useBiasWt) {
            if (this.weights != this.complSampleWts) {
                this.biasWt += d3;
            } else {
                this.complSampleBiasWt += d3;
                this.biasWt = this.complSampleBiasWt;
            }
        }
    }

    @Override // edu.utexas.cs.tamerProject.modeling.IncModel, edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void addInstance(Sample sample) {
        addInstance(sample, 0.0d);
    }

    @Override // edu.utexas.cs.tamerProject.modeling.IncModel
    public void addInstance(Sample sample, double d) {
        this.weights = this.complSampleWts;
        this.biasWt = this.complSampleBiasWt;
        gradDescUpdate(sample, d);
        if (this.verbose) {
            System.out.println("Prediction after update: " + predictLabel(sample.feats));
            System.out.println("---- after -----");
            System.out.println("this.weights:" + Arrays.toString(this.weights));
            System.out.println("this.complSampleWts:" + Arrays.toString(this.complSampleWts));
            System.out.println("this.biasWt:" + this.biasWt);
            System.out.println("this.complSampleBiasWt:" + this.complSampleBiasWt);
        }
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void addBiasInstance(Sample sample) {
        double d = this.stepSize;
        this.stepSize = 1.0d;
        addInstance(sample);
        this.stepSize = d;
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void addInstances(Sample[] sampleArr) {
        for (Sample sample : sampleArr) {
            addInstance(sample);
        }
    }

    public void addInstanceWReplacement(Sample sample) {
        System.err.println("addInstanceWReplacement is not supported for IncGDLinearModel. Exiting.");
        System.exit(1);
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void addInstancesWReplacement(Sample[] sampleArr) {
        for (int i = 0; i < sampleArr.length; i++) {
            if (sampleArr[i].usedCredit > 0.99999d || sampleArr[i].unique == -1) {
                addInstance(sampleArr[i]);
            }
        }
        this.weights = Arrays.copyOf(this.complSampleWts, this.complSampleWts.length);
        this.biasWt = this.complSampleBiasWt;
        for (int i2 = 0; i2 < sampleArr.length; i2++) {
            if (sampleArr[i2].usedCredit <= 0.99999d && sampleArr[i2].weight != 0.0d && sampleArr[i2].unique != -1) {
                if (this.decayFactor != 0.0d) {
                    System.err.println(String.valueOf(getClass().getName()) + " does not support both eligibility traces and temporary samples. Exiting.");
                    System.err.println("traceDecayFactor: " + this.decayFactor);
                    System.exit(1);
                }
                gradDescUpdate(sampleArr[i2], 0.0d);
            }
        }
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void buildModel() {
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public double predictLabel(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < this.weights.length; i++) {
            d += this.weights[i] * dArr[i];
        }
        if (this.useBiasWt) {
            d += this.biasWt;
        }
        if (Double.isInfinite(d)) {
            System.err.println(String.valueOf(getClass().getSimpleName()) + " calculating infinite prediction.");
        }
        return d;
    }

    public void setModelParams(double[] dArr) {
        if (this.complSampleWts.length != dArr.length) {
            System.err.println("Mismatch in number of complSampleWts in IncGDLinearModel.setModelParams(). Exiting");
            System.exit(1);
        }
        for (int i = 0; i < this.complSampleWts.length; i++) {
            this.complSampleWts[i] = dArr[i];
        }
        this.weights = this.complSampleWts;
    }

    @Override // edu.utexas.cs.tamerProject.modeling.RegressionModel
    public void clearSamplesAndReset() {
        for (int i = 0; i < this.complSampleWts.length; i++) {
            this.complSampleWts[i] = 0.0d;
            this.traces[i] = 0.0d;
        }
        this.weights = this.complSampleWts;
    }

    private void updateEligTraces(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            double[] dArr2 = this.traces;
            int i2 = i;
            dArr2[i2] = dArr2[i2] * this.decayFactor * this.discountFactor.getValue();
            if (this.decayFactor * this.discountFactor.getValue() == 0.0d || this.traceStyle.equals("accumulating")) {
                double[] dArr3 = this.traces;
                int i3 = i;
                dArr3[i3] = dArr3[i3] + dArr[i];
            } else if (this.traceStyle.equals("replacing")) {
                this.traces[i] = Math.max(dArr[i], this.traces[i]);
            } else {
                System.err.println("Trace style " + this.traceStyle + "  is not supported in " + getClass() + ". Exiting.");
                System.exit(0);
            }
        }
    }

    public static void main(String[] strArr) {
        IncGDLinearModel incGDLinearModel = new IncGDLinearModel(3, 0.2d, null, 0.0d, true);
        Sample sample = new Sample(new double[]{1.0d, 1.0d, 1.0d}, 20.0d, 1.0d, 1);
        sample.usedCredit = 1.0d;
        Sample sample2 = new Sample(new double[]{0.0d, 1.0d, 1.0d}, 10.0d, 1.0d, 1);
        sample2.usedCredit = 1.0d;
        Sample[] sampleArr = {sample2, sample};
        double[] dArr = {1.0d, 1.0d, 1.0d};
        for (int i = 0; i < 24; i++) {
            incGDLinearModel.addInstancesWReplacement(sampleArr);
            System.out.println("predicted label: " + incGDLinearModel.predictLabel(dArr));
        }
        System.out.println("predicted label: " + incGDLinearModel.predictLabel(new double[]{0.0d, 1.0d, 1.0d}));
        System.out.println("predicted label: " + incGDLinearModel.predictLabel(dArr));
        System.out.println("predicted label: " + incGDLinearModel.predictLabel(new double[]{0.0d, 1.0d, 1.0d}));
    }
}
