package edu.utexas.cs.tamerProject.agents.combo;

import edu.utexas.cs.tamerProject.agents.GeneralAgent;
import edu.utexas.cs.tamerProject.agents.sarsaLambda.SarsaLambdaAgent;
import edu.utexas.cs.tamerProject.agents.tamer.TamerAgent;
import edu.utexas.cs.tamerProject.experiment.LogTrainer;
import edu.utexas.cs.tamerProject.experiment.RecordHandler;
import edu.utexas.cs.tamerProject.modeling.IncGDLinearModel;
import edu.utexas.cs.tamerProject.params.Params;
import edu.utexas.cs.tamerProject.trainInterface.TrainerListener;
import edu.utexas.cs.tamerProject.visualization.EligModDisplay;
import java.awt.HeadlessException;
import java.util.ArrayList;
import javax.swing.SwingUtilities;
import org.rlcommunity.rlglue.codec.AgentInterface;
import org.rlcommunity.rlglue.codec.types.Action;
import org.rlcommunity.rlglue.codec.types.Observation;
import org.rlcommunity.rlglue.codec.util.AgentLoader;
import rlVizLib.general.ParameterHolder;
import rlVizLib.utilities.UtilityShop;

/* loaded from: input_file:edu/utexas/cs/tamerProject/agents/combo/TamerRLAgent.class */
public class TamerRLAgent extends GeneralAgent implements AgentInterface {
    public SarsaLambdaAgent rlAgent;
    public TamerAgent tamerAgent;
    public HInfluence hInf;
    private String H_INFLUENCE_METHOD;
    private boolean SIMUL_LEARNING;
    public int COMBINATION_METHOD;
    private double INITIAL_COMB_PARAM;
    public boolean USING_PY_MC_MODEL;
    private int H_NUM;
    private boolean tamerControl;
    public final int RL_ON_H_AS_R = -2;
    public final int TAMER_ONLY = -1;
    public final int RL_ONLY = 0;
    public final int REW_SHAPING = 1;
    public final int FEAT_ADD = 2;
    public final int Q_INIT = 3;
    public final int Q_AUGM = 4;
    public final int EXTRA_ACT = 5;
    public final int ACT_BIASING = 6;
    public final int BERNOULLI_ACT = 7;
    public final int STATE_POT_FCN_SHAPING = 8;
    public final int SA_POT_FCN_SHAPING = 9;
    public final int PROB_ACT_W_OSCILL_DAMP = 10;
    private final boolean DISPLAY_ELIG_MOD = true;

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void processPreInitArgs(String[] strArr) {
        super.processPreInitArgs(strArr);
        for (int i = 0; i < strArr.length; i++) {
            String str = strArr[i];
            if (str.equals("-combMethod") && i + 1 < strArr.length) {
                this.COMBINATION_METHOD = Integer.valueOf(strArr[i + 1]).intValue();
                System.out.println("this.COMBINATION_METHOD set to: " + this.COMBINATION_METHOD);
            } else if (str.equals("-combParam") && i + 1 < strArr.length) {
                this.INITIAL_COMB_PARAM = Double.valueOf(strArr[i + 1]).doubleValue();
                System.out.println("this.INITIAL_COMB_PARAM set to: " + this.INITIAL_COMB_PARAM);
            } else if (str.equals("-eligTrace")) {
                this.H_INFLUENCE_METHOD = "eligTrace";
                System.out.println("this.H_INFLUENCE_METHOD set to: " + this.H_INFLUENCE_METHOD);
            } else if (str.equals("-simulLearning")) {
                this.SIMUL_LEARNING = true;
                System.out.println("this.SIMUL_LEARNING set to: " + this.SIMUL_LEARNING);
            } else if (str.equals("-pyMCModel") && i + 1 < strArr.length) {
                usePYMCModel(Integer.valueOf(strArr[i + 1]).intValue());
            }
        }
        this.rlAgent.processPreInitArgs(strArr);
        this.tamerAgent.processPreInitArgs(strArr);
        System.out.println("------------------");
    }

    public void usePYMCModel(int i) {
        this.USING_PY_MC_MODEL = true;
        this.H_NUM = i;
        System.out.println("this.USING_PY_MC_MODEL set to: " + this.USING_PY_MC_MODEL);
        System.out.println("this.h_num set to: " + this.H_NUM);
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void initRecords() {
        super.initRecords();
        if (this.tamerAgent != null) {
            this.tamerAgent.initRecords();
        }
        if (this.rlAgent != null) {
            this.rlAgent.initRecords();
        }
    }

    public TamerRLAgent() {
        this.H_INFLUENCE_METHOD = "annealedParam";
        this.SIMUL_LEARNING = false;
        this.COMBINATION_METHOD = 6;
        this.INITIAL_COMB_PARAM = 10.0d;
        this.USING_PY_MC_MODEL = false;
        this.H_NUM = -1;
        this.RL_ON_H_AS_R = -2;
        this.TAMER_ONLY = -1;
        this.RL_ONLY = 0;
        this.REW_SHAPING = 1;
        this.FEAT_ADD = 2;
        this.Q_INIT = 3;
        this.Q_AUGM = 4;
        this.EXTRA_ACT = 5;
        this.ACT_BIASING = 6;
        this.BERNOULLI_ACT = 7;
        this.STATE_POT_FCN_SHAPING = 8;
        this.SA_POT_FCN_SHAPING = 9;
        this.PROB_ACT_W_OSCILL_DAMP = 10;
        this.DISPLAY_ELIG_MOD = true;
        this.tamerAgent = new TamerAgent();
        this.rlAgent = new SarsaLambdaAgent();
    }

    public TamerRLAgent(ParameterHolder parameterHolder) {
        this();
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void initParams(String str) {
        super.initParams(str);
        this.rlAgent.initParams(str);
        this.tamerAgent.initParams(str);
    }

    public static ParameterHolder getDefaultParameters() {
        ParameterHolder parameterHolder = new ParameterHolder();
        UtilityShop.setVersionDetails(parameterHolder, new DetailsProvider());
        return parameterHolder;
    }

    public static void main(String[] strArr) {
        TamerRLAgent tamerRLAgent = new TamerRLAgent();
        tamerRLAgent.processPreInitArgs(strArr);
        if (tamerRLAgent.glue) {
            new AgentLoader(tamerRLAgent).run();
        } else {
            tamerRLAgent.runSelf();
        }
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent, org.rlcommunity.rlglue.codec.AgentInterface
    public void agent_init(String str) {
        System.out.println("Agent init started");
        GeneralAgent.agent_init(str, this);
        if (this.COMBINATION_METHOD == -1) {
            this.tamerControl = true;
        } else {
            this.tamerControl = false;
        }
        if (this.SIMUL_LEARNING) {
            this.countTrainingEps = true;
        }
        this.tamerAgent.setIsTopLevelAgent(false);
        this.tamerAgent.enableGUI = false;
        if (this.trainFromLog && !this.SIMUL_LEARNING) {
            if (this.COMBINATION_METHOD != 0) {
                this.tamerAgent.trainFromLog = true;
                this.tamerAgent.trainLogPath = this.trainLogPath;
            }
            this.trainFromLog = false;
        }
        if (this.USING_PY_MC_MODEL) {
            setTamerForPyMC();
        }
        this.tamerAgent.agent_init(str);
        if (this.USING_PY_MC_MODEL) {
            loadPyMCWts();
        }
        boolean z = false;
        boolean z2 = false;
        if (this.COMBINATION_METHOD == 7) {
            z = true;
        } else if (this.COMBINATION_METHOD != 6 && this.COMBINATION_METHOD != 1) {
            z2 = true;
            this.H_INFLUENCE_METHOD = "annealedParam";
        }
        this.hInf = new HInfluence(this.H_INFLUENCE_METHOD, this.INITIAL_COMB_PARAM * (this.COMBINATION_METHOD == 9 ? this.discountFactorForLearning.getValue() : 1.0d), getEnvName(this.taskSpecObj.getExtraString()), this.tamerAgent.params, this, z);
        if (z2) {
            this.hInf.setEpDecayFactor(1.0d);
            this.hInf.setStepDecayFactor(1.0d);
        }
        this.rlAgent.setIsTopLevelAgent(false);
        this.rlAgent.enableGUI = false;
        if (this.USING_PY_MC_MODEL) {
            setRLForPyMC();
        }
        if (this.COMBINATION_METHOD == 2) {
            this.rlAgent.addModelBasedFeat(str, this.tamerAgent.model, this.tamerAgent.featGen);
        }
        this.rlAgent.agent_init(str);
        if (this.COMBINATION_METHOD == 6 || this.COMBINATION_METHOD == 9) {
            this.rlAgent.actSelector.addModelForActBias(this.tamerAgent.model, this.hInf);
        }
        if (this.COMBINATION_METHOD == 4) {
            this.rlAgent.qAugModel = this.tamerAgent.model;
            this.rlAgent.actSelector.addModelForActBias(this.rlAgent.qAugModel, this.hInf);
        }
        if (this.COMBINATION_METHOD == -2) {
            this.rlAgent.actSelector.setRewModel(this.tamerAgent.model);
        }
        System.out.println("this.recordRew: " + this.recordRew);
        if (!this.recordRew) {
            this.enableGUI = false;
        }
        if (!GeneralAgent.isApplet && this.enableGUI) {
            try {
                SwingUtilities.invokeLater(new Runnable() { // from class: edu.utexas.cs.tamerProject.agents.combo.TamerRLAgent.1
                    @Override // java.lang.Runnable
                    public void run() {
                        TrainerListener.createAndShowGUI(TamerRLAgent.this);
                    }
                });
            } catch (HeadlessException e) {
                System.out.println("Exception in TamerRLAgent while trying to create reinforcement window: " + e.toString());
            }
        }
        if (this.H_INFLUENCE_METHOD.equals("eligTrace")) {
            try {
                SwingUtilities.invokeLater(new Runnable() { // from class: edu.utexas.cs.tamerProject.agents.combo.TamerRLAgent.2
                    @Override // java.lang.Runnable
                    public void run() {
                        EligModDisplay.createAndShowDisplay(TamerRLAgent.this);
                    }
                });
            } catch (Exception e2) {
                System.out.println("Exception in TamerRLAgent while trying to create display for eligibility module display: " + e2.toString());
            }
        }
        endInitHelper();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void endInitHelper() {
        if (!this.trainFromLog) {
            System.out.println("Not training from log.");
        } else {
            System.out.println("Training from log for simul learning.");
            LogTrainer.trainOnLog(this.trainLogPath, this, this.logTrainEpochs);
        }
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public Action agent_start(Observation observation, double d, Action action) {
        startHelper();
        if (this.COMBINATION_METHOD != 0) {
            this.tamerAgent.agent_start(observation, d, new Action());
        }
        this.currObsAndAct.setAct(agent_step(0.0d, observation, d, action));
        this.rlAgent.agent_start(observation, d, this.currObsAndAct.getAct());
        if (this.COMBINATION_METHOD != 0) {
            this.tamerAgent.lastObsAndAct.setAct(this.currObsAndAct.getAct().duplicate());
        }
        return this.currObsAndAct.getAct();
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public Action agent_step(double d, Observation observation, double d2, Action action) {
        this.stepStartTime = d2;
        stepStartHelper(d);
        this.tamerAgent.hRewList = new ArrayList<>(this.hRewThisStep);
        this.hInf.recordTimeStepEnd(d2);
        if (this.stepsThisEp > 1 && this.COMBINATION_METHOD != 0) {
            this.currObsAndAct.setAct(new Action());
            this.tamerAgent.agent_step(d, observation, this.stepStartTime, this.currObsAndAct.getAct());
        }
        this.hInf.stepUpdate(this.inTrainSess, this.stepStartTime);
        if (action == null) {
            this.currObsAndAct.setAct(this.rlAgent.actSelector.selectAction(observation, this.lastObsAndAct.getAct()));
            if ((this.COMBINATION_METHOD == 7 && this.random.nextDouble() < this.hInf.getHInfluence(observation, (Action) null)) || this.tamerControl) {
                this.currObsAndAct.setAct(this.tamerAgent.actSelector.greedyActSelect(observation, this.lastObsAndAct.getAct()));
            }
        } else {
            this.currObsAndAct.setAct(action);
        }
        manualActIfFailed(observation);
        this.tamerAgent.lastObsAndAct.setAct(this.currObsAndAct.getAct().duplicate());
        if (this.COMBINATION_METHOD != 0) {
            this.tamerAgent.hLearner.recordTimeStepStart(this.tamerAgent.featGen.getFeats(observation, this.currObsAndAct.getAct()), d2);
        }
        double manipulatedRew = getManipulatedRew(d, observation);
        if (this.stepsThisEp > 1) {
            this.rlAgent.agent_step(manipulatedRew, observation, d2, this.currObsAndAct.getAct());
        }
        stepEndHelper(d, observation);
        this.hInf.recordTimeStepStart(observation, this.currObsAndAct.getAct(), d2);
        return this.currObsAndAct.getAct();
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void agent_end(double d, double d2) {
        this.stepStartTime = d2;
        endHelper(d);
        this.tamerAgent.hRewList = new ArrayList<>(this.hRewThisStep);
        double manipulatedRew = getManipulatedRew(d, null);
        this.hInf.episodeEndUpdate();
        if (this.COMBINATION_METHOD != 0) {
            this.tamerAgent.agent_end(d, d2);
        }
        this.rlAgent.agent_end(manipulatedRew, d2);
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent, org.rlcommunity.rlglue.codec.AgentInterface
    public void agent_cleanup() {
    }

    private double getManipulatedRew(double d, Observation observation) {
        if (this.COMBINATION_METHOD == 1) {
            if (this.stepsThisEp > 1) {
                d += this.hInf.getHInfluence(this.lastObsAndAct.getObs(), this.lastObsAndAct.getAct()) * this.tamerAgent.getVal(this.lastObsAndAct.getObs(), this.lastObsAndAct.getAct());
            }
        } else if (this.COMBINATION_METHOD == -2) {
            if (this.stepsThisEp > 1) {
                d = this.tamerAgent.getVal(this.lastObsAndAct.getObs(), this.lastObsAndAct.getAct());
            }
        } else if (this.COMBINATION_METHOD == 8) {
            d += this.INITIAL_COMB_PARAM * this.tamerAgent.getStatePotForTrans(this.lastObsAndAct.getObs(), observation);
        } else if (this.COMBINATION_METHOD == 9) {
            d += this.INITIAL_COMB_PARAM * this.tamerAgent.getSAPotForTrans(this.lastObsAndAct.getObs(), this.lastObsAndAct.getAct(), observation, this.currObsAndAct.getAct());
        }
        return d;
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void receiveKeyInput(char c) {
        System.out.println(c);
        if (c == '/') {
            addHRew(1.0d);
            return;
        }
        if (c == 'z') {
            addHRew(-1.0d);
        } else if (c == ' ' && this.allowUserToggledTraining) {
            toggleInTrainSess();
        }
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void toggleInTrainSess() {
        this.tamerAgent.setInTrainSess(!this.tamerAgent.getInTrainSess());
        this.inTrainSess = !this.inTrainSess;
    }

    public void toggleTamerControl() {
        if (this.COMBINATION_METHOD == -1) {
            System.out.println("In TAMER-ONLY mode. TAMER agent cannot cede control.");
            return;
        }
        this.tamerControl = !this.tamerControl;
        if (this.tamerControl) {
            System.out.println("\n\nTAMER agent taking control.\n");
        } else {
            System.out.println("\n\nTAMER agent ceding control.\n");
        }
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void setMasterLogSwitch(boolean z) {
        this.masterLogSwitch = z;
        this.tamerAgent.setMasterLogSwitch(z);
        this.rlAgent.setMasterLogSwitch(z);
    }

    private void setTamerForPyMC() {
        this.tamerAgent.params = Params.getParams(this.tamerAgent.getClass().toString(), getEnvName(this.taskSpecObj.getExtraString()));
        this.tamerAgent.params.setPyMCParams(this.tamerAgent.getClass().toString(), false);
    }

    private void loadPyMCWts() {
        System.out.println("\nLoading H-hat_{" + this.H_NUM + "} from AAMAS-10 TAMER+RL Mountain Car experiments. ");
        double[] dArr = (double[]) null;
        try {
            dArr = RecordHandler.getDoubleArrayFromStr(RecordHandler.getStrArray(new String[]{String.valueOf(RecordHandler.getPresentWorkingDir().replace("/bin", "")) + "/data/mc_tamer/models/juhyun-1228942397.86-100.model", String.valueOf(RecordHandler.getPresentWorkingDir().replace("/bin", "")) + "/data/mc_tamer/models/ikarpov-1228858017.78-100.model"}[this.H_NUM - 1])[0]);
            System.out.println("numWts: " + dArr.length);
        } catch (Exception e) {
            System.err.println("Error in TamerRLAgent.loadPyMCWts: " + e.getMessage() + "\nExiting.");
            System.err.println("If you have already created a data directory somewhere, consider creating a symbolic link to it to make the above path valid.");
            System.exit(0);
        }
        ((IncGDLinearModel) this.tamerAgent.model).setModelParams(dArr);
    }

    private void setRLForPyMC() {
        this.rlAgent.params = Params.getParams(this.rlAgent.getClass().toString(), getEnvName(this.taskSpecObj.getExtraString()));
        this.rlAgent.params.setPyMCParams(this.rlAgent.getClass().toString(), false);
    }

    public void manualActIfFailed(Observation observation) {
        if (!this.envName.equals("Mountain-Car") || this.numEpsBeforeStop == -1 || this.totalRew >= this.numEpsBeforeStop * (-200) || this.stepsThisEp <= 150) {
            return;
        }
        Action action = new Action();
        action.intArray = new int[1];
        if (observation.doubleArray[1] < 0.0d) {
            action.intArray[0] = 0;
        } else {
            action.intArray[0] = 2;
        }
        this.currObsAndAct.setAct(action);
    }
}
