package edu.umd.cs.psl.application.learning.weight.maxmargin;

import edu.umd.cs.psl.application.learning.weight.WeightLearningApplication;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.atom.ObservedAtom;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.model.kernel.GroundKernel;
import edu.umd.cs.psl.model.parameters.NegativeWeight;
import edu.umd.cs.psl.model.parameters.PositiveWeight;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/maxmargin/MaxMargin.class */
public abstract class MaxMargin extends WeightLearningApplication {
    public static final String CONFIG_PREFIX = "maxmargin";
    public static final String CUTTING_PLANE_TOLERANCE_KEY = "maxmargin.tolerance";
    public static final double CUTTING_PLANE_TOLERANCE_DEFAULT = 0.001d;
    public static final String SLACK_PENALTY_KEY = "maxmargin.slackpenalty";
    public static final double SLACK_PENALTY_DEFAULT = 1.0d;
    public static final String MAX_ITER_KEY = "maxmargin.maxiter";
    public static final int MAX_ITER_DEFAULT = 500;
    public static final String NONNEGATIVE_WEIGHTS_KEY = "maxmargin.nonnegativeweights";
    public static final boolean NONNEGATIVE_WEIGHTS_DEFAULT = true;
    public static final String SCALE_NORM_KEY = "maxmargin.scalenorm";
    public static final String SQUARE_SLACK_KEY = "maxmargin.squareslack";
    public static final boolean SQUARE_SLACK_DEFAULT = false;
    protected final double tolerance;
    protected final int maxIter;
    protected final boolean nonnegativeWeights;
    protected double slackPenalty;
    protected final NormScalingType scaleNorm;
    protected final boolean squareSlack;
    protected MinNormProgram normProgram;
    private static final Logger log = LoggerFactory.getLogger(MaxMargin.class);
    public static final NormScalingType SCALE_NORM_DEFAULT = NormScalingType.NONE;

    /* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/maxmargin/MaxMargin$NormScalingType.class */
    public enum NormScalingType {
        NONE,
        NUM_GROUNDINGS,
        INVERSE_NUM_GROUNDINGS;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static NormScalingType[] valuesCustom() {
            NormScalingType[] valuesCustom = values();
            int length = valuesCustom.length;
            NormScalingType[] normScalingTypeArr = new NormScalingType[length];
            System.arraycopy(valuesCustom, 0, normScalingTypeArr, 0, length);
            return normScalingTypeArr;
        }
    }

    public MaxMargin(Model model, Database database, Database database2, ConfigBundle configBundle) {
        super(model, database, database2, configBundle);
        this.tolerance = configBundle.getDouble("maxmargin.tolerance", 0.001d);
        this.maxIter = configBundle.getInt(MAX_ITER_KEY, 500);
        this.nonnegativeWeights = configBundle.getBoolean(NONNEGATIVE_WEIGHTS_KEY, true);
        this.slackPenalty = configBundle.getDouble(SLACK_PENALTY_KEY, 1.0d);
        this.scaleNorm = (NormScalingType) configBundle.getEnum(SCALE_NORM_KEY, SCALE_NORM_DEFAULT);
        this.squareSlack = configBundle.getBoolean(SQUARE_SLACK_KEY, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void initGroundModel() throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        super.initGroundModel();
        this.normProgram = new MinNormProgram(this.kernels.size() + 1, this.nonnegativeWeights, this.config);
        double[] dArr = new double[this.kernels.size() + 1];
        dArr[this.kernels.size()] = this.squareSlack ? 0.0d : this.slackPenalty;
        this.normProgram.setLinearCoefficients(dArr);
        double[] dArr2 = new double[this.kernels.size() + 1];
        if (NormScalingType.NONE.equals(this.scaleNorm)) {
            for (int i = 0; i < this.kernels.size(); i++) {
                dArr2[i] = 1.0d;
            }
        } else {
            int[] iArr = new int[this.kernels.size()];
            for (int i2 = 0; i2 < this.kernels.size(); i2++) {
                Iterator<GroundKernel> it = this.reasoner.getGroundKernels(this.kernels.get(i2)).iterator();
                while (it.hasNext()) {
                    it.next();
                    int i3 = i2;
                    iArr[i3] = iArr[i3] + 1;
                }
            }
            if (NormScalingType.NUM_GROUNDINGS.equals(this.scaleNorm)) {
                for (int i4 = 0; i4 < this.kernels.size(); i4++) {
                    dArr2[i4] = iArr[i4];
                }
            } else {
                if (!NormScalingType.INVERSE_NUM_GROUNDINGS.equals(this.scaleNorm)) {
                    throw new IllegalStateException("Unrecognized NormScalingType.");
                }
                for (int i5 = 0; i5 < this.kernels.size(); i5++) {
                    dArr2[i5] = ((double) iArr[i5]) > 0.0d ? 1.0d / iArr[i5] : 0.0d;
                }
            }
            double d = 0.0d;
            for (double d2 : dArr2) {
                d += d2 * d2;
            }
            double sqrt = Math.sqrt(this.kernels.size()) / Math.sqrt(d);
            for (int i6 = 0; i6 < this.kernels.size(); i6++) {
                int i7 = i6;
                dArr2[i7] = dArr2[i7] * sqrt;
            }
        }
        log.debug("Quad coeffs: {}", Arrays.toString(dArr2));
        dArr2[this.kernels.size()] = this.squareSlack ? this.slackPenalty : 0.0d;
        this.normProgram.setQuadraticTerm(dArr2, new double[this.kernels.size() + 1]);
    }

    @Override // edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    protected void doLearn() {
        double[] dArr = new double[this.kernels.size() + 1];
        for (int i = 0; i < this.kernels.size(); i++) {
            dArr[i] = this.kernels.get(i).getWeight().getWeight();
        }
        double[] dArr2 = new double[this.kernels.size()];
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            entry.getKey().setValue(entry.getValue().getValue());
        }
        for (int i2 = 0; i2 < this.kernels.size(); i2++) {
            Iterator<GroundKernel> it = this.reasoner.getGroundKernels(this.kernels.get(i2)).iterator();
            while (it.hasNext()) {
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + ((GroundCompatibilityKernel) it.next()).getIncompatibility();
            }
        }
        setupSeparationOracle();
        int i4 = 0;
        double d = Double.POSITIVE_INFINITY;
        while (i4 < this.maxIter && d > this.tolerance) {
            runSeparationOracle();
            double d2 = dArr[this.kernels.size()];
            double evaluateLoss = evaluateLoss();
            double[] dArr3 = new double[this.kernels.size() + 1];
            double d3 = 0.0d;
            for (int i5 = 0; i5 < this.kernels.size(); i5++) {
                double d4 = 0.0d;
                Iterator<GroundKernel> it2 = this.reasoner.getGroundKernels(this.kernels.get(i5)).iterator();
                while (it2.hasNext()) {
                    d4 += ((GroundCompatibilityKernel) it2.next()).getIncompatibility();
                }
                dArr3[i5] = dArr2[i5] - d4;
                d3 += dArr[i5] * dArr3[i5];
            }
            d = (d3 - d2) + evaluateLoss;
            log.debug("Violation of most recent constraint: {}", Double.valueOf(d));
            log.debug("Loss at most recent point: {}", Double.valueOf(evaluateLoss));
            log.debug("Slack: {}", Double.valueOf(d2));
            if (d > this.tolerance) {
                dArr3[this.kernels.size()] = -1.0d;
                this.normProgram.addInequalityConstraint(dArr3, (-1.0d) * evaluateLoss);
                try {
                    this.normProgram.solve();
                    dArr = this.normProgram.getSolution();
                    for (int i6 = 0; i6 < this.kernels.size(); i6++) {
                        if (!this.nonnegativeWeights || dArr[i6] >= 0.0d) {
                            this.kernels.get(i6).setWeight(new PositiveWeight(dArr[i6]));
                        } else {
                            this.kernels.get(i6).setWeight(new NegativeWeight(dArr[i6]));
                        }
                    }
                    this.reasoner.changedGroundKernelWeights();
                    i4++;
                } catch (IllegalArgumentException e) {
                    log.error("Norm minimization program failed (IllegalArgumentException). Returning early.");
                    return;
                } catch (IllegalStateException e2) {
                    log.error("Norm minimization program failed (IllegalStateException). Returning early.");
                    return;
                }
            }
        }
        log.debug("Number of separation oracle calls: {}", Integer.valueOf(i4));
        tearDownSeparationOracle();
    }

    protected abstract void setupSeparationOracle();

    protected abstract void runSeparationOracle();

    protected abstract double evaluateLoss();

    protected abstract void tearDownSeparationOracle();
}
