package edu.utexas.cs.tamerProject.agents.combo;

import edu.utexas.cs.tamerProject.agents.CreditAssign;
import edu.utexas.cs.tamerProject.agents.CreditAssignParamVec;
import edu.utexas.cs.tamerProject.agents.GeneralAgent;
import edu.utexas.cs.tamerProject.featGen.FeatGenerator;
import edu.utexas.cs.tamerProject.modeling.Sample;
import edu.utexas.cs.tamerProject.params.Params;
import java.util.Arrays;
import org.rlcommunity.rlglue.codec.types.Action;
import org.rlcommunity.rlglue.codec.types.Observation;

/* loaded from: input_file:edu/utexas/cs/tamerProject/agents/combo/HInfluence.class */
public class HInfluence {
    private String INFLUENCE_METHOD;
    public double COMB_PARAM;
    private double STEP_DECAY_FACTOR;
    private double EP_DECAY_FACTOR;
    private String TRACE_STYLE;
    private double ACCUM_FACTOR;
    public FeatGenerator featGen;
    boolean stateOnly;
    Action stateOnlyAction = new Action();
    public CreditAssign credA;
    public double[] traces;
    public double[] lastStepTraces;
    private double[] minFeats;
    private double[] maxFeats;

    public HInfluence(String str, double d, String str2, Params params, GeneralAgent generalAgent, boolean z) {
        this.COMB_PARAM = 1.0d;
        this.TRACE_STYLE = "accumulating";
        this.ACCUM_FACTOR = 0.0d;
        this.stateOnly = false;
        this.INFLUENCE_METHOD = str;
        this.COMB_PARAM = d;
        Params hInflParams = Params.getHInflParams(str2, str.equals("eligTrace"));
        if (generalAgent.getClass().toString().contains("TamerRLAgent") && ((TamerRLAgent) generalAgent).USING_PY_MC_MODEL) {
            hInflParams.setPyMCParams(getClass().toString(), str.equals("eligTrace"));
        }
        this.STEP_DECAY_FACTOR = Double.valueOf(hInflParams.hInflParams.get("stepDecayFactor")).doubleValue();
        this.EP_DECAY_FACTOR = Double.valueOf(hInflParams.hInflParams.get("epDecayFactor")).doubleValue();
        if (this.INFLUENCE_METHOD.equals("annealedParam")) {
            this.traces = new double[1];
            setTracesToMax();
            return;
        }
        if (this.INFLUENCE_METHOD.equals("eligTrace")) {
            this.ACCUM_FACTOR = Double.valueOf(hInflParams.hInflParams.get("accumFactor")).doubleValue();
            this.TRACE_STYLE = hInflParams.hInflParams.get("traceStyle");
            this.featGen = generalAgent.getFeatGen(hInflParams);
            this.credA = new CreditAssign(new CreditAssignParamVec(params.distClass, params.creditDelay, params.windowSize, params.extrapolateFutureRew, params.delayWtedIndivRew, params.noUpdateWhenNoRew));
            this.stateOnly = z;
            if (this.stateOnly) {
                this.maxFeats = this.featGen.getMaxPossSFeats();
                this.minFeats = this.featGen.getMinPossSFeats();
                this.stateOnlyAction.intArray = new int[1];
                this.stateOnlyAction.doubleArray = new double[0];
            } else {
                this.maxFeats = this.featGen.getMaxPossFeats();
                this.minFeats = this.featGen.getMinPossFeats();
            }
            this.traces = new double[this.maxFeats.length];
            this.lastStepTraces = new double[this.maxFeats.length];
        }
    }

    public void setAccumFactor(double d) {
        this.ACCUM_FACTOR = d;
    }

    public void setTracesToMax() {
        for (int i = 0; i < this.traces.length; i++) {
            this.traces[i] = 1.0d;
        }
    }

    public void setTraceStyle(String str) {
        this.TRACE_STYLE = str;
    }

    public void setStepDecayFactor(double d) {
        this.STEP_DECAY_FACTOR = d;
    }

    public void setEpDecayFactor(double d) {
        this.EP_DECAY_FACTOR = d;
    }

    private double[] linearNormFeats(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = (dArr[i] - this.minFeats[i]) / (this.maxFeats[i] - this.minFeats[i]);
            if (this.maxFeats[i] - this.minFeats[i] == 0.0d) {
                dArr2[i] = 0.0d;
            }
            if (dArr2[i] < 0.0d) {
                System.out.println("components of negative normed feats -- feats[i]: " + dArr[i] + ", minFeats[i]: " + this.minFeats[i] + ", maxFeats[i]: " + this.maxFeats[i]);
            }
        }
        return dArr2;
    }

    public void episodeEndUpdate() {
        this.lastStepTraces = Arrays.copyOf(this.traces, this.traces.length);
        if (this.EP_DECAY_FACTOR != 1.0d) {
            epDecayEligTraces();
        }
    }

    public double getHInfluence(Observation observation, Action action) {
        return getHInfluence(observation, action, false);
    }

    public double getHInfluence(Observation observation, Action action, boolean z) {
        double[] dArr = this.traces;
        if (z) {
            dArr = this.lastStepTraces;
        }
        return this.INFLUENCE_METHOD.equals("annealedParam") ? dArr[0] * this.COMB_PARAM : getHInfluence(getFeats(observation, action), z);
    }

    public double getHInfluence(int[] iArr, double[] dArr, char[] cArr, int[] iArr2, double[] dArr2) {
        return this.INFLUENCE_METHOD.equals("annealedParam") ? this.traces[0] * this.COMB_PARAM : getHInfluence(getFeats(iArr, dArr, cArr, iArr2, dArr2));
    }

    public double getHInfluence(double[] dArr) {
        return getHInfluence(dArr, false);
    }

    public double getHInfluence(double[] dArr, boolean z) {
        double[] dArr2 = this.traces;
        if (z) {
            dArr2 = this.lastStepTraces;
        }
        if (this.INFLUENCE_METHOD.equals("annealedParam")) {
            return dArr2[0] * this.COMB_PARAM;
        }
        if (!this.INFLUENCE_METHOD.equals("eligTrace")) {
            System.err.println("Influence method " + this.INFLUENCE_METHOD + " not supported. Exiting.");
            System.exit(1);
            return -1.0d;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double[] linearNormFeats = linearNormFeats(dArr);
        for (int i = 0; i < linearNormFeats.length; i++) {
            d += linearNormFeats[i] * Math.min(1.0d, dArr2[i]);
            d2 += linearNormFeats[i];
            if (dArr2[i] < 0.0d) {
                System.out.println("\nfound a trace below 0!!!");
                System.out.println("normFeats[i]: " + linearNormFeats[i]);
                System.out.println("tracesForInf[i]: " + dArr2[i]);
            }
        }
        return this.COMB_PARAM * (d / d2);
    }

    public void stepUpdate(boolean z, double d) {
        this.lastStepTraces = Arrays.copyOf(this.traces, this.traces.length);
        if (this.STEP_DECAY_FACTOR != 1.0d) {
            stepDecayEligTraces();
        }
        if (this.INFLUENCE_METHOD.equals("eligTrace")) {
            Sample[] processSamplesAndRemoveFinished = this.credA.processSamplesAndRemoveFinished(d, z);
            if (z) {
                double d2 = 0.0d;
                for (Sample sample : processSamplesAndRemoveFinished) {
                    growEligTraces(sample.feats, sample.creditUsedLastStep);
                    d2 += sample.creditUsedLastStep;
                }
            }
        }
    }

    private double[] getFeats(Observation observation, Action action) {
        return this.stateOnly ? this.featGen.getSFeats(observation) : this.featGen.getFeats(observation, action);
    }

    private double[] getFeats(int[] iArr, double[] dArr, char[] cArr, int[] iArr2, double[] dArr2) {
        Observation observation = new Observation();
        observation.intArray = iArr;
        observation.doubleArray = dArr;
        observation.charArray = cArr;
        Action action = new Action();
        action.intArray = iArr2;
        action.doubleArray = dArr2;
        return this.stateOnly ? this.featGen.getSFeats(observation) : this.featGen.getFeats(observation, action);
    }

    public void recordTimeStepStart(Observation observation, Action action, double d) {
        if (this.INFLUENCE_METHOD.equals("eligTrace")) {
            if (this.stateOnly) {
                action = this.stateOnlyAction;
            }
            this.credA.recordTimeStepStart(linearNormFeats(getFeats(observation, action)), d);
        }
    }

    public void recordTimeStepEnd(double d) {
        if (this.INFLUENCE_METHOD.equals("eligTrace")) {
            this.credA.recordTimeStepEnd(d);
        }
    }

    public void stepDecayEligTraces() {
        for (int i = 0; i < this.traces.length; i++) {
            double[] dArr = this.traces;
            int i2 = i;
            dArr[i2] = dArr[i2] * this.STEP_DECAY_FACTOR;
        }
    }

    public void epDecayEligTraces() {
        for (int i = 0; i < this.traces.length; i++) {
            double[] dArr = this.traces;
            int i2 = i;
            dArr[i2] = dArr[i2] * this.EP_DECAY_FACTOR;
        }
    }

    private void growEligTraces(double[] dArr, double d) {
        System.out.print("g");
        System.out.flush();
        for (int i = 0; i < dArr.length; i++) {
            if (this.TRACE_STYLE.equals("replacing")) {
                this.traces[i] = Math.max(dArr[i], this.traces[i]);
            } else if (this.TRACE_STYLE.equals("accumulating")) {
                this.traces[i] = Math.min((d * dArr[i] * this.ACCUM_FACTOR) + this.traces[i], 1.0d);
            } else {
                System.err.println("Trace style " + this.TRACE_STYLE + "  is not supported in " + getClass() + ". Exiting.");
                System.exit(0);
            }
        }
    }
}
