package edu.utexas.cs.tamerProject.agents;

import edu.utexas.cs.tamerProject.modeling.Sample;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

/* loaded from: input_file:edu/utexas/cs/tamerProject/agents/CreditAssign.class */
public class CreditAssign {
    public static final Random randGenerator = new Random();
    ArrayList<TimeStep> timeStepsInWindow;
    ArrayList<Sample> activeSamples;
    int totalTimeSteps;
    final int UNIQUE_START;
    boolean inTrainSess = false;
    final double SAMPLE_CUMUL_CRED_MIN = 0.9d;
    boolean EXTRAPOLATE_FUTURE_REW;
    private double MIN_USED_CRED_FOR_EXTRAP;
    public static final double APPROX_ONE = 0.99999d;
    String distClass;
    boolean delayWtedIndivRew;
    boolean noUpdateWhenNoRew;
    double windowStart;
    double windowEnd;

    /* loaded from: input_file:edu/utexas/cs/tamerProject/agents/CreditAssign$TimeStep.class */
    public class TimeStep {
        public static final double UNASSIGNED_START_TIME = Double.NEGATIVE_INFINITY;
        public static final double UNASSIGNED_END_TIME = Double.POSITIVE_INFINITY;
        public double startTime = Double.NEGATIVE_INFINITY;
        public double endTime = Double.POSITIVE_INFINITY;
        public double[] feats = null;
        public boolean throwOut = false;
        public boolean setInStone = false;
        public double credUsedBeforeLastStep = 0.0d;

        public TimeStep() {
        }

        public String toString() {
            return String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf("\n") + "startTime: " + String.format("%f", Double.valueOf(this.startTime)) + "\n") + "endTime: " + String.format("%f", Double.valueOf(this.endTime)) + "\n") + "throwOut: " + this.throwOut + "\n") + "setInStone: " + this.setInStone + "\n") + "feats: " + Arrays.toString(this.feats) + "\n";
        }
    }

    public CreditAssign(CreditAssignParamVec creditAssignParamVec) {
        this.EXTRAPOLATE_FUTURE_REW = false;
        this.MIN_USED_CRED_FOR_EXTRAP = 0.5d;
        this.delayWtedIndivRew = false;
        this.noUpdateWhenNoRew = false;
        System.out.println("Creating CreditAssign object with params: " + creditAssignParamVec);
        this.distClass = creditAssignParamVec.distClass;
        this.windowStart = creditAssignParamVec.creditDelay;
        this.windowEnd = creditAssignParamVec.windowSize + this.windowStart;
        this.UNIQUE_START = creditAssignParamVec.uniqueStart;
        this.EXTRAPOLATE_FUTURE_REW = creditAssignParamVec.EXTRAPOLATE_FUTURE_REW;
        this.delayWtedIndivRew = creditAssignParamVec.delayWtedIndivRew;
        this.noUpdateWhenNoRew = creditAssignParamVec.noUpdateWhenNoRew;
        if (this.delayWtedIndivRew) {
            this.EXTRAPOLATE_FUTURE_REW = false;
            this.MIN_USED_CRED_FOR_EXTRAP = 0.0d;
        }
        this.totalTimeSteps = 0;
        clearHistory();
    }

    public void recordTimeStepEnd(double d) {
        int size = this.timeStepsInWindow.size() - 1;
        if (this.timeStepsInWindow.size() > 0) {
            this.timeStepsInWindow.get(size).endTime = d;
        }
    }

    public void recordTimeStepStart(double[] dArr, double d) {
        this.timeStepsInWindow.add(new TimeStep());
        int size = this.timeStepsInWindow.size() - 1;
        if (this.timeStepsInWindow.size() > 1 && this.timeStepsInWindow.get(size - 1).endTime == Double.POSITIVE_INFINITY) {
            System.err.println("\n\n\nTried to create a new time step in CreditAssign before ending the last. Exiting.");
            System.err.println(Arrays.toString(Thread.currentThread().getStackTrace()));
            System.exit(1);
        }
        this.timeStepsInWindow.get(size).feats = dArr;
        this.timeStepsInWindow.get(size).startTime = d;
        this.activeSamples.add(new Sample(dArr, 1.0d, this.delayWtedIndivRew ? -1 : this.totalTimeSteps + this.UNIQUE_START));
        this.totalTimeSteps++;
    }

    private int getStepsBeforeCurrent(int i) {
        return (this.timeStepsInWindow.size() - i) - 1;
    }

    public Sample[] processSamplesAndRemoveFinished(double d, boolean z) {
        if (GeneralAgent.duringStepTransition && this.timeStepsInWindow.size() > 0 && this.timeStepsInWindow.get(this.timeStepsInWindow.size() - 1).endTime == Double.POSITIVE_INFINITY) {
            System.err.println("Calling processSamplesAndRemoveFinished() with a step's start recorded but the end unrecorded in CreditAssign. Be careful to call this method before calling recordTimeStepStart().");
            System.err.println("Thread stack trace:\n" + Arrays.toString(Thread.currentThread().getStackTrace()));
            System.err.println("Killing agent process\n\n");
            System.exit(1);
        }
        ArrayList<Sample> processTimeSteps = processTimeSteps(d, z);
        ArrayList<Sample> removeFinishedTimeSteps = removeFinishedTimeSteps(d, z);
        if (this.noUpdateWhenNoRew && this.delayWtedIndivRew && allSamplesHaveZeroRew()) {
            return new Sample[0];
        }
        ArrayList<Sample> arrayList = new ArrayList<>();
        if (this.EXTRAPOLATE_FUTURE_REW || this.delayWtedIndivRew) {
            arrayList.addAll(processTimeSteps);
            if (this.delayWtedIndivRew) {
                for (int i = 0; i < this.activeSamples.size(); i++) {
                    this.activeSamples.set(i, this.activeSamples.get(i).m49clone());
                    this.activeSamples.get(i).unweightedRew = 0.0d;
                    this.activeSamples.get(i).label = 0.0d;
                }
            }
        }
        arrayList.addAll(removeFinishedTimeSteps);
        if (this.delayWtedIndivRew) {
            removeSamplesWNoNewCred(arrayList);
            setWtToCredLastStep(arrayList);
        } else if (this.noUpdateWhenNoRew) {
            removeSamplesWZeroRew(arrayList);
        }
        Sample[] sampleArr = new Sample[arrayList.size()];
        arrayList.toArray(sampleArr);
        return sampleArr;
    }

    private ArrayList<Sample> processTimeSteps(double d, boolean z) {
        ArrayList<Sample> arrayList = new ArrayList<>();
        for (int i = 0; i < this.timeStepsInWindow.size(); i++) {
            Sample sample = this.activeSamples.get(i);
            TimeStep timeStep = this.timeStepsInWindow.get(i);
            if (!this.timeStepsInWindow.get(i).setInStone) {
                double d2 = this.activeSamples.get(i).usedCredit;
                sample.usedCredit = getCreditPastElig(i, getStepsBeforeCurrent(i), d);
                sample.label = sample.unweightedRew;
                if (this.EXTRAPOLATE_FUTURE_REW) {
                    sample.label /= sample.usedCredit;
                }
                sample.creditUsedLastStep = sample.usedCredit - timeStep.credUsedBeforeLastStep;
                if (!z && sample.usedCredit > 0.0d && (d2 == 0.0d || d2 == Double.NEGATIVE_INFINITY)) {
                    timeStep.throwOut = true;
                }
                if (!z && isSampleFinished(i) && (this.distClass == "previousStep" || this.distClass == "immediate")) {
                    timeStep.throwOut = true;
                }
                if (this.distClass != "immediate" && this.distClass != "previousStep" && sample.usedCredit > this.MIN_USED_CRED_FOR_EXTRAP && sample.usedCredit < 0.99999d && !timeStep.throwOut) {
                    arrayList.add(sample);
                }
                timeStep.credUsedBeforeLastStep = sample.usedCredit;
            }
        }
        return arrayList;
    }

    private ArrayList<Sample> removeFinishedTimeSteps(double d, boolean z) {
        ArrayList<Sample> arrayList = new ArrayList<>();
        int i = this.distClass == "immediate" ? 0 : 1;
        int i2 = 0;
        while (i2 < this.timeStepsInWindow.size() - i) {
            if (this.timeStepsInWindow.get(i2).throwOut) {
                removeSample(i2);
                i2--;
            } else if (isSampleFinished(i2)) {
                arrayList.add(removeSample(i2));
                i2--;
            }
            i2++;
        }
        return arrayList;
    }

    private boolean isSampleFinished(int i) {
        boolean z = false;
        if (this.timeStepsInWindow.get(i).setInStone) {
            z = true;
        } else if (this.distClass == "previousStep" || this.distClass == "immediate") {
            int stepsBeforeCurrent = getStepsBeforeCurrent(i);
            int i2 = 0;
            if (this.distClass == "immediate") {
                i2 = GeneralAgent.duringStepTransition ? 0 : 1;
            } else if (this.distClass == "previousStep") {
                i2 = GeneralAgent.duringStepTransition ? 1 : 2;
            }
            if (stepsBeforeCurrent == i2) {
                z = true;
            }
        } else if (this.distClass == "uniform" && this.activeSamples.get(i).usedCredit <= 1.0d && this.activeSamples.get(i).usedCredit >= 0.99999d) {
            z = true;
        }
        return z;
    }

    private Sample removeSample(int i) {
        this.timeStepsInWindow.remove(i);
        Sample remove = this.activeSamples.remove(i);
        remove.label = remove.unweightedRew / remove.usedCredit;
        return remove;
    }

    public void setInTrainSess(double d, boolean z) {
        if (this.inTrainSess && !z) {
            for (int i = 0; i < this.timeStepsInWindow.size(); i++) {
                double creditPastElig = getCreditPastElig(i, getStepsBeforeCurrent(i), d);
                if (creditPastElig < 0.9d && creditPastElig > 0.0d) {
                    this.timeStepsInWindow.get(i).throwOut = true;
                    println("Throwing out an already credited time step.");
                } else if (creditPastElig >= 0.9d) {
                    this.timeStepsInWindow.get(i).setInStone = true;
                    this.activeSamples.get(i).usedCredit = creditPastElig;
                }
            }
        }
        this.inTrainSess = z;
        println("\n\n\n\n\n\n\n\nIn training session: " + this.inTrainSess + "\n\n\n");
    }

    public void clearHistory() {
        this.timeStepsInWindow = new ArrayList<>();
        this.activeSamples = new ArrayList<>();
    }

    public void processNewHReward(double d, double d2) {
        double d3 = 0.0d;
        for (int i = 0; i < this.timeStepsInWindow.size(); i++) {
            double credit = d * getCredit(d2, i, getStepsBeforeCurrent(i));
            d3 += credit;
            Sample sample = this.activeSamples.get(i);
            sample.unweightedRew += credit;
            sample.usedCredit = Math.max(getCreditPastElig(i, getStepsBeforeCurrent(i), d2), sample.usedCredit);
            sample.label = sample.unweightedRew;
            if (this.EXTRAPOLATE_FUTURE_REW) {
                sample.label /= sample.usedCredit;
            }
        }
    }

    private double getRelNearBound(int i, double d) {
        if (this.timeStepsInWindow.get(i).endTime != Double.NEGATIVE_INFINITY) {
            return d - this.timeStepsInWindow.get(i).endTime;
        }
        return 0.0d;
    }

    private double getCredit(double d, int i, int i2) {
        double d2 = 0.0d;
        if (this.distClass == "previousStep") {
            return i2 == 1 ? 1.0d : 0.0d;
        }
        if (this.distClass == "immediate") {
            return i2 == 0 ? 1.0d : 0.0d;
        }
        if (this.distClass == "uniform") {
            double d3 = d - this.timeStepsInWindow.get(i).startTime;
            double relNearBound = getRelNearBound(i, d);
            if (d3 <= this.windowStart || relNearBound >= this.windowEnd) {
                d2 = 0.0d;
            } else {
                if (relNearBound < this.windowStart) {
                    relNearBound = this.windowStart;
                }
                if (d3 > this.windowEnd) {
                    d3 = this.windowEnd;
                }
                d2 = (d3 - relNearBound) / (this.windowEnd - this.windowStart);
            }
        } else {
            System.err.println("Using an invalid distribution class for credit assignment!");
            System.exit(1);
        }
        if (d2 < 0.0d || d2 > 1.0d) {
            println("bad credit: " + d2);
        }
        return d2;
    }

    public double getCreditPastElig(int i, int i2, double d) {
        if (this.distClass == "previousStep") {
            return isSampleFinished(i) ? 1.0d : 0.0d;
        }
        if (this.distClass == "immediate") {
            return isSampleFinished(i) ? 1.0d : 0.0d;
        }
        if (this.distClass != "uniform") {
            System.err.println("Using an invalid distribution class for credit assignment!");
            System.exit(1);
            return 0.0d;
        }
        TimeStep timeStep = this.timeStepsInWindow.get(i);
        double d2 = this.windowEnd - this.windowStart;
        double min = Math.min(timeStep.endTime, d) - timeStep.startTime;
        if (min == 0.0d) {
            if (timeStep.startTime >= d) {
                return 0.0d;
            }
            if (!this.inTrainSess) {
                return 1.0d;
            }
            System.err.println("Steps of zero duration are being thrown out during training. Exiting from CreditAssign.");
            System.exit(1);
            return 1.0d;
        }
        double d3 = d - (timeStep.startTime + this.windowStart);
        double abs = Math.abs(d2 - min);
        double min2 = Math.min(d2, min);
        double d4 = min2 * (min2 + abs);
        return d4 == 0.0d ? 0.0d : Math.min(1.0d, (((Math.pow(Math.min(min2, Math.max(0.0d, d3)), 2.0d) / 2.0d) + (Math.max(0.0d, Math.min(abs, d3 - min2)) * min2)) + Math.max(0.0d, Math.pow(Math.min(min2, Math.max(0.0d, d3 - (min2 + abs))), 2.0d) / 2.0d)) / d4);
    }

    private void removeSamplesWNoNewCred(ArrayList<Sample> arrayList) {
        int i = 0;
        while (i < arrayList.size()) {
            if (arrayList.get(i).creditUsedLastStep == 0.0d) {
                arrayList.remove(i);
                i--;
            }
            i++;
        }
    }

    private void removeSamplesWZeroRew(ArrayList<Sample> arrayList) {
        int i = 0;
        while (i < arrayList.size()) {
            if (arrayList.get(i).label == 0.0d) {
                arrayList.remove(i);
                i--;
            }
            i++;
        }
    }

    private void setWtToCredLastStep(ArrayList<Sample> arrayList) {
        for (int i = 0; i < arrayList.size(); i++) {
            arrayList.get(i).weight = arrayList.get(i).creditUsedLastStep;
        }
    }

    private boolean allSamplesHaveZeroRew() {
        boolean z = true;
        for (int i = 0; i < this.activeSamples.size(); i++) {
            if (this.activeSamples.get(i).unweightedRew != 0.0d) {
                z = false;
            }
        }
        return z;
    }

    public double drawDelay() {
        double d = 0.0d;
        if (this.distClass == "previousStep") {
            d = 0.0d;
        } else if (this.distClass == "uniform") {
            d = this.windowStart + (randGenerator.nextDouble() * (this.windowEnd - this.windowStart));
        } else {
            System.err.println("Using an invalid distribution class for credit assignment!");
            System.exit(1);
        }
        return d;
    }

    public static void main(String[] strArr) {
        CreditAssign creditAssign = new CreditAssign(new CreditAssignParamVec("uniform", 0.2d, 0.6d, true, false, true));
        creditAssign.setInTrainSess(0.0d, true);
        for (int i = 0; i < 10; i++) {
            creditAssign.println("\n\n");
            creditAssign.recordTimeStepEnd(0.2d * i);
            creditAssign.recordTimeStepStart(new double[]{i, i + 10}, 0.2d * i);
            creditAssign.processNewHReward(0.2d * ((double) i) == 1.8d ? 1 : 0, 0.2d * i);
            System.out.println("\nsamples for update: \n" + Arrays.toString(creditAssign.processSamplesAndRemoveFinished(0.2d * i, true)));
            System.out.println("\nactiveSamples: \n" + creditAssign.activeSamples);
        }
        double[] dArr = {10.0d, 20.0d};
        creditAssign.processNewHReward(1.0d, 1.9d);
        creditAssign.recordTimeStepEnd(2.0d);
        creditAssign.recordTimeStepStart(dArr, 2.0d);
        creditAssign.removeFinishedTimeSteps(2.0d, true);
        creditAssign.processNewHReward(1.0d, 2.0d);
        creditAssign.println("\n\n\nCredited samples: " + Arrays.toString(creditAssign.processSamplesAndRemoveFinished(2.0d, true)));
        creditAssign.println("");
        creditAssign.clearHistory();
        creditAssign.processNewHReward(2.0d, 2.0d);
        creditAssign.println("\n\n\nCredited samples after clearing: " + Arrays.toString(creditAssign.processSamplesAndRemoveFinished(2.0d, true)));
        System.out.println("\n\n\n-----calculation of example in figure of journal paper-----");
        creditAssign.recordTimeStepStart(dArr, 3.25d);
        creditAssign.processNewHReward(2.0d, 3.4d);
        creditAssign.recordTimeStepEnd(3.45d);
        creditAssign.recordTimeStepStart(dArr, 3.45d);
        creditAssign.processNewHReward(2.0d, 3.85d);
        Sample[] processSamplesAndRemoveFinished = creditAssign.processSamplesAndRemoveFinished(3.9d, true);
        creditAssign.println("activeSamples (where example label is): " + creditAssign.activeSamples);
        creditAssign.println("\n\n\nCredited samples: " + Arrays.toString(processSamplesAndRemoveFinished));
        creditAssign.processNewHReward(2.0d, 4.1d);
    }

    public void print(String str) {
        System.out.print(str);
    }

    public void println(String str) {
        System.out.println(str);
    }
}
