package edu.utexas.cs.tamerProject.params;

import edu.utexas.cs.tamerProject.agents.GeneralAgent;
import edu.utexas.cs.tamerProject.agents.sarsaLambda.SarsaLambdaAgent;
import java.util.Arrays;
import java.util.Random;
import org.rlcommunity.rlglue.codec.AgentInterface;
import org.rlcommunity.rlglue.codec.taskspec.TaskSpec;
import org.rlcommunity.rlglue.codec.types.Action;
import org.rlcommunity.rlglue.codec.types.Observation;
import org.rlcommunity.rlglue.codec.util.AgentLoader;

/* loaded from: input_file:edu/utexas/cs/tamerProject/params/ParamSearchAgent.class */
public class ParamSearchAgent implements AgentInterface {
    private double[][] PARAM_RANGES;
    private int NUM_TEST_VALS;
    private double[] bestObservedParams;
    private double[] currParams;
    private int RUNS_PER_PARAM_SET;
    protected SarsaLambdaAgent rlAgent;
    protected SarsaLambdaAgent rescueRLAgent;
    private Params agentParams;
    private String taskSpec;
    private double[] paramValsToTest;
    private double[] rewSums;
    private final double[][] PARAM_RANGES_MC = {new double[]{0.001d, 0.2d}, new double[]{0.5d, 1.0d}, new double[]{0.0d, 0.6d}, new double[]{15.0d, 41.0d}, new double[]{0.01d, 1.0d}};
    private final double[][] PARAM_RANGES_CARTPOLE = {new double[]{0.001d, 0.2d}, new double[]{0.5d, 1.0d}, new double[]{0.0d, 0.6d}, new double[]{5.0d, 12.0d}, new double[]{0.01d, 0.25d}};
    private final double[][] PARAM_RANGES_ACROBOT = {new double[]{0.001d, 0.2d}, new double[]{0.5d, 1.0d}, new double[]{0.0d, 0.6d}, new double[]{5.0d, 12.0d}, new double[]{0.01d, 0.25d}};
    private int[] PARAM_TEST_LEVELS = {2, 2, 2, 1, 2};
    private double bestRewSum = -1.0E11d;
    private int EPS_PER_RUN;
    private double rewFloor = ((-130) * this.EPS_PER_RUN) - 4000;
    private int numEpsFinshedForRun = 0;
    private int numRunsFinishedForParamSet = 0;
    private int currLevel = 0;
    private int lastParamI = -1;
    private int currParamI = -1;
    private int numLevelsFinished = 0;
    private int currParamTestValI = 0;
    private Random randGen = new Random();

    private void changeParamVals() {
        this.currParams = new double[this.bestObservedParams.length];
        for (int i = 0; i < this.bestObservedParams.length; i++) {
            this.currParams[i] = this.bestObservedParams[i];
        }
        this.currParams[this.currParamI] = this.paramValsToTest[this.currParamTestValI];
        System.out.println("Changed param " + this.currParamI + " to " + this.currParams[this.currParamI]);
        this.agentParams.stepSize = this.currParams[0];
        this.agentParams.traceDecayFactor = this.currParams[1];
        this.agentParams.selectionParams.put("epsilon", new StringBuilder().append(this.currParams[2]).toString());
        this.agentParams.featGenParams.put("basisFcnsPerDim", new StringBuilder().append((int) this.currParams[3]).toString());
        this.agentParams.featGenParams.put("relWidth", new StringBuilder().append(this.currParams[4]).toString());
    }

    private double[] getCurrParamVec() {
        double[] dArr = new double[this.PARAM_RANGES.length];
        dArr[0] = this.agentParams.stepSize;
        dArr[1] = this.agentParams.traceDecayFactor;
        dArr[2] = Double.valueOf(this.agentParams.selectionParams.get("epsilon")).doubleValue();
        dArr[3] = Double.valueOf(this.agentParams.featGenParams.get("basisFcnsPerDim")).doubleValue();
        dArr[4] = Double.valueOf(this.agentParams.featGenParams.get("relWidth")).doubleValue();
        return dArr;
    }

    @Override // org.rlcommunity.rlglue.codec.AgentInterface
    public void agent_init(String str) {
        this.taskSpec = str;
        String envName = GeneralAgent.getEnvName(new TaskSpec(str).getExtraString());
        if (envName.equals("Mountain-Car")) {
            this.PARAM_RANGES = this.PARAM_RANGES_MC;
            this.RUNS_PER_PARAM_SET = 5;
            this.EPS_PER_RUN = 500;
        } else if (envName.equals("CartPole")) {
            this.PARAM_RANGES = this.PARAM_RANGES_CARTPOLE;
            this.RUNS_PER_PARAM_SET = 2;
            this.EPS_PER_RUN = 150;
        } else if (envName.equals("Acrobot")) {
            this.PARAM_RANGES = this.PARAM_RANGES_ACROBOT;
            this.RUNS_PER_PARAM_SET = 2;
            this.EPS_PER_RUN = 500;
        } else {
            this.PARAM_RANGES = this.PARAM_RANGES_CARTPOLE;
            this.RUNS_PER_PARAM_SET = 2;
            this.EPS_PER_RUN = 100;
        }
        this.rewSums = new double[this.PARAM_RANGES.length];
        this.NUM_TEST_VALS = this.PARAM_RANGES.length;
        this.rescueRLAgent = new SarsaLambdaAgent();
        this.rescueRLAgent.agent_init(str);
        this.rlAgent = new SarsaLambdaAgent();
        this.rlAgent.agent_init(str);
        this.agentParams = this.rlAgent.params;
        this.currParamI = 4;
        System.out.println("\n\nthis.currParamI: " + this.currParamI);
        setNewTestVals(this.PARAM_RANGES[this.currParamI]);
        System.out.println("this.PARAM_RANGES[this.currParamI]: " + Arrays.toString(this.PARAM_RANGES[this.currParamI]));
        this.bestObservedParams = getCurrParamVec();
        System.out.println("Best observed params: " + Arrays.toString(this.bestObservedParams));
        changeParamVals();
        System.out.println("Current params: " + Arrays.toString(this.currParams));
        this.rlAgent = new SarsaLambdaAgent();
        this.rlAgent.params = this.agentParams;
        this.rlAgent.agent_init(str);
    }

    @Override // org.rlcommunity.rlglue.codec.AgentInterface
    public Action agent_start(Observation observation) {
        this.rescueRLAgent.agent_start(observation);
        return this.rlAgent.agent_start(observation);
    }

    @Override // org.rlcommunity.rlglue.codec.AgentInterface
    public Action agent_step(double d, Observation observation) {
        if (this.rlAgent.rewThisEp >= -5000.0d) {
            return this.rlAgent.agent_step(d, observation);
        }
        System.out.print("rescue" + this.rlAgent.stepsThisEp);
        return this.rescueRLAgent.agent_step(d, observation);
    }

    @Override // org.rlcommunity.rlglue.codec.AgentInterface
    public void agent_end(double d) {
        if (this.rlAgent.rewThisEp < -5000.0d) {
            this.rescueRLAgent.agent_end(d);
            this.rlAgent.totalRew += this.rewFloor;
        } else {
            this.rlAgent.agent_end(d);
        }
        this.numEpsFinshedForRun++;
        if (this.rlAgent.totalRew < this.rewFloor) {
            System.out.println("REW_FLOOR: " + this.rewFloor);
            System.out.println("Total reward: " + this.rlAgent.totalRew);
        }
        if (this.numEpsFinshedForRun % this.EPS_PER_RUN == 0 || this.rlAgent.totalRew < this.rewFloor) {
            this.numRunsFinishedForParamSet++;
            runEnd();
            this.numEpsFinshedForRun = 0;
        }
    }

    private void runEnd() {
        int i;
        double[] dArr = this.rewSums;
        int i2 = this.currParamTestValI;
        dArr[i2] = dArr[i2] + this.rlAgent.totalRew;
        System.out.println("rewSums: " + Arrays.toString(this.rewSums));
        if (this.numRunsFinishedForParamSet % this.RUNS_PER_PARAM_SET == 0) {
            this.numRunsFinishedForParamSet = 0;
            this.currParamTestValI++;
            if (this.currParamTestValI >= this.paramValsToTest.length) {
                this.currParamTestValI = 0;
                int indexOfMaxVal = getIndexOfMaxVal(this.rewSums);
                double d = this.rewSums[indexOfMaxVal];
                this.rewSums = new double[this.PARAM_RANGES.length];
                this.numLevelsFinished++;
                double[] dArr2 = new double[2];
                if (this.numLevelsFinished >= this.PARAM_TEST_LEVELS[this.currParamI]) {
                    System.out.print("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n");
                    if (d > this.bestRewSum) {
                        this.bestObservedParams[this.currParamI] = this.paramValsToTest[indexOfMaxVal];
                        this.bestRewSum = d;
                        System.out.println("\n\n\n\n\nNew best reward sum: " + d + "\n\n\n");
                        if (this.bestRewSum < 0.0d) {
                            this.rewFloor = (this.bestRewSum / this.RUNS_PER_PARAM_SET) * 1.3d;
                        }
                    }
                    int nextInt = this.randGen.nextInt(this.NUM_TEST_VALS);
                    while (true) {
                        i = nextInt;
                        if (i != this.currParamI && i != this.lastParamI) {
                            break;
                        } else {
                            nextInt = this.randGen.nextInt(this.NUM_TEST_VALS);
                        }
                    }
                    this.lastParamI = this.currParamI;
                    this.currParamI = i;
                    dArr2[0] = this.PARAM_RANGES[this.currParamI][0];
                    dArr2[1] = this.PARAM_RANGES[this.currParamI][1];
                    this.numLevelsFinished = 0;
                } else if (indexOfMaxVal == 0) {
                    dArr2[0] = this.paramValsToTest[indexOfMaxVal];
                    dArr2[1] = this.paramValsToTest[indexOfMaxVal] + ((this.paramValsToTest[indexOfMaxVal + 1] - this.paramValsToTest[indexOfMaxVal]) * (this.NUM_TEST_VALS / (this.NUM_TEST_VALS + 1.0d)));
                } else if (indexOfMaxVal == this.NUM_TEST_VALS - 1) {
                    dArr2[0] = this.paramValsToTest[indexOfMaxVal - 1] + ((this.paramValsToTest[indexOfMaxVal] - this.paramValsToTest[indexOfMaxVal - 1]) * (1.0d / (this.NUM_TEST_VALS + 1.0d)));
                    dArr2[1] = this.paramValsToTest[this.NUM_TEST_VALS - 1];
                } else {
                    dArr2[0] = this.paramValsToTest[indexOfMaxVal - 1] + ((this.paramValsToTest[indexOfMaxVal] - this.paramValsToTest[indexOfMaxVal - 1]) * (1.0d / (this.NUM_TEST_VALS + 2.0d)));
                    dArr2[1] = this.paramValsToTest[indexOfMaxVal] + ((this.paramValsToTest[indexOfMaxVal + 1] - this.paramValsToTest[indexOfMaxVal]) * ((this.NUM_TEST_VALS + 1.0d) / (this.NUM_TEST_VALS + 2.0d)));
                }
                setNewTestVals(dArr2);
            }
        }
        System.out.println("\n\nBest observed params: " + Arrays.toString(this.bestObservedParams));
        changeParamVals();
        System.out.println("this.getCurrParamVec(): " + Arrays.toString(getCurrParamVec()));
        System.out.println("Current params: " + Arrays.toString(this.currParams));
        this.rlAgent = new SarsaLambdaAgent();
        this.rlAgent.params = this.agentParams;
        this.rlAgent.agent_init(this.taskSpec);
    }

    private static int getIndexOfMaxVal(double[] dArr) {
        int i = 0;
        System.out.println("starting max at 0: " + dArr[0]);
        for (int i2 = 1; i2 < dArr.length; i2++) {
            if (dArr[i2] > dArr[i]) {
                i = i2;
                System.out.println("new max at " + i2 + ": " + dArr[i]);
            } else {
                System.out.println("val at " + i2 + " not a new max: " + dArr[i2]);
            }
        }
        return i;
    }

    private void setNewTestVals(double[] dArr) {
        this.paramValsToTest = new double[this.NUM_TEST_VALS];
        for (int i = 0; i < this.NUM_TEST_VALS; i++) {
            this.paramValsToTest[i] = dArr[0] + ((dArr[1] - dArr[0]) * (i / (this.NUM_TEST_VALS - 1)));
        }
        System.out.println("this.paramValsToTest: " + Arrays.toString(this.paramValsToTest));
    }

    @Override // org.rlcommunity.rlglue.codec.AgentInterface
    public String agent_message(String str) {
        return null;
    }

    @Override // org.rlcommunity.rlglue.codec.AgentInterface
    public void agent_cleanup() {
        this.rlAgent.agent_cleanup();
    }

    public static void main(String[] strArr) {
        new AgentLoader(new ParamSearchAgent()).run();
    }
}
