package edu.utexas.cs.tamerProject.agents.hyperbolic;

import edu.utexas.cs.tamerProject.actSelect.ActionSelect;
import edu.utexas.cs.tamerProject.agents.GeneralAgent;
import edu.utexas.cs.tamerProject.agents.tamer.HRew;
import edu.utexas.cs.tamerProject.agents.tamer.TamerAgent;
import edu.utexas.cs.tamerProject.experiment.RecordHandler;
import edu.utexas.cs.tamerProject.modeling.IncGDLinearModel;
import edu.utexas.cs.tamerProject.modeling.Sample;
import edu.utexas.cs.tamerProject.params.Params;
import edu.utexas.cs.tamerProject.utilities.MutableDouble;
import edu.utexas.cs.tamerProject.utilities.Stopwatch;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import org.rlcommunity.rlglue.codec.types.Action;
import org.rlcommunity.rlglue.codec.types.Observation;

/* loaded from: input_file:edu/utexas/cs/tamerProject/agents/hyperbolic/HyperSarsaLambdaAgent.class */
public class HyperSarsaLambdaAgent extends GeneralAgent {
    public TamerAgent tamerAgent;
    private String writePredHRewDir;
    private String writePredHRewPath;
    private MutableDouble hyperDiscountParam = new MutableDouble(1.0d);
    private double DEFAULT_HYPER_PARAM = 1.0d;
    private MicroModel[] microModels = null;
    private int numMicroModels = 99;
    public boolean USING_PY_MC_MODEL = false;
    private int H_NUM = -1;
    private double predHRewThisEp = 0.0d;

    public HyperSarsaLambdaAgent() {
        this.tamerAgent = null;
        this.tamerAgent = new TamerAgent();
    }

    public void setNumMicroModels(int i) {
        if (this.microModels != null) {
            System.err.println("Changing number of micro models after they're already instantiated. This action is not supported. Exiting.");
            System.exit(1);
        }
        this.numMicroModels = i;
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void processPreInitArgs(String[] strArr) {
        super.processPreInitArgs(strArr);
        this.tamerAgent.processPreInitArgs(strArr);
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void processPostInitArgs(String[] strArr) {
        System.out.println("\n[------process post-init args------] " + Arrays.toString(strArr));
        for (int i = 0; i < strArr.length; i++) {
            if (strArr[i].equals("-discountParam") && i + 1 < strArr.length) {
                setHyperDiscParam(Double.valueOf(strArr[i + 1]).doubleValue());
            }
        }
    }

    public void setHyperDiscParam(double d) {
        this.hyperDiscountParam = new MutableDouble(d);
        this.actSelector.setDiscountParam(d);
        this.microModels = new MicroModel[this.numMicroModels];
        for (int i = 0; i < this.numMicroModels; i++) {
            this.microModels[i] = new MicroModel(Math.pow((i + 1) / (this.numMicroModels + 1.0d), d), ((IncGDLinearModel) this.model).duplicate());
        }
        this.model.clearSamplesAndReset();
        System.out.println("hyperbolic discount param set to: " + d);
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void initParams(String str) {
        super.initParams(str);
        this.tamerAgent.initParams(str);
    }

    public static void main(String[] strArr) {
        new HyperSarsaLambdaAgent().test();
        new HyperSarsaLambdaAgent().processPreInitArgs(strArr);
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent, org.rlcommunity.rlglue.codec.AgentInterface
    public void agent_init(String str) {
        GeneralAgent.agent_init(str, this);
        System.out.println("rlAgent params: " + this.params.toOneLineStr());
        if (!(this.model instanceof IncGDLinearModel)) {
            System.err.println("HyperSarsaLambdaAgent does not support non-linear models. Exiting.");
            System.exit(1);
        }
        ((IncGDLinearModel) this.model).setDiscountFactor(0.0d);
        if (this.tamerAgent == null) {
            this.tamerAgent = new TamerAgent();
        }
        this.tamerAgent.setIsTopLevelAgent(false);
        this.tamerAgent.enableGUI = false;
        if (this.trainFromLog) {
            this.tamerAgent.trainFromLog = true;
            this.tamerAgent.trainLogPath = this.trainLogPath;
            this.trainFromLog = false;
        }
        if (this.USING_PY_MC_MODEL) {
            setTamerForPyMC();
        }
        this.tamerAgent.agent_init(str);
        if (this.USING_PY_MC_MODEL) {
            loadPyMCWts();
        }
        this.actSelector = new ActionSelect(this.model, this.params.selectionMethod, this.params.selectionParams, this.currObsAndAct.getAct().duplicate());
        this.actSelector.setDiscountType(ActionSelect.DiscountTypes.HYPER);
        this.actSelector.setRewModel(this.tamerAgent.model);
        this.microModels = new MicroModel[this.numMicroModels];
        setHyperDiscParam(this.DEFAULT_HYPER_PARAM);
        this.writePredHRewDir = String.valueOf(RLLIBRARY_PATH) + "/data/" + this.expName;
        this.writePredHRewPath = String.valueOf(this.writePredHRewDir) + "/HRew-" + this.unique + ".rew";
        if (this.masterLogSwitch && this.recordRew) {
            System.out.println("Reward log base path: " + this.writePredHRewDir);
            if (this.recHandler.canWriteToFile) {
                new File(this.writeLogDir).mkdir();
            }
            System.out.println("this.writePredHRewPath: " + this.writePredHRewPath);
            this.recHandler.writeParamsToRewLog(this.writePredHRewPath, this.params);
        }
        endInitHelper();
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public Action agent_start(Observation observation, double d, Action action) {
        startHelper();
        this.predHRewThisEp = 0.0d;
        this.tamerAgent.agent_start(observation, d, new Action());
        this.currObsAndAct.setAct(agent_step(0.0d, observation, d, action));
        this.tamerAgent.lastObsAndAct.setAct(this.currObsAndAct.getAct().duplicate());
        return this.currObsAndAct.getAct();
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public Action agent_step(double d, Observation observation, double d2, Action action) {
        this.stepStartTime = d2;
        stepStartHelper(d);
        this.tamerAgent.hRewList = new ArrayList<>(this.hRewThisStep);
        if (this.stepsThisEp > 1) {
            this.currObsAndAct.setAct(new Action());
            this.tamerAgent.agent_step(d, observation, this.stepStartTime, this.currObsAndAct.getAct());
        }
        overwriteLastObsAndAct(this.tamerAgent);
        if (action == null) {
            this.currObsAndAct.setAct(this.actSelector.selectAction(observation, this.lastObsAndAct.getAct()));
        } else {
            this.currObsAndAct.setAct(action);
        }
        if (this.stepsThisEp == 399) {
            System.out.println("HyperSarsa act vals: " + Arrays.toString(this.model.getStateActOutputs(observation, this.model.getPossActions(observation))));
        }
        this.tamerAgent.lastObsAndAct.setAct(this.currObsAndAct.getAct().duplicate());
        this.tamerAgent.hLearner.recordTimeStepStart(this.tamerAgent.featGen.getFeats(observation, this.currObsAndAct.getAct()), d2);
        processPrevTimeStep(d, observation);
        stepEndHelper(d, observation);
        return this.currObsAndAct.getAct();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void stepEndHelper(double d, Observation observation) {
        super.stepEndHelper(d, observation);
        overwriteLastObsAndAct(this.tamerAgent);
        if (this.stepsThisEp == 10000000) {
            if (this.isTopLevelAgent) {
                System.out.println("At end of steps!!");
            }
            if (this.recordRew && this.masterLogSwitch) {
                this.recHandler.writeLineToRewLog(this.writePredHRewPath, new StringBuilder(String.valueOf(this.predHRewThisEp)).toString(), true);
            }
        }
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void agent_end(double d, double d2) {
        this.stepStartTime = d2;
        endHelper(d);
        this.tamerAgent.hRewList = new ArrayList<>(this.hRewThisStep);
        this.tamerAgent.agent_end(d, d2);
        processPrevTimeStep(d, null);
        this.actSelector.anneal();
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void endHelper(double d) {
        super.endHelper(d);
        if (this.recordRew && this.masterLogSwitch) {
            this.recHandler.writeLineToRewLog(this.writePredHRewPath, new StringBuilder(String.valueOf(this.predHRewThisEp)).toString(), true);
        }
    }

    private void processPrevTimeStep(double d, Observation observation) {
        if (this.stepsThisEp <= 1) {
            return;
        }
        double[] feats = this.featGen.getFeats(this.lastObsAndAct.getObs(), this.lastObsAndAct.getAct());
        double val = this.tamerAgent.getVal(this.lastObsAndAct.getObs(), this.lastObsAndAct.getAct());
        this.predHRewThisEp += val;
        double[] dArr = (double[]) null;
        if (observation != null) {
            dArr = this.featGen.getFeats(observation, this.currObsAndAct.getAct());
        }
        for (MicroModel microModel : this.microModels) {
            double d2 = 0.0d;
            if (dArr != null) {
                d2 = microModel.predictLabel(dArr);
            }
            microModel.addExperience(new Sample(feats, val + (microModel.discountFactor * d2), 1.0d));
        }
        double[] dArr2 = new double[((IncGDLinearModel) this.model).getWeights().length];
        for (MicroModel microModel2 : this.microModels) {
            double[] weights = microModel2.getWeights();
            for (int i = 0; i < dArr2.length; i++) {
                int i2 = i;
                dArr2[i2] = dArr2[i2] + weights[i];
            }
        }
        for (int i3 = 0; i3 < dArr2.length; i3++) {
            int i4 = i3;
            dArr2[i4] = dArr2[i4] / this.microModels.length;
        }
        ((IncGDLinearModel) this.model).setModelParams(dArr2);
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void receiveKeyInput(char c) {
        System.out.println(c);
        if (c == '/') {
            addHRew(1.0d);
            return;
        }
        if (c == 'z') {
            addHRew(-1.0d);
        } else if (c == ' ' && this.allowUserToggledTraining) {
            toggleInTrainSess();
        }
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void toggleInTrainSess() {
        this.tamerAgent.setInTrainSess(!this.tamerAgent.getInTrainSess());
        this.inTrainSess = !this.inTrainSess;
    }

    public void usePYMCModel(int i) {
        this.USING_PY_MC_MODEL = true;
        this.H_NUM = i;
        System.out.println("this.USING_PY_MC_MODEL set to: " + this.USING_PY_MC_MODEL);
        System.out.println("this.h_num set to: " + this.H_NUM);
    }

    private void setTamerForPyMC() {
        this.tamerAgent.params = Params.getParams(this.tamerAgent.getClass().toString(), getEnvName(this.taskSpecObj.getExtraString()));
        this.tamerAgent.params.setPyMCParams(this.tamerAgent.getClass().toString(), false);
    }

    private void loadPyMCWts() {
        System.out.println("\nLoading H-hat_{" + this.H_NUM + "} from AAMAS-10 TAMER+RL Mountain Car experiments. ");
        double[] dArr = (double[]) null;
        try {
            dArr = RecordHandler.getDoubleArrayFromStr(RecordHandler.getStrArray(new String[]{String.valueOf(RecordHandler.getPresentWorkingDir().replace("/bin", "")) + "/data/mc_tamer/models/juhyun-1228942397.86-100.model", String.valueOf(RecordHandler.getPresentWorkingDir().replace("/bin", "")) + "/data/mc_tamer/models/ikarpov-1228858017.78-100.model"}[this.H_NUM - 1])[0]);
            System.out.println("numWts: " + dArr.length);
        } catch (Exception e) {
            System.err.println("Error in TamerRLAgent.loadPyMCWts: " + e.getMessage() + "\nExiting.");
            System.err.println("If you have already created a data directory somewhere, consider creating a symbolic link to it to make the above path valid.");
            System.exit(0);
        }
        ((IncGDLinearModel) this.tamerAgent.model).setModelParams(dArr);
    }

    public void test() {
        this.numMicroModels = 99;
        this.verbose = false;
        this.params = Params.getParams(getClass().getName(), "HandFed");
        this.params.modelClass = "IncGDLinearModel";
        this.params.featClass = "FeatGen_Discretize";
        this.params.featGenParams.put("numBinsPerDim", new StringBuilder().append(8).toString());
        this.params.stepSize = 0.2d;
        this.tamerAgent.params = Params.getParams(this.tamerAgent.getClass().getName(), "HandFed");
        this.tamerAgent.params.modelClass = "IncGDLinearModel";
        this.tamerAgent.params.featClass = "FeatGen_Discretize";
        this.tamerAgent.params.featGenParams.put("numBinsPerDim", new StringBuilder().append(8).toString());
        this.tamerAgent.params.distClass = "previousStep";
        this.tamerAgent.params.stepSize = 1.0d;
        agent_init("VERSION RL-Glue-3.0 PROBLEMTYPE episodic DISCOUNTFACTOR 1.0 OBSERVATIONS DOUBLES (0.5 " + (8 + 0.5d) + ")  ACTIONS INTS (0 0)  REWARDS (-10.0 10.0)  EXTRA EnvName:HandFed");
        Stopwatch stopwatch = new Stopwatch();
        stopwatch.startTimer();
        for (double d : new double[]{0.0d, 0.5d, 1.0d, 1000000.0d}) {
            setHyperDiscParam(d);
            this.tamerAgent.EP_END_PAUSE = 0;
            Observation observation = new Observation();
            observation.doubleArray = new double[1];
            Action action = new Action();
            action.doubleArray = new double[1];
            action.doubleArray[0] = 0.0d;
            double d2 = 0.0d;
            if (!getInTrainSess()) {
                toggleInTrainSess();
            }
            for (int i = 0; i < 1000; i++) {
                for (int i2 = 0; i2 < 8; i2++) {
                    d2 += 0.2d;
                    observation.doubleArray[0] = i2 + 1;
                    if (i2 == 0) {
                        agent_start(observation, d2, action);
                        this.hRewList.add(new HRew(1.0d, d2 - (0.2d * 0.5d)));
                    } else {
                        agent_step(0.0d, observation, d2, action);
                    }
                }
                agent_end(1.0d, d2);
            }
            System.out.println("Time elapsed: " + stopwatch.getTimeElapsed());
            for (int i3 = 1; i3 <= 8; i3++) {
                double d3 = 1.0d / (1.0d + (d * (8 - i3)));
                System.out.println("Correct value for " + (8 - i3) + " back from reward: " + d3);
                observation.doubleArray[0] = i3;
                double predictLabel = this.model.predictLabel(observation, action);
                System.out.println("Learned value: " + predictLabel);
                if (((predictLabel / d3 < 0.97d) | (predictLabel / d3 > 1.03d)) && (Math.abs(predictLabel) >= 0.001d || d3 >= 0.001d)) {
                    System.err.println("Incorrect hyperbolically discounted value learned!! Exiting.");
                    System.err.println("Correct value: " + d3);
                    System.err.println("Learned value: " + predictLabel);
                    System.err.flush();
                    System.exit(1);
                }
            }
        }
        this.verbose = true;
    }
}
