package edu.utexas.cs.tamerProject.agents.sarsaLambda;

import edu.utexas.cs.tamerProject.actSelect.ActionSelect;
import edu.utexas.cs.tamerProject.agents.GeneralAgent;
import edu.utexas.cs.tamerProject.modeling.CombinationModel;
import edu.utexas.cs.tamerProject.modeling.IncModel;
import edu.utexas.cs.tamerProject.modeling.RegressionModel;
import edu.utexas.cs.tamerProject.modeling.Sample;
import java.util.Arrays;
import org.rlcommunity.rlglue.codec.types.Action;
import org.rlcommunity.rlglue.codec.types.Observation;
import org.rlcommunity.rlglue.codec.util.AgentLoader;

/* loaded from: input_file:edu/utexas/cs/tamerProject/agents/sarsaLambda/SarsaLambdaAgent.class */
public class SarsaLambdaAgent extends GeneralAgent {
    public RegressionModel qAugModel;

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void processPreInitArgs(String[] strArr) {
        super.processPreInitArgs(strArr);
        System.out.println("\n[------Sarsa process pre-init args------] " + Arrays.toString(strArr));
        for (int i = 0; i < strArr.length; i++) {
            if (strArr[i].equals("-initialValue") && i + 1 < strArr.length) {
                if (strArr[i + 1].equals("zero")) {
                    this.params.initWtsValue = 0.0d;
                } else {
                    System.out.println("\nIllegal SarsaAgent initial values type. Exiting.\n\n");
                    System.exit(1);
                }
                System.out.println("Sarsa's Q-model's initial weights set to: " + this.params.initWtsValue);
            }
        }
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void processPostInitArgs(String[] strArr) {
        System.out.println("\n[------process post-init args------] " + Arrays.toString(strArr));
        for (int i = 0; i < strArr.length; i++) {
            if (strArr[i].equals("-discountFactor") && i + 1 < strArr.length) {
                double doubleValue = Double.valueOf(strArr[i + 1]).doubleValue();
                setDiscountFactorForLearning(doubleValue);
                this.actSelector.setDiscountParam(ActionSelect.discFactorToParam(doubleValue));
                System.out.println("discount factor set to: " + doubleValue);
            }
        }
    }

    private void test() {
        agent_init("VERSION RL-Glue-3.0 PROBLEMTYPE episodic DISCOUNTFACTOR 1.0 OBSERVATIONS DOUBLES (0.0 2.0)  ACTIONS INTS (0 0)  REWARDS (-1.0 10.0)  EXTRA EnvName:HandFed");
        Observation observation = new Observation();
        observation.doubleArray = new double[1];
        observation.doubleArray[0] = 0.0d;
        agent_start(observation);
        for (int i = 0; i < 100; i++) {
            for (int i2 = 0; i2 < 3; i2++) {
                observation.doubleArray[0] = i2;
                agent_step(10.0d, observation);
            }
        }
        System.out.println("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n-------------------------------------\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n");
        for (int i3 = 0; i3 < 100; i3++) {
            for (int i4 = 0; i4 < 3; i4++) {
                observation.doubleArray[0] = i4;
                agent_step(-1.0d, observation);
            }
        }
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent, org.rlcommunity.rlglue.codec.AgentInterface
    public void agent_init(String str) {
        GeneralAgent.agent_init(str, this);
        System.out.println("rlAgent params: " + this.params.toOneLineStr());
        this.actSelector = new ActionSelect(this.model, this.params.selectionMethod, this.params.selectionParams, this.currObsAndAct.getAct().duplicate());
        this.qAugModel = null;
        this.numEpsBeforePause = -1;
        endInitHelper();
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public Action agent_start(Observation observation, double d, Action action) {
        startHelper();
        return agent_step(0.0d, observation, d, action);
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public Action agent_step(double d, Observation observation, double d2, Action action) {
        this.stepStartTime = d2;
        stepStartHelper(d);
        if (action == null) {
            this.currObsAndAct.setAct(this.actSelector.selectAction(observation, this.lastObsAndAct.getAct()));
        } else {
            this.currObsAndAct.setAct(action);
        }
        if (this.stepsThisEp == 399) {
            System.out.println("Sarsa act vals: " + Arrays.toString(this.model.getStateActOutputs(observation, this.model.getPossActions(observation))));
        }
        processPrevTimeStep(d, observation);
        stepEndHelper(d, observation);
        return this.currObsAndAct.getAct();
    }

    @Override // edu.utexas.cs.tamerProject.agents.GeneralAgent
    public void agent_end(double d, double d2) {
        this.stepStartTime = d2;
        endHelper(d);
        processPrevTimeStep(d, null);
        this.actSelector.anneal();
    }

    private void processPrevTimeStep(double d, Observation observation) {
        if (this.stepsThisEp <= 1) {
            return;
        }
        double d2 = 0.0d;
        double d3 = 0.0d;
        if (this.qAugModel != null) {
            d3 = ((CombinationModel) this.actSelector.valFcnModel).hInf.getHInfluence(this.lastObsAndAct.getObs(), this.lastObsAndAct.getAct());
        }
        if (observation != null) {
            d2 = this.model.predictLabel(observation, this.currObsAndAct.getAct());
            if (this.qAugModel != null) {
                d2 += d3 * this.qAugModel.predictLabel(observation, this.currObsAndAct.getAct());
            }
        }
        Sample sample = new Sample(this.featGen.getFeats(this.lastObsAndAct.getObs(), this.lastObsAndAct.getAct()), d + (this.discountFactorForLearning.getValue() * d2), 1.0d);
        if (this.qAugModel == null) {
            this.model.addInstance(sample);
        } else {
            ((IncModel) this.model).addInstance(sample, this.qAugModel.predictLabel(this.lastObsAndAct.getObs(), this.lastObsAndAct.getAct()) * d3);
        }
    }

    public static void main(String[] strArr) {
        SarsaLambdaAgent sarsaLambdaAgent = new SarsaLambdaAgent();
        sarsaLambdaAgent.processPreInitArgs(strArr);
        new AgentLoader(sarsaLambdaAgent).run();
    }
}
