package edu.utexas.cs.tamerProject.experiments;

import edu.utexas.cs.tamerProject.actSelect.ActionSelect;
import edu.utexas.cs.tamerProject.agents.GeneralAgent;
import edu.utexas.cs.tamerProject.agents.combo.TamerRLAgent;
import edu.utexas.cs.tamerProject.agents.sarsaLambda.SarsaLForceContAgent;
import edu.utexas.cs.tamerProject.agents.sarsaLambda.SarsaLambdaAgent;
import edu.utexas.cs.tamerProject.agents.tamer.TamerAgent;
import edu.utexas.cs.tamerProject.applet.RunLocalExperiment;
import edu.utexas.cs.tamerProject.experiment.LogTrainer;
import edu.utexas.cs.tamerProject.experiment.RecordHandler;
import edu.utexas.cs.tamerProject.params.Params;
import java.util.Arrays;
import org.rlcommunity.environments.mountaincar.MountainCar;

/* loaded from: input_file:edu/utexas/cs/tamerProject/experiments/DiscFactorExp.class */
public class DiscFactorExp {
    static boolean debug = false;
    public String trainLogPath = "";
    public int logTrainEpochs = 1;
    private final int TRAIN_EPOCHS_FOR_LINEAR = 100;
    private boolean makeTaskContinuous = false;
    static final int DISC_PARAM_I = 1;
    static final int INIT_VALUE_I = 3;
    static final int TRAIN_PATH_I = 5;
    static final int TASK_CONT_I = 6;
    static final int EXP_NAME_I = 8;
    static final int CRED_TYPE_I = 10;

    public GeneralAgent createAgent(String[] strArr) {
        processArgs(strArr);
        return createTamerRLAgent(strArr);
    }

    public TamerRLAgent createTamerRLAgent(String[] strArr) {
        TamerRLAgent tamerRLAgent = new TamerRLAgent();
        if (this.makeTaskContinuous) {
            tamerRLAgent.rlAgent = new SarsaLForceContAgent();
        } else {
            tamerRLAgent.rlAgent = new SarsaLambdaAgent();
        }
        tamerRLAgent.tamerAgent = new TamerAgent();
        System.out.println("\n\nAgent in DiscFactorExp: " + tamerRLAgent.tamerAgent);
        String makeUnique = makeUnique(strArr);
        tamerRLAgent.setUnique("tamerrl%" + makeUnique);
        tamerRLAgent.rlAgent.setUnique("sarsa%" + makeUnique);
        tamerRLAgent.setRecordRew(true);
        tamerRLAgent.rlAgent.setRecordRew(true);
        tamerRLAgent.getClass();
        tamerRLAgent.COMBINATION_METHOD = -2;
        tamerRLAgent.envName = "Mountain-Car";
        tamerRLAgent.enableGUI = false;
        tamerRLAgent.tamerAgent.EP_END_PAUSE = 0;
        tamerRLAgent.params = Params.getParams(tamerRLAgent.getClass().getName(), tamerRLAgent.envName);
        tamerRLAgent.tamerAgent.params = Params.getParams(tamerRLAgent.tamerAgent.getClass().getName(), tamerRLAgent.envName);
        tamerRLAgent.rlAgent.params = Params.getParams(tamerRLAgent.rlAgent.getClass().getName(), tamerRLAgent.envName);
        setAgentParams(tamerRLAgent);
        processPreInitArgs(strArr, tamerRLAgent);
        return tamerRLAgent;
    }

    public void runOneExp(String[] strArr) {
        RunLocalExperiment runLocalExperiment = new RunLocalExperiment();
        setRunLocalExpOptions();
        GeneralAgent createAgent = createAgent(strArr);
        runLocalExperiment.theAgent = createAgent;
        runLocalExperiment.theEnvironment = new MountainCar();
        runLocalExperiment.init();
        runLocalExperiment.initExp();
        adjustAgentAfterItsInit(strArr, createAgent, this.logTrainEpochs, this.trainLogPath);
        System.out.println("About to start experiment");
        runLocalExperiment.startExp();
        while (!runLocalExperiment.expFinished) {
            GeneralAgent.sleep(10000.0d);
        }
    }

    public static void adjustAgentAfterItsInit(String[] strArr, GeneralAgent generalAgent, int i, String str) {
        processPostInitArgs(strArr, (TamerRLAgent) generalAgent);
        LogTrainer.trainOnLog(str, ((TamerRLAgent) generalAgent).tamerAgent, i);
        if (generalAgent.getInTrainSess()) {
            generalAgent.toggleInTrainSess();
        }
        System.out.println("tamer in training? " + ((TamerRLAgent) generalAgent).tamerAgent.getInTrainSess());
        ((TamerRLAgent) generalAgent).rlAgent.actSelector.setTreeSearchFlag(true);
    }

    public static String makeUnique(String[] strArr) {
        String[] split = strArr[5].split("/");
        String replace = split[split.length - 1].replace(".log", "").replace("recTraj-", "");
        System.out.println("logFileName: " + replace);
        return String.valueOf(strArr[1]) + "%" + strArr[3] + "%" + (strArr[6].equals("-makeTaskCont") ? "cont" : "epis") + "%" + replace;
    }

    public static void setRunLocalExpOptions() {
        RunLocalExperiment.numEpisodes = 4000;
        RunLocalExperiment.maxStepsPerEpisode = 400;
        RunLocalExperiment.stepDurInMilliSecs = 0.0d;
    }

    public void processArgs(String[] strArr) {
        System.out.println("\n[------process args in exp------] " + Arrays.toString(strArr));
        for (int i = 0; i < strArr.length; i++) {
            String str = strArr[i];
            if (str.equals("-trainLogPathInExp") && i + 1 < strArr.length) {
                this.trainLogPath = strArr[i + 1];
                System.out.println("this.trainLogPath set to: " + this.trainLogPath);
            } else if (str.equals("-makeTaskCont")) {
                this.makeTaskContinuous = true;
                System.out.println("forcing task to be continuous");
            }
        }
    }

    public void processPreInitArgs(String[] strArr, TamerRLAgent tamerRLAgent) {
        System.out.println("\n[------process pre-init args------] " + Arrays.toString(strArr));
        for (int i = 0; i < strArr.length; i++) {
            String str = strArr[i];
            if (str.equals("-initialValue") && i + 1 < strArr.length) {
                if (strArr[i + 1].equals("zero")) {
                    tamerRLAgent.rlAgent.params.initWtsValue = 0.0d;
                } else {
                    System.out.println("\nIllegal SarsaAgent initial values type. Exiting.\n\n");
                    System.exit(1);
                }
                System.out.println("Sarsa's Q-model's initial weights set to: " + tamerRLAgent.rlAgent.params.initWtsValue);
            } else if (str.equals("-tamerModel") && i + 1 < strArr.length) {
                if (strArr[i + 1].equals("linear")) {
                    System.out.println("Setting model to linear model");
                    tamerRLAgent.tamerAgent.params.featClass = "FeatGen_RBFs";
                    tamerRLAgent.tamerAgent.params.modelClass = "IncGDLinearModel";
                    tamerRLAgent.tamerAgent.params.featGenParams.put("basisFcnsPerDim", "40");
                    tamerRLAgent.tamerAgent.params.featGenParams.put("relWidth", "0.08");
                    tamerRLAgent.tamerAgent.params.featGenParams.put("biasFeatVal", "0.1");
                    tamerRLAgent.tamerAgent.params.featGenParams.put("normMin", "-1");
                    tamerRLAgent.tamerAgent.params.featGenParams.put("normMax", "1");
                    tamerRLAgent.tamerAgent.params.initModelWSamples = false;
                    tamerRLAgent.tamerAgent.params.initWtsValue = 0.0d;
                    tamerRLAgent.tamerAgent.params.stepSize = 0.001d;
                    this.logTrainEpochs = 100;
                } else if (strArr[i + 1].equals("kNN")) {
                    tamerRLAgent.tamerAgent.params.modelClass = "WekaModelPerActionModel";
                    tamerRLAgent.tamerAgent.params.featClass = "FeatGen_NoChange";
                    tamerRLAgent.tamerAgent.params.initModelWSamples = false;
                    tamerRLAgent.tamerAgent.params.numBiasingSamples = 100;
                    tamerRLAgent.tamerAgent.params.biasSampleWt = 0.1d;
                    tamerRLAgent.tamerAgent.params.wekaModelName = "IBk";
                } else {
                    System.out.println("\nIllegal TamerAgent model type. Exiting.\n\n");
                    System.exit(1);
                }
                System.out.println("agent model set to: " + strArr[i + 1]);
            } else if (str.equals("-credType") && i + 1 < strArr.length) {
                if (strArr[i + 1].equals("aggregate")) {
                    tamerRLAgent.tamerAgent.params.delayWtedIndivRew = false;
                    tamerRLAgent.tamerAgent.params.noUpdateWhenNoRew = false;
                } else if (strArr[i + 1].equals("aggregRewOnly")) {
                    tamerRLAgent.tamerAgent.params.delayWtedIndivRew = false;
                    tamerRLAgent.tamerAgent.params.noUpdateWhenNoRew = true;
                } else if (strArr[i + 1].equals("indivAlways")) {
                    tamerRLAgent.tamerAgent.params.delayWtedIndivRew = true;
                    tamerRLAgent.tamerAgent.params.noUpdateWhenNoRew = false;
                } else if (strArr[i + 1].equals("indivRewOnly")) {
                    tamerRLAgent.tamerAgent.params.delayWtedIndivRew = true;
                    tamerRLAgent.tamerAgent.params.noUpdateWhenNoRew = true;
                } else {
                    System.out.println("\nIllegal TamerAgent credit assignment type. Exiting.\n\n");
                    System.exit(1);
                }
                System.out.println("agent.credType set to: " + strArr[i + 1]);
            } else if (str.equals("-expName") && i + 1 < strArr.length) {
                tamerRLAgent.setExpName(strArr[i + 1]);
                tamerRLAgent.rlAgent.setExpName(strArr[i + 1]);
                tamerRLAgent.tamerAgent.setExpName(strArr[i + 1]);
                System.out.println("agent.expName set to: " + tamerRLAgent.getExpName());
            }
        }
    }

    public static void processPostInitArgs(String[] strArr, TamerRLAgent tamerRLAgent) {
        System.out.println("\n[------process post-init args------] " + Arrays.toString(strArr));
        for (int i = 0; i < strArr.length; i++) {
            String str = strArr[i];
            if (str.equals("-trainEpLimit") && i + 1 < strArr.length) {
                tamerRLAgent.tamerAgent.trainEpLimit = Integer.valueOf(strArr[i + 1]).intValue();
                System.out.println("agent.trainEpLimit set to: " + tamerRLAgent.tamerAgent.trainEpLimit);
            } else if (str.equals("-discountFactor") && i + 1 < strArr.length) {
                double doubleValue = Double.valueOf(strArr[i + 1]).doubleValue();
                tamerRLAgent.rlAgent.setDiscountFactorForLearning(doubleValue);
                tamerRLAgent.rlAgent.actSelector.setDiscountParam(ActionSelect.discFactorToParam(doubleValue));
                System.out.println("discount factor set to: " + doubleValue);
            }
        }
    }

    private static void setAgentParams(TamerRLAgent tamerRLAgent) {
        tamerRLAgent.params.extrapolateFutureRew = false;
        tamerRLAgent.rlAgent.params.featClass = "FeatGen_RBFs";
        tamerRLAgent.rlAgent.params.modelClass = "IncGDLinearModel";
        tamerRLAgent.rlAgent.params.initModelWSamples = false;
        tamerRLAgent.rlAgent.params.traceDecayFactor = 0.84d;
        tamerRLAgent.rlAgent.params.featGenParams.put("basisFcnsPerDim", "40");
        tamerRLAgent.rlAgent.params.featGenParams.put("relWidth", "0.08");
        tamerRLAgent.rlAgent.params.featGenParams.put("normMin", "-1");
        tamerRLAgent.rlAgent.params.featGenParams.put("normMax", "1");
        tamerRLAgent.rlAgent.params.featGenParams.put("biasFeatVal", "0.1");
        tamerRLAgent.rlAgent.params.initSampleValue = 0.0d;
        tamerRLAgent.rlAgent.params.numBiasingSamples = 0;
        tamerRLAgent.rlAgent.params.biasSampleWt = 0.5d;
        tamerRLAgent.rlAgent.params.modelAddsBiasFeat = false;
        tamerRLAgent.rlAgent.params.traceType = "replacing";
        tamerRLAgent.rlAgent.params.initWtsValue = 0.0d;
        tamerRLAgent.rlAgent.params.stepSize = 0.01d;
        tamerRLAgent.rlAgent.params.selectionMethod = "e-greedy";
        tamerRLAgent.rlAgent.params.selectionParams.put("epsilon", "0");
        tamerRLAgent.rlAgent.params.selectionParams.put("epsilonAnnealRate", "0.998");
        tamerRLAgent.rlAgent.params.selectionParams.put("treeSearch", "true");
        tamerRLAgent.rlAgent.params.selectionParams.put("greedyLeafPathLength", "0");
        tamerRLAgent.rlAgent.params.selectionParams.put("exhaustiveSearchDepth", "3");
        tamerRLAgent.rlAgent.params.selectionParams.put("randomizeSearchDepth", "true");
        tamerRLAgent.tamerAgent.params.distClass = "uniform";
        tamerRLAgent.tamerAgent.params.creditDelay = 0.2d;
        tamerRLAgent.tamerAgent.params.windowSize = 0.6d;
        tamerRLAgent.tamerAgent.params.extrapolateFutureRew = false;
        tamerRLAgent.tamerAgent.params.delayWtedIndivRew = false;
        tamerRLAgent.tamerAgent.params.noUpdateWhenNoRew = false;
        tamerRLAgent.tamerAgent.params.modelClass = "WekaModelPerActionModel";
        tamerRLAgent.tamerAgent.params.featClass = "FeatGen_NoChange";
        tamerRLAgent.tamerAgent.params.selectionMethod = "greedy";
        tamerRLAgent.tamerAgent.params.initModelWSamples = false;
        tamerRLAgent.tamerAgent.params.numBiasingSamples = 100;
        tamerRLAgent.tamerAgent.params.biasSampleWt = 0.1d;
        tamerRLAgent.tamerAgent.params.wekaModelName = "IBk";
        tamerRLAgent.tamerAgent.params.traceDecayFactor = 0.0d;
    }

    public static String[] getDebugArgsStrArray() {
        return new String[]{"-discountFactor", "0.1", "-initialValue", "zero", "-trainLogPathInExp", String.valueOf(RecordHandler.getPresentWorkingDir().replace("/bin", "")) + "/data/mc_tamer/recTraj-ikarpov-1228858017.78.log", "-makeTaskCont", "-expName", "test", "-credType", "aggregRewOnly", "-trainEpLimit", "20", "-tamerModel", "linear"};
    }

    public static void main(String[] strArr) {
        if (debug) {
            strArr = getDebugArgsStrArray();
        }
        new DiscFactorExp().runOneExp(strArr);
    }
}
