package edu.umd.cs.psl.application.learning.weight.random;

import com.google.common.collect.Iterables;
import edu.umd.cs.psl.application.ModelApplication;
import edu.umd.cs.psl.application.learning.weight.TrainingMap;
import edu.umd.cs.psl.application.learning.weight.maxmargin.LossAugmentingGroundKernel;
import edu.umd.cs.psl.application.learning.weight.maxmargin.MinNormProgram;
import edu.umd.cs.psl.application.util.Grounding;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.atom.ObservedAtom;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.kernel.CompatibilityKernel;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.model.parameters.NegativeWeight;
import edu.umd.cs.psl.model.parameters.PositiveWeight;
import edu.umd.cs.psl.reasoner.Reasoner;
import edu.umd.cs.psl.reasoner.ReasonerFactory;
import edu.umd.cs.psl.reasoner.admm.ADMMReasonerFactory;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/random/HardEMRandOM.class */
public class HardEMRandOM implements ModelApplication {
    public static final String CONFIG_PREFIX = "maxmargin";
    public static final String CUTTING_PLANE_TOLERANCE = "maxmargin.tolerance";
    public static final double CUTTING_PLANE_TOLERANCE_DEFAULT = 1.0E-5d;
    public static final String SLACK_PENALTY = "maxmargin.slack_penalty";
    public static final double SLACK_PENALTY_DEFAULT = 1.0d;
    public static final String MAX_INNER_ITER = "maxmargin.max_inner_iter";
    public static final int MAX_INNER_ITER_DEFAULT = 500;
    public static final String MAX_OUTER_ITER = "maxmargin.max_outer_iter";
    public static final int MAX_OUTER_ITER_DEFAULT = 500;
    public static final String CHANGE_THRESHOLD = "maxmargin.change_threshold";
    public static final double CHANGE_THRESHOLD_DEFAULT = 0.001d;
    public static final String REASONER_KEY = "maxmargin.reasoner";
    private Model model;
    private Database rvDB;
    private Database observedDB;
    private ConfigBundle config;
    private final double tolerance;
    private final int maxInnerIter;
    private final int maxOuterIter;
    private double slackPenalty;
    private double changeThreshold;
    private static final Logger log = LoggerFactory.getLogger(HardEMRandOM.class);
    public static final ReasonerFactory REASONER_DEFAULT = new ADMMReasonerFactory();

    public HardEMRandOM(Model model, Database database, Database database2, ConfigBundle configBundle) {
        this.model = model;
        this.rvDB = database;
        this.observedDB = database2;
        this.config = configBundle;
        this.tolerance = configBundle.getDouble("maxmargin.tolerance", 1.0E-5d);
        this.maxInnerIter = configBundle.getInt(MAX_INNER_ITER, 500);
        this.maxOuterIter = configBundle.getInt(MAX_OUTER_ITER, 500);
        this.slackPenalty = configBundle.getDouble(SLACK_PENALTY, 1.0d);
        this.changeThreshold = configBundle.getDouble(CHANGE_THRESHOLD, 0.001d);
    }

    public void setSlackPenalty(double d) {
        this.slackPenalty = d;
    }

    public void learn() throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        Reasoner reasoner = ((ReasonerFactory) this.config.getFactory(REASONER_KEY, REASONER_DEFAULT)).getReasoner(this.config);
        TrainingMap trainingMap = new TrainingMap(this.rvDB, this.observedDB);
        if (trainingMap.getLatentVariables().size() > 0) {
            throw new IllegalArgumentException("All RandomVariableAtoms must have corresponding ObservedAtoms. Latent variables are not supported by MaxMargin.");
        }
        Grounding.groundAll(this.model, trainingMap, reasoner);
        ArrayList arrayList = new ArrayList();
        Iterator it = Iterables.filter(reasoner.getGroundKernels(), GroundCompatibilityKernel.class).iterator();
        while (it.hasNext()) {
            arrayList.add((GroundCompatibilityKernel) it.next());
        }
        double[] dArr = new double[arrayList.size() + 1];
        double[] dArr2 = new double[arrayList.size()];
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : trainingMap.getTrainingMap().entrySet()) {
            entry.getKey().setValue(entry.getValue().getValue());
        }
        for (int i = 0; i < arrayList.size(); i++) {
            GroundCompatibilityKernel groundCompatibilityKernel = arrayList.get(i);
            int i2 = i;
            dArr2[i2] = dArr2[i2] + groundCompatibilityKernel.getIncompatibility();
            dArr[i] = groundCompatibilityKernel.getWeight().getWeight();
        }
        boolean z = false;
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry2 : trainingMap.getTrainingMap().entrySet()) {
            entry2.getKey().setValue(entry2.getValue().getValue());
            reasoner.addGroundKernel(new LossAugmentingGroundKernel(entry2.getKey(), entry2.getValue().getValue(), new NegativeWeight(-1.0d)));
        }
        int i3 = 0;
        while (!z) {
            int i4 = 0;
            double d = Double.POSITIVE_INFINITY;
            MinNormProgram minNormProgram = new MinNormProgram(arrayList.size() + 1, true, this.config);
            double[] dArr3 = new double[arrayList.size() + 1];
            dArr3[arrayList.size()] = this.slackPenalty;
            minNormProgram.setLinearCoefficients(dArr3);
            double[] dArr4 = new double[arrayList.size() + 1];
            for (int i5 = 0; i5 < arrayList.size(); i5++) {
                dArr4[i5] = 1.0d;
            }
            dArr4[arrayList.size()] = 0.0d;
            minNormProgram.setQuadraticTerm(dArr4, getOrigin(arrayList));
            while (i4 < this.maxInnerIter && d > this.tolerance) {
                reasoner.optimize();
                double[] dArr5 = new double[arrayList.size() + 1];
                double d2 = 0.0d;
                for (Map.Entry<RandomVariableAtom, ObservedAtom> entry3 : trainingMap.getTrainingMap().entrySet()) {
                    d2 += Math.abs(entry3.getKey().getValue() - entry3.getValue().getValue());
                }
                double d3 = 0.0d;
                for (int i6 = 0; i6 < arrayList.size(); i6++) {
                    dArr5[i6] = dArr2[i6] - arrayList.get(i6).getIncompatibility();
                    d3 += dArr[i6] * dArr5[i6];
                }
                d = (d3 - dArr[arrayList.size()]) + d2;
                dArr5[arrayList.size()] = -1.0d;
                minNormProgram.addInequalityConstraint(dArr5, (-1.0d) * d2);
                minNormProgram.solve();
                dArr = minNormProgram.getSolution();
                for (int i7 = 0; i7 < arrayList.size(); i7++) {
                    arrayList.get(i7).setWeight(new PositiveWeight(dArr[i7]));
                }
                reasoner.changedGroundKernelWeights();
                i4++;
                log.debug("Violation: {}", Double.valueOf(d));
                log.debug("Slack: {}", Double.valueOf(dArr[arrayList.size()]));
                log.debug("Model: {}", this.model);
            }
            double d4 = 0.0d;
            for (CompatibilityKernel compatibilityKernel : Iterables.filter(this.model.getKernels(), CompatibilityKernel.class)) {
                double d5 = 0.0d;
                int i8 = 0;
                Iterator it2 = Iterables.filter(reasoner.getGroundKernels(compatibilityKernel), GroundCompatibilityKernel.class).iterator();
                while (it2.hasNext()) {
                    d5 += ((GroundCompatibilityKernel) it2.next()).getWeight().getWeight();
                    i8++;
                }
                double d6 = d5 / i8;
                d4 += Math.abs(d6 - compatibilityKernel.getWeight().getWeight());
                compatibilityKernel.setWeight(new PositiveWeight(d6));
            }
            i3++;
            if (d4 < this.changeThreshold || i3 > this.maxOuterIter) {
                z = true;
            }
        }
    }

    private double[] getOrigin(List<GroundCompatibilityKernel> list) {
        double[] dArr = new double[list.size() + 1];
        for (int i = 0; i < list.size(); i++) {
            dArr[i] = list.get(i).getKernel().getWeight().getWeight();
        }
        return dArr;
    }

    @Override // edu.umd.cs.psl.application.ModelApplication
    public void close() {
        this.model = null;
        this.rvDB = null;
        this.config = null;
    }
}
