package edu.utexas.cs.tamerProject.actSelect;

import edu.utexas.cs.tamerProject.agents.ObsAndTerm;
import edu.utexas.cs.tamerProject.agents.combo.HInfluence;
import edu.utexas.cs.tamerProject.env.EnvTransModel;
import edu.utexas.cs.tamerProject.featGen.FeatGenerator;
import edu.utexas.cs.tamerProject.modeling.CombinationModel;
import edu.utexas.cs.tamerProject.modeling.ObsActModel;
import edu.utexas.cs.tamerProject.modeling.RegressionModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Random;
import org.rlcommunity.rlglue.codec.types.Action;
import org.rlcommunity.rlglue.codec.types.Observation;

/* loaded from: input_file:edu/utexas/cs/tamerProject/actSelect/ActionSelect.class */
public class ActionSelect {
    public String selectionMethod;
    private HashMap<String, String> selectionParams;
    public RegressionModel valFcnModel;
    boolean treeSearch;
    private static double discountParam = Double.MAX_VALUE;
    private static int greedyLeafPathLength = 0;
    private static int exhaustiveSearchDepth = 1;
    private static boolean randomizeSearchDepth = true;
    private ObsActModel rewModel = null;
    private EnvTransModel envTransModel = null;
    private DiscountTypes discountType = DiscountTypes.EXPON;

    /* loaded from: input_file:edu/utexas/cs/tamerProject/actSelect/ActionSelect$DiscountTypes.class */
    public enum DiscountTypes {
        EXPON,
        HYPER;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static DiscountTypes[] valuesCustom() {
            DiscountTypes[] valuesCustom = values();
            int length = valuesCustom.length;
            DiscountTypes[] discountTypesArr = new DiscountTypes[length];
            System.arraycopy(valuesCustom, 0, discountTypesArr, 0, length);
            return discountTypesArr;
        }
    }

    public ActionSelect(RegressionModel regressionModel, String str, HashMap<String, String> hashMap, Action action) {
        this.selectionMethod = "greedy";
        this.treeSearch = false;
        this.valFcnModel = regressionModel;
        this.selectionMethod = str;
        this.selectionParams = hashMap;
        this.treeSearch = Boolean.valueOf(hashMap.get("treeSearch")).booleanValue();
        greedyLeafPathLength = Integer.valueOf(hashMap.get("greedyLeafPathLength")).intValue();
        exhaustiveSearchDepth = Integer.valueOf(hashMap.get("exhaustiveSearchDepth")).intValue();
        randomizeSearchDepth = Boolean.valueOf(hashMap.get("randomizeSearchDepth")).booleanValue();
        System.out.println("selectionParams in ActionSelect: " + hashMap.toString());
    }

    public void addModelForActBias(RegressionModel regressionModel, HInfluence hInfluence) {
        this.valFcnModel = new CombinationModel(this.valFcnModel, regressionModel, hInfluence);
    }

    public void setEnvTransModel(EnvTransModel envTransModel) {
        this.envTransModel = envTransModel;
    }

    public EnvTransModel getEnvTransModel() {
        return this.envTransModel;
    }

    public void setRewModel(ObsActModel obsActModel) {
        if (this.rewModel != null && obsActModel == null) {
            System.err.println("Attempting to change reward model to null in ActionSelect. Exiting.");
            System.exit(1);
        }
        this.rewModel = obsActModel;
    }

    public ObsActModel getRewModel() {
        return this.rewModel;
    }

    public void setTreeSearchFlag(boolean z) {
        this.treeSearch = z;
    }

    public boolean getTreeSearchFlag() {
        return this.treeSearch;
    }

    public void setDiscountParam(double d) {
        discountParam = d;
        System.out.println("Discount parameter in ActionSelect set to: " + discountParam);
    }

    public void setDiscountType(DiscountTypes discountTypes) {
        this.discountType = discountTypes;
    }

    public Action selectAction(Observation observation, Action action) {
        if (this.treeSearch) {
            return (this.selectionMethod.equals("greedy") || new Random().nextDouble() > Double.valueOf(this.selectionParams.get("epsilon")).doubleValue()) ? treeSearchBasedExploitSelect(this.valFcnModel, this.rewModel, this.envTransModel, observation, action, this.discountType) : this.valFcnModel.getRandomAction();
        }
        if (this.selectionMethod.equals("greedy")) {
            return greedyActSelect(this.valFcnModel, observation, action);
        }
        if (this.selectionMethod.equals("e-greedy")) {
            return eGreedyActSelect(Double.valueOf(this.selectionParams.get("epsilon")).doubleValue(), this.valFcnModel, observation, action);
        }
        System.err.println("Action selection method " + this.selectionMethod + " not supported. Exiting.");
        System.exit(0);
        return null;
    }

    public void anneal() {
        if (this.selectionMethod.equals("greedy")) {
            return;
        }
        if (!this.selectionMethod.equals("e-greedy")) {
            System.err.println("Action selection method " + this.selectionMethod + " not supported. Exiting.");
            System.exit(0);
        } else {
            this.selectionParams.put("epsilon", Double.toString(Double.valueOf(this.selectionParams.get("epsilon")).doubleValue() * Double.valueOf(this.selectionParams.get("epsilonAnnealRate")).doubleValue()));
        }
    }

    public Action greedyActSelect(Observation observation, Action action) {
        return greedyActSelect(this.valFcnModel, observation, action);
    }

    private static Action greedyActSelect(RegressionModel regressionModel, Observation observation, Action action) {
        ArrayList<Action> maxActs = regressionModel.getMaxActs(observation, null);
        if (maxActs.size() == 0) {
            System.err.println("A list of zero maximum acts was returned by RegressionModel.getMaxActs(). Exiting.");
            System.err.println("state-action values: " + Arrays.toString(regressionModel.getStateActOutputs(observation, regressionModel.getFeatGen().getPossActions(observation))));
            System.exit(1);
        }
        boolean z = false;
        if (action != null) {
            Iterator<Action> it = maxActs.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Action next = it.next();
                if (Arrays.equals(next.intArray, action.intArray) && Arrays.equals(next.doubleArray, action.doubleArray) && Arrays.equals(next.charArray, action.charArray)) {
                    z = true;
                    break;
                }
            }
        }
        return z ? action.duplicate() : maxActs.get(FeatGenerator.staticRandGenerator.nextInt(maxActs.size()));
    }

    private static Action eGreedyActSelect(double d, RegressionModel regressionModel, Observation observation, Action action) {
        return new Random().nextDouble() > d ? greedyActSelect(regressionModel, observation, action) : regressionModel.getRandomAction();
    }

    private static Action treeSearchBasedExploitSelect(RegressionModel regressionModel, ObsActModel obsActModel, EnvTransModel envTransModel, Observation observation, Action action, DiscountTypes discountTypes) {
        if (envTransModel == null) {
            System.err.println("Attempting to treeSearch without setting an environment model.");
            System.exit(1);
        }
        if (obsActModel == null) {
            System.err.println("Attempting to treeSearch without setting ActionSelect.rewModel.");
            System.exit(1);
        }
        ArrayList<Action> possActions = regressionModel.getPossActions(observation);
        double d = Double.NEGATIVE_INFINITY;
        ArrayList arrayList = new ArrayList();
        int i = exhaustiveSearchDepth;
        if (randomizeSearchDepth) {
            i = FeatGenerator.staticRandGenerator.nextInt(exhaustiveSearchDepth + 1);
        }
        for (int i2 = 0; i2 < possActions.size(); i2++) {
            PathAndVal planWithStartAct = planWithStartAct(observation, possActions.get(i2), regressionModel, obsActModel, envTransModel, 0, i, action, discountTypes);
            if (planWithStartAct.getVal() > d) {
                arrayList.clear();
                d = planWithStartAct.getVal();
            }
            if (planWithStartAct.getVal() == d) {
                arrayList.add(planWithStartAct);
            }
        }
        return ((PathAndVal) arrayList.get(FeatGenerator.staticRandGenerator.nextInt(arrayList.size()))).getFirstAct();
    }

    private static PathAndVal plan(Observation observation, RegressionModel regressionModel, ObsActModel obsActModel, EnvTransModel envTransModel, int i, int i2, Action action, DiscountTypes discountTypes) {
        if (i2 == 0) {
            return new PathAndVal(sampleGreedyPath(observation, regressionModel, obsActModel, envTransModel, i + 1, greedyLeafPathLength, action, discountTypes).getVal());
        }
        ArrayList<Action> possActions = regressionModel.getPossActions(observation);
        ArrayList arrayList = new ArrayList();
        double d = Double.NEGATIVE_INFINITY;
        for (int i3 = 0; i3 < possActions.size(); i3++) {
            Action action2 = possActions.get(i3);
            ObsAndTerm sampleNextObs = envTransModel.sampleNextObs(observation, action2);
            PathAndVal plan = !sampleNextObs.getTerm() ? plan(sampleNextObs.getObs(), regressionModel, obsActModel, envTransModel, i + 1, i2 - 1, action2, discountTypes) : new PathAndVal();
            double predictLabel = obsActModel.predictLabel(observation, action2);
            plan.addObsAndActBeforePath(observation, action2, getDiscount(i, discountTypes) * predictLabel, predictLabel);
            if (plan.getVal() > d) {
                arrayList.clear();
                d = plan.getVal();
            }
            if (plan.getVal() == d) {
                arrayList.add(plan);
            }
        }
        return (PathAndVal) arrayList.get(FeatGenerator.staticRandGenerator.nextInt(arrayList.size()));
    }

    private static PathAndVal planWithStartAct(Observation observation, Action action, RegressionModel regressionModel, ObsActModel obsActModel, EnvTransModel envTransModel, int i, int i2, Action action2, DiscountTypes discountTypes) {
        if (i2 == 0) {
            return new PathAndVal(regressionModel.predictLabel(observation, action), action);
        }
        ObsAndTerm sampleNextObs = envTransModel.sampleNextObs(observation, action);
        PathAndVal pathAndVal = new PathAndVal();
        if (!sampleNextObs.getTerm()) {
            pathAndVal = plan(sampleNextObs.getObs(), regressionModel, obsActModel, envTransModel, i + 1, i2 - 1, action, discountTypes);
        }
        double predictLabel = obsActModel.predictLabel(observation, action);
        pathAndVal.addObsAndActBeforePath(observation, action, getDiscount(i, discountTypes) * predictLabel, predictLabel);
        return pathAndVal;
    }

    private static PathAndVal planWithStartAct2(Observation observation, Action action, RegressionModel regressionModel, ObsActModel obsActModel, EnvTransModel envTransModel, int i, int i2, Action action2, DiscountTypes discountTypes) {
        PathAndVal pathAndVal;
        double predictLabel = obsActModel.predictLabel(observation, action);
        double discount = getDiscount(i, discountTypes) * predictLabel;
        if (i2 == 1) {
            ObsAndTerm sampleNextObs = envTransModel.sampleNextObs(observation, action);
            PathAndVal pathAndVal2 = !sampleNextObs.getTerm() ? new PathAndVal(sampleGreedyPath(sampleNextObs.getObs(), regressionModel, obsActModel, envTransModel, i + 1, greedyLeafPathLength, action, discountTypes).getVal()) : new PathAndVal();
            pathAndVal2.addObsAndActBeforePath(observation, action, discount, predictLabel);
            return pathAndVal2;
        }
        ObsAndTerm sampleNextObs2 = envTransModel.sampleNextObs(observation, action);
        if (sampleNextObs2.getTerm()) {
            pathAndVal = new PathAndVal();
        } else {
            ArrayList<Action> possActions = regressionModel.getPossActions(sampleNextObs2.getObs());
            ArrayList arrayList = new ArrayList();
            double d = Double.NEGATIVE_INFINITY;
            for (int i3 = 0; i3 < possActions.size(); i3++) {
                PathAndVal planWithStartAct = planWithStartAct(sampleNextObs2.getObs(), possActions.get(i3), regressionModel, obsActModel, envTransModel, i + 1, i2 - 1, action, discountTypes);
                if (planWithStartAct.getVal() > d) {
                    arrayList.clear();
                    d = planWithStartAct.getVal();
                }
                if (planWithStartAct.getVal() == d) {
                    arrayList.add(planWithStartAct);
                }
            }
            pathAndVal = (PathAndVal) arrayList.get(FeatGenerator.staticRandGenerator.nextInt(arrayList.size()));
        }
        pathAndVal.addObsAndActBeforePath(observation, action, discount, predictLabel);
        return pathAndVal;
    }

    private static PathAndVal sampleGreedyPath(Observation observation, RegressionModel regressionModel, ObsActModel obsActModel, EnvTransModel envTransModel, int i, int i2, Action action, DiscountTypes discountTypes) {
        PathAndVal pathAndVal = new PathAndVal();
        for (int i3 = 0; i3 < i2; i3++) {
            Action greedyActSelect = greedyActSelect(regressionModel, observation, action);
            double predictLabel = obsActModel.predictLabel(observation, greedyActSelect);
            pathAndVal.addObsAndActToPathEnd(observation, greedyActSelect, getDiscount(i + i3, discountTypes) * predictLabel, predictLabel);
            action = greedyActSelect;
            ObsAndTerm sampleNextObs = envTransModel.sampleNextObs(observation, greedyActSelect);
            if (sampleNextObs.getTerm()) {
                return pathAndVal;
            }
            observation = sampleNextObs.getObs();
        }
        pathAndVal.leafValue = getDiscount(i + i2, discountTypes) * regressionModel.predictLabel(observation, greedyActSelect(regressionModel, observation, action));
        return pathAndVal;
    }

    private static double getDiscount(int i, DiscountTypes discountTypes) {
        double d = discountParam * i;
        if ((discountParam == Double.MAX_VALUE || Double.isInfinite(discountParam)) && i == 0) {
            d = 0.0d;
        }
        if (discountTypes == DiscountTypes.EXPON) {
            return Math.exp((-1.0d) * d);
        }
        if (discountTypes == DiscountTypes.HYPER) {
            return 1.0d / (1.0d + d);
        }
        System.err.println("Illegal discount type in ActionSelect.getDiscount(). Exiting.");
        System.exit(1);
        return Double.NaN;
    }

    public static double discParamToFactor(double d) {
        return Math.exp((-1.0d) * d);
    }

    public static double discFactorToParam(double d) {
        double log = (-1.0d) * Math.log(d);
        if (log == Double.NaN) {
            log = Double.MAX_VALUE;
        }
        return log;
    }
}
