package edu.umd.cs.psl.application.learning.weight.maxlikelihood;

import com.google.common.collect.Iterables;
import edu.umd.cs.psl.application.learning.weight.WeightLearningApplication;
import edu.umd.cs.psl.application.learning.weight.maxmargin.LossAugmentingGroundKernel;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.atom.ObservedAtom;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.model.kernel.GroundKernel;
import edu.umd.cs.psl.model.parameters.NegativeWeight;
import edu.umd.cs.psl.model.parameters.PositiveWeight;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/maxlikelihood/VotedPerceptron.class */
public abstract class VotedPerceptron extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger(VotedPerceptron.class);
    public static final String CONFIG_PREFIX = "votedperceptron";
    public static final String AUGMENT_LOSS_KEY = "votedperceptron.augmentloss";
    public static final boolean AUGMENT_LOSS_DEFAULT = false;
    public static final String L2_REGULARIZATION_KEY = "votedperceptron.l2regularization";
    public static final double L2_REGULARIZATION_DEFAULT = 0.0d;
    public static final String L1_REGULARIZATION_KEY = "votedperceptron.l1regularization";
    public static final double L1_REGULARIZATION_DEFAULT = 0.0d;
    public static final String STEP_SIZE_KEY = "votedperceptron.stepsize";
    public static final double STEP_SIZE_DEFAULT = 1.0d;
    public static final String STEP_SCHEDULE_KEY = "votedperceptron.schedule";
    public static final boolean STEP_SCHEDULE_DEFAULT = true;
    public static final String SCALE_GRADIENT_KEY = "votedperceptron.scalegradient";
    public static final boolean SCALE_GRADIENT_DEFAULT = true;
    public static final String AVERAGE_STEPS_KEY = "votedperceptron.averagesteps";
    public static final boolean AVERAGE_STEPS_DEFAULT = true;
    public static final String NUM_STEPS_KEY = "votedperceptron.numsteps";
    public static final int NUM_STEPS_DEFAULT = 25;
    public static final String NONNEGATIVE_WEIGHTS_KEY = "votedperceptron.nonnegativeweights";
    public static final boolean NONNEGATIVE_WEIGHTS_DEFAULT = true;
    protected double[] numGroundings;
    protected final double stepSize;
    protected final int numSteps;
    protected final double l2Regularization;
    protected final double l1Regularization;
    protected final boolean augmentLoss;
    protected final boolean scheduleStepSize;
    protected final boolean scaleGradient;
    protected final boolean averageSteps;
    protected final boolean nonnegativeWeights;
    protected double[] truthIncompatibility;
    protected double[] expectedIncompatibility;
    protected boolean toStop;
    private double loss;

    /* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/maxlikelihood/VotedPerceptron$IntermediateState.class */
    public class IntermediateState {
        public final int step;
        public final int maxStep;

        public IntermediateState(int i, int i2) {
            this.step = i;
            this.maxStep = i2;
        }
    }

    public VotedPerceptron(Model model, Database database, Database database2, ConfigBundle configBundle) {
        super(model, database, database2, configBundle);
        this.toStop = false;
        this.loss = Double.POSITIVE_INFINITY;
        this.stepSize = configBundle.getDouble(STEP_SIZE_KEY, 1.0d);
        if (this.stepSize <= 0.0d) {
            throw new IllegalArgumentException("Step size must be positive.");
        }
        this.numSteps = configBundle.getInt(NUM_STEPS_KEY, 25);
        if (this.numSteps <= 0) {
            throw new IllegalArgumentException("Number of steps must be positive.");
        }
        this.l2Regularization = configBundle.getDouble(L2_REGULARIZATION_KEY, 0.0d);
        if (this.l2Regularization < 0.0d) {
            throw new IllegalArgumentException("L2 regularization parameter must be non-negative.");
        }
        this.l1Regularization = configBundle.getDouble(L1_REGULARIZATION_KEY, 0.0d);
        if (this.l1Regularization < 0.0d) {
            throw new IllegalArgumentException("L1 regularization parameter must be non-negative.");
        }
        this.augmentLoss = configBundle.getBoolean(AUGMENT_LOSS_KEY, false);
        this.scheduleStepSize = configBundle.getBoolean(STEP_SCHEDULE_KEY, true);
        this.scaleGradient = configBundle.getBoolean(SCALE_GRADIENT_KEY, true);
        this.averageSteps = configBundle.getBoolean(AVERAGE_STEPS_KEY, true);
        this.nonnegativeWeights = configBundle.getBoolean(NONNEGATIVE_WEIGHTS_KEY, true);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void addLossAugmentedKernels() {
        LossAugmentingGroundKernel lossAugmentingGroundKernel;
        ArrayList arrayList = new ArrayList(this.trainingMap.getTrainingMap().size());
        ArrayList arrayList2 = new ArrayList();
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            double value = entry.getValue().getValue();
            if (value == 1.0d || value == 0.0d) {
                lossAugmentingGroundKernel = new LossAugmentingGroundKernel(entry.getKey(), value, new NegativeWeight(value == 1.0d ? -1.0d : -1.0d));
            } else {
                lossAugmentingGroundKernel = value >= 0.5d ? new LossAugmentingGroundKernel(entry.getKey(), 1.0d, new NegativeWeight(-1.0d)) : new LossAugmentingGroundKernel(entry.getKey(), 1.0d, new PositiveWeight((-1.0d) * (-1.0d)));
                arrayList2.add(lossAugmentingGroundKernel);
            }
            this.reasoner.addGroundKernel(lossAugmentingGroundKernel);
            arrayList.add(lossAugmentingGroundKernel);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void removeLossAugmentedKernels() {
        ArrayList arrayList = new ArrayList();
        Iterator it = Iterables.filter(this.reasoner.getGroundKernels(), LossAugmentingGroundKernel.class).iterator();
        while (it.hasNext()) {
            arrayList.add((LossAugmentingGroundKernel) it.next());
        }
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            this.reasoner.removeGroundKernel((LossAugmentingGroundKernel) it2.next());
        }
        new ArrayList();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getStepSize(int i) {
        return this.scheduleStepSize ? this.stepSize / (i + 1) : this.stepSize;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void doLearn() {
        double[] dArr = new double[this.kernels.size()];
        this.truthIncompatibility = computeObservedIncomp();
        if (this.augmentLoss) {
            addLossAugmentedKernels();
        }
        Iterator<Map.Entry<RandomVariableAtom, ObservedAtom>> it = this.trainingMap.getTrainingMap().entrySet().iterator();
        while (it.hasNext()) {
            it.next().getKey().setValue(0.0d);
        }
        Iterator<RandomVariableAtom> it2 = this.trainingMap.getLatentVariables().iterator();
        while (it2.hasNext()) {
            it2.next().setValue(0.0d);
        }
        for (int i = 0; i < this.numSteps; i++) {
            log.debug("Starting iter {}", Integer.valueOf(i + 1));
            this.expectedIncompatibility = computeExpectedIncomp();
            double[] computeScalingFactor = computeScalingFactor();
            this.loss = computeLoss();
            for (int i2 = 0; i2 < this.kernels.size(); i2++) {
                double weight = this.kernels.get(i2).getWeight().getWeight();
                double stepSize = ((((this.expectedIncompatibility[i2] - this.truthIncompatibility[i2]) - (this.l2Regularization * weight)) - this.l1Regularization) / computeScalingFactor[i2]) * getStepSize(i);
                log.debug("Step of {} for kernel {}", Double.valueOf(stepSize), this.kernels.get(i2));
                log.debug(" --- Expected incomp.: {}, Truth incomp.: {}", Double.valueOf(this.expectedIncompatibility[i2]), Double.valueOf(this.truthIncompatibility[i2]));
                double d = weight + stepSize + stepSize;
                if (this.nonnegativeWeights) {
                    d = Math.max(d, 0.0d);
                }
                int i3 = i2;
                dArr[i3] = dArr[i3] + d;
                this.kernels.get(i2).setWeight(d >= 0.0d ? new PositiveWeight(d) : new NegativeWeight(d));
            }
            this.reasoner.changedGroundKernelWeights();
            setChanged();
            notifyObservers(new IntermediateState(i, this.numSteps));
            if (this.toStop) {
                break;
            }
        }
        if (this.averageSteps) {
            for (int i4 = 0; i4 < this.kernels.size(); i4++) {
                double d2 = dArr[i4] / this.numSteps;
                this.kernels.get(i4).setWeight(d2 >= 0.0d ? new PositiveWeight(d2) : new NegativeWeight(d2));
            }
            this.reasoner.changedGroundKernelWeights();
        }
        if (this.augmentLoss) {
            removeLossAugmentedKernels();
        }
    }

    protected double[] computeObservedIncomp() {
        this.numGroundings = new double[this.kernels.size()];
        double[] dArr = new double[this.kernels.size()];
        setLabeledRandomVariables();
        for (int i = 0; i < this.kernels.size(); i++) {
            Iterator<GroundKernel> it = this.reasoner.getGroundKernels(this.kernels.get(i)).iterator();
            while (it.hasNext()) {
                int i2 = i;
                dArr[i2] = dArr[i2] + ((GroundCompatibilityKernel) it.next()).getIncompatibility();
                double[] dArr2 = this.numGroundings;
                int i3 = i;
                dArr2[i3] = dArr2[i3] + 1.0d;
            }
        }
        return dArr;
    }

    protected abstract double[] computeExpectedIncomp();

    /* JADX INFO: Access modifiers changed from: protected */
    public double computeRegularizer() {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < this.kernels.size(); i++) {
            d += Math.pow(this.kernels.get(i).getWeight().getWeight(), 2.0d);
            d2 += Math.abs(this.kernels.get(i).getWeight().getWeight());
        }
        return (0.5d * this.l2Regularization * d) + (this.l1Regularization * d2);
    }

    protected double computeLoss() {
        return Double.POSITIVE_INFINITY;
    }

    public double getLoss() {
        return this.loss;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] computeScalingFactor() {
        double[] dArr = new double[this.numGroundings.length];
        for (int i = 0; i < this.numGroundings.length; i++) {
            dArr[i] = (!this.scaleGradient || this.numGroundings[i] <= 0.0d) ? 1.0d : this.numGroundings[i];
        }
        return dArr;
    }

    public void stop() {
        this.toStop = true;
    }
}
