package edu.utexas.cs.tamerProject.featGen;

import edu.utexas.cs.tamerProject.experiment.RecordHandler;
import edu.utexas.cs.tamerProject.modeling.RegressionModel;
import edu.utexas.cs.tamerProject.utilities.Stopwatch;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import org.rlcommunity.rlglue.codec.types.Action;
import org.rlcommunity.rlglue.codec.types.Observation;

/* loaded from: input_file:edu/utexas/cs/tamerProject/featGen/FeatGen_RBFs.class */
public class FeatGen_RBFs extends FeatGenerator {
    private double relWidth;
    private double width;
    private int basisFcnsPerDim;
    private int numObsDims;
    ArrayList<double[]> means;
    double[] theObsRangeSizes;
    double[] dimDistNormFactor;
    boolean addBiasFeatPerAct;
    double biasFeatVal;
    boolean approxFeats;

    public FeatGen_RBFs(int[][] iArr, double[][] dArr, int[][] iArr2, double[][] dArr2, int i, double d) {
        super(iArr, dArr, iArr2, dArr2);
        this.addBiasFeatPerAct = false;
        this.biasFeatVal = 0.0d;
        this.approxFeats = true;
        this.basisFcnsPerDim = i;
        this.relWidth = d;
        double[] dArr3 = {0.0d, 1.0d};
        this.width = ((dArr3[1] - dArr3[0]) * this.relWidth) / (i - 1);
        System.out.println("width: " + this.width);
        this.numObsDims = iArr.length + dArr.length;
        this.means = getRBFMeans(getTheObsRangesAndSetNormalization(dArr3));
        this.numFeatures = this.means.size() * possStaticActions.size();
    }

    private double[][] getTheObsRangesAndSetNormalization(double[] dArr) {
        double[][] dArr2 = new double[this.numObsDims][2];
        this.theObsRangeSizes = new double[this.numObsDims];
        this.dimDistNormFactor = new double[this.numObsDims];
        int i = 0;
        for (int[] iArr : this.theObsIntRanges) {
            dArr2[i][0] = iArr[0];
            dArr2[i][1] = iArr[1];
            this.theObsRangeSizes[i] = iArr[1] - iArr[0];
            this.dimDistNormFactor[i] = (dArr[1] - dArr[0]) / this.theObsRangeSizes[i];
            i++;
        }
        for (double[] dArr3 : this.theObsDoubleRanges) {
            dArr2[i][0] = dArr3[0];
            dArr2[i][1] = dArr3[1];
            this.theObsRangeSizes[i] = dArr3[1] - dArr3[0];
            this.dimDistNormFactor[i] = (dArr[1] - dArr[0]) / this.theObsRangeSizes[i];
            i++;
        }
        return dArr2;
    }

    public void setNormBounds(double d, double d2) {
        System.out.print("Setting norm bounds to [" + d + ", " + d2 + "]. ");
        double[] dArr = {d, d2};
        this.width = ((dArr[1] - dArr[0]) * this.relWidth) / (this.basisFcnsPerDim - 1);
        System.out.println("**New width: " + this.width + "**");
        getTheObsRangesAndSetNormalization(dArr);
    }

    @Override // edu.utexas.cs.tamerProject.featGen.FeatGenerator
    public void setSupplModel(RegressionModel regressionModel, FeatGenerator featGenerator) {
        System.err.println("This method is not implemented in " + getClass() + ". Exiting.");
        System.exit(1);
    }

    private ArrayList<double[]> getRBFMeans(double[][] dArr) {
        return recurseForRBFMeans(dArr, new double[0]);
    }

    protected ArrayList<double[]> recurseForRBFMeans(double[][] dArr, double[] dArr2) {
        if (dArr2.length == dArr.length) {
            ArrayList<double[]> arrayList = new ArrayList<>();
            arrayList.add(dArr2);
            return arrayList;
        }
        int length = dArr2.length;
        ArrayList<double[]> arrayList2 = new ArrayList<>();
        for (int i = 0; i < this.basisFcnsPerDim; i++) {
            double d = (this.theObsRangeSizes[length] * (i / (this.basisFcnsPerDim - 1))) + dArr[length][0];
            double[] copyOf = Arrays.copyOf(dArr2, dArr2.length + 1);
            copyOf[length] = d;
            arrayList2.addAll(recurseForRBFMeans(dArr, copyOf));
        }
        return arrayList2;
    }

    public void setBiasFeatPerAct(double d) {
        this.addBiasFeatPerAct = true;
        this.biasFeatVal = d;
        this.numFeatures = (this.means.size() + (this.addBiasFeatPerAct ? 1 : 0)) * possStaticActions.size();
    }

    private double[] getStateVars(int[] iArr, double[] dArr) {
        double[] dArr2 = new double[this.numObsDims];
        int i = 0;
        for (int i2 : iArr) {
            dArr2[i] = i2;
            i++;
        }
        for (double d : dArr) {
            dArr2[i] = d;
            i++;
        }
        return dArr2;
    }

    @Override // edu.utexas.cs.tamerProject.featGen.FeatGenerator
    public double[] getSAFeats(Observation observation, Action action) {
        double[] dArr = new double[this.numFeatures];
        fillWithStateFeats(dArr, (this.means.size() + (this.addBiasFeatPerAct ? 1 : 0)) * getActIntIndex(action.intArray), observation.intArray, observation.doubleArray);
        return dArr;
    }

    private void fillWithStateFeats(double[] dArr, int i, int[] iArr, double[] dArr2) {
        double[] stateVars = getStateVars(iArr, dArr2);
        int i2 = i;
        if (this.approxFeats) {
            Iterator<double[]> it = this.means.iterator();
            while (it.hasNext()) {
                dArr[i2] = exp(((-0.5d) * getSqrdEucDist(it.next(), stateVars)) / this.width);
                i2++;
            }
        } else {
            Iterator<double[]> it2 = this.means.iterator();
            while (it2.hasNext()) {
                dArr[i2] = Math.exp(((-0.5d) * getSqrdEucDist(it2.next(), stateVars)) / this.width);
                i2++;
            }
        }
        if (this.addBiasFeatPerAct) {
            dArr[i2] = this.biasFeatVal;
        }
    }

    @Override // edu.utexas.cs.tamerProject.featGen.FeatGenerator
    public double[] getSFeats(Observation observation) {
        double[] dArr = new double[this.means.size() + (this.addBiasFeatPerAct ? 1 : 0)];
        fillWithStateFeats(dArr, 0, observation.intArray, observation.doubleArray);
        return dArr;
    }

    private double getSqrdEucDist(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            double d2 = (dArr2[i] - dArr[i]) * this.dimDistNormFactor[i];
            d += d2 * d2;
        }
        return d;
    }

    @Override // edu.utexas.cs.tamerProject.featGen.FeatGenerator
    public int[] getNumFeatValsPerFeatI() {
        System.err.println("This method is not implemented in " + getClass() + ". Exiting.");
        System.exit(1);
        return new int[0];
    }

    @Override // edu.utexas.cs.tamerProject.featGen.FeatGenerator
    public int[] getActionFeatIndices() {
        System.err.println("This method is not implemented in " + getClass() + ". Exiting.");
        System.exit(1);
        return new int[0];
    }

    public double[] getSSFeats(int[] iArr, double[] dArr, int[] iArr2, double[] dArr2) {
        System.err.println("This method is not implemented in " + getClass() + ". Exiting.");
        System.exit(1);
        return new double[0];
    }

    @Override // edu.utexas.cs.tamerProject.featGen.FeatGenerator
    public double[] getMaxPossFeats() {
        double[] dArr = new double[this.numFeatures];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = 1.0d;
        }
        if (this.addBiasFeatPerAct) {
            for (int i2 = 0; i2 < FeatGenerator.possStaticActions.size(); i2++) {
                dArr[((this.means.size() + 1) * (i2 + 1)) - 1] = this.biasFeatVal;
            }
        }
        return dArr;
    }

    @Override // edu.utexas.cs.tamerProject.featGen.FeatGenerator
    public double[] getMinPossFeats() {
        double[] dArr = new double[this.numFeatures];
        if (this.addBiasFeatPerAct) {
            for (int i = 0; i < FeatGenerator.possStaticActions.size(); i++) {
                dArr[((this.means.size() + 1) * (i + 1)) - 1] = this.biasFeatVal;
            }
        }
        return dArr;
    }

    @Override // edu.utexas.cs.tamerProject.featGen.FeatGenerator
    public double[] getMaxPossSFeats() {
        double[] dArr = new double[this.means.size() + (this.addBiasFeatPerAct ? 1 : 0)];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = 1.0d;
        }
        if (this.addBiasFeatPerAct) {
            dArr[this.means.size()] = this.biasFeatVal;
        }
        return dArr;
    }

    @Override // edu.utexas.cs.tamerProject.featGen.FeatGenerator
    public double[] getMinPossSFeats() {
        double[] dArr = new double[this.means.size() + (this.addBiasFeatPerAct ? 1 : 0)];
        if (this.addBiasFeatPerAct) {
            dArr[this.means.size()] = this.biasFeatVal;
        }
        return dArr;
    }

    public static void main(String[] strArr) {
        Observation observation = new Observation();
        Action action = new Action();
        int[][] iArr = new int[0][0];
        double[][] dArr = new double[2][2];
        dArr[0][0] = 0.0d;
        dArr[0][1] = 2.0d;
        dArr[1][0] = 0.0d;
        dArr[1][1] = 2.0d;
        System.out.println("theObsDoubleRanges: " + Arrays.toString(dArr));
        int[][] iArr2 = new int[1][2];
        iArr2[0][0] = 0;
        iArr2[0][1] = 1;
        double[][] dArr2 = new double[0][0];
        FeatGen_RBFs featGen_RBFs = new FeatGen_RBFs(iArr, dArr, iArr2, dArr2, 2, 0.08d);
        featGen_RBFs.setNormBounds(-1.0d, 1.0d);
        observation.intArray = new int[0];
        double[] dArr3 = {0.8d, 0.8d};
        observation.doubleArray = dArr3;
        action.intArray = r0;
        action.doubleArray = null;
        System.out.println("Input: " + Arrays.toString(dArr3) + ", " + Arrays.toString(r0));
        System.out.println("Feats: " + Arrays.toString(featGen_RBFs.getSAFeats(observation, action)) + "\n");
        observation.doubleArray[0] = 1.2d;
        observation.doubleArray[1] = 1.0d;
        int[] iArr3 = {1};
        System.out.println("Input: " + Arrays.toString(dArr3) + ", " + Arrays.toString(iArr3));
        System.out.println("Feats: " + Arrays.toString(featGen_RBFs.getSAFeats(observation, action)) + "\n");
        Random random = new Random(21L);
        Stopwatch stopwatch = new Stopwatch();
        Observation[] observationArr = new Observation[1000];
        for (int i = 0; i < 1000; i++) {
            observationArr[i] = new Observation();
            observationArr[i].doubleArray = new double[2];
            observationArr[i].doubleArray[0] = random.nextDouble() * dArr[0][1];
            observationArr[i].doubleArray[1] = random.nextDouble() * dArr[1][1];
        }
        boolean z = featGen_RBFs.approxFeats;
        featGen_RBFs.approxFeats = false;
        stopwatch.startTimer();
        featGen_RBFs.approxFeats = true;
        stopwatch.startTimer();
        for (int i2 = 0; i2 < 5000000; i2++) {
            featGen_RBFs.getSAFeats(observationArr[i2 % 1000], action);
        }
        System.out.println("Time for " + (2 * 5000000) + " approximate RBF evaluations: " + stopwatch.getTimeElapsed());
        featGen_RBFs.approxFeats = z;
        System.out.println("\n\nPython TAMER's Mountain car features and model test");
        double[] dArr4 = (double[]) null;
        try {
            dArr4 = RecordHandler.getDoubleArrayFromStr(RecordHandler.getStrArray("/Users/bradknox/projects/rl-library-data/mc_tamer/models/ikarpov-1228858017.78-100.model")[0]);
        } catch (Exception e) {
            System.err.println("Error: " + e.getMessage() + "\nExiting.");
            System.err.println(Arrays.toString(Thread.currentThread().getStackTrace()));
            System.exit(0);
        }
        System.out.println("\nLoading saved features from Python code for comparison.");
        double[] dArr5 = (double[]) null;
        try {
            dArr5 = RecordHandler.getDoubleArrayFromStr(RecordHandler.getStrArray("/Users/bradknox/projects/rl-library-data/mc_tamer/models/feats.python")[0]);
        } catch (Exception e2) {
            System.err.println("Error: " + e2.getMessage() + "\nExiting.");
            System.err.println(Arrays.toString(Thread.currentThread().getStackTrace()));
            System.exit(0);
        }
        System.out.println("model weights: " + Arrays.toString(dArr4));
        dArr[0][0] = -1.2d;
        dArr[0][1] = 0.6d;
        dArr[1][0] = -0.07d;
        dArr[1][1] = 0.07d;
        int[][] iArr4 = new int[1][2];
        iArr4[0][0] = 0;
        iArr4[0][1] = 2;
        System.out.println("Creating new FeatGen_RBFs instance.");
        FeatGen_RBFs featGen_RBFs2 = new FeatGen_RBFs(iArr, dArr, iArr4, dArr2, 40, 0.08d);
        System.out.println("Created.");
        featGen_RBFs2.setNormBounds(-1.0d, 1.0d);
        featGen_RBFs2.setBiasFeatPerAct(0.1d);
        observation.doubleArray[0] = 0.0d;
        observation.doubleArray[1] = 0.0d;
        action.intArray[0] = 0;
        System.out.println("Input: " + Arrays.toString(dArr3) + ", " + Arrays.toString(iArr3));
        double[] sAFeats = featGen_RBFs2.getSAFeats(observation, action);
        double d = 0.0d;
        for (int i3 = 0; i3 < dArr4.length; i3++) {
            d += sAFeats[i3] * dArr4[i3];
        }
        System.out.println("Model output: " + d);
        observation.doubleArray[0] = 0.1d;
        observation.doubleArray[1] = 0.01d;
        action.intArray[0] = 1;
        System.out.println("Input: " + Arrays.toString(dArr3) + ", " + Arrays.toString(iArr3));
        double[] sAFeats2 = featGen_RBFs2.getSAFeats(observation, action);
        System.out.println("Feats: " + Arrays.toString(sAFeats2) + "\n");
        System.out.println("Num feats: " + sAFeats2.length);
        System.out.println("width: " + featGen_RBFs2.width);
        System.out.println("Num weights: " + dArr4.length);
        double d2 = 0.0d;
        for (int i4 = 0; i4 < dArr4.length; i4++) {
            if (!areAlmostTheSame(sAFeats2[i4], dArr5[i4])) {
                System.out.println("Mistmatch at index " + i4 + ". Python: " + dArr5[i4] + ". Java: " + sAFeats2[i4]);
            }
            d2 += sAFeats2[i4] * dArr4[i4];
        }
        System.out.println("Model output: " + d2);
    }

    private static boolean areAlmostTheSame(double d, double d2) {
        return d2 == 0.0d ? d == 0.0d : Math.abs((d / d2) - 1.0d) < 1.0E-4d;
    }

    public static double exp(double d) {
        if (d < -709.0d) {
            return 0.0d;
        }
        if (d > 709.0d) {
            return Double.MAX_VALUE;
        }
        return Double.longBitsToDouble(((long) ((1512775.0d * d) + 1.072632447E9d)) << 32);
    }
}
