package edu.umd.cs.psl.application.learning.weight.maxlikelihood;

import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.model.kernel.GroundKernel;
import java.util.Arrays;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/maxlikelihood/MaxLikelihoodMPE.class */
public class MaxLikelihoodMPE extends VotedPerceptron {
    double[] fullObservedIncompatibility;
    double[] fullExpectedIncompatibility;

    public MaxLikelihoodMPE(Model model, Database database, Database database2, ConfigBundle configBundle) {
        super(model, database, database2, configBundle);
    }

    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron
    protected double[] computeExpectedIncomp() {
        this.fullExpectedIncompatibility = new double[this.kernels.size() + this.immutableKernels.size()];
        this.reasoner.optimize();
        for (int i = 0; i < this.kernels.size(); i++) {
            for (GroundKernel groundKernel : this.reasoner.getGroundKernels(this.kernels.get(i))) {
                double[] dArr = this.fullExpectedIncompatibility;
                int i2 = i;
                dArr[i2] = dArr[i2] + ((GroundCompatibilityKernel) groundKernel).getIncompatibility();
            }
        }
        for (int i3 = 0; i3 < this.immutableKernels.size(); i3++) {
            for (GroundKernel groundKernel2 : this.reasoner.getGroundKernels(this.immutableKernels.get(i3))) {
                double[] dArr2 = this.fullExpectedIncompatibility;
                int size = this.kernels.size() + i3;
                dArr2[size] = dArr2[size] + ((GroundCompatibilityKernel) groundKernel2).getIncompatibility();
            }
        }
        return Arrays.copyOf(this.fullExpectedIncompatibility, this.kernels.size());
    }

    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron
    protected double[] computeObservedIncomp() {
        this.numGroundings = new double[this.kernels.size()];
        this.fullObservedIncompatibility = new double[this.kernels.size() + this.immutableKernels.size()];
        setLabeledRandomVariables();
        for (int i = 0; i < this.kernels.size(); i++) {
            for (GroundKernel groundKernel : this.reasoner.getGroundKernels(this.kernels.get(i))) {
                double[] dArr = this.fullObservedIncompatibility;
                int i2 = i;
                dArr[i2] = dArr[i2] + ((GroundCompatibilityKernel) groundKernel).getIncompatibility();
                double[] dArr2 = this.numGroundings;
                int i3 = i;
                dArr2[i3] = dArr2[i3] + 1.0d;
            }
        }
        for (int i4 = 0; i4 < this.immutableKernels.size(); i4++) {
            for (GroundKernel groundKernel2 : this.reasoner.getGroundKernels(this.immutableKernels.get(i4))) {
                double[] dArr3 = this.fullObservedIncompatibility;
                int size = this.kernels.size() + i4;
                dArr3[size] = dArr3[size] + ((GroundCompatibilityKernel) groundKernel2).getIncompatibility();
            }
        }
        return Arrays.copyOf(this.fullObservedIncompatibility, this.kernels.size());
    }

    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron
    protected double computeLoss() {
        double d = 0.0d;
        for (int i = 0; i < this.kernels.size(); i++) {
            d += this.kernels.get(i).getWeight().getWeight() * (this.fullObservedIncompatibility[i] - this.fullExpectedIncompatibility[i]);
        }
        for (int i2 = 0; i2 < this.immutableKernels.size(); i2++) {
            d += this.immutableKernels.get(i2).getWeight().getWeight() * (this.fullObservedIncompatibility[this.kernels.size() + i2] - this.fullExpectedIncompatibility[this.kernels.size() + i2]);
        }
        return d;
    }
}
