package edu.umd.cs.psl.application.learning.weight.em;

import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.atom.GroundAtom;
import edu.umd.cs.psl.model.atom.ObservedAtom;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.model.kernel.GroundKernel;
import edu.umd.cs.psl.model.kernel.linearconstraint.GroundValueConstraint;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/em/BernoulliMeanFieldEM.class */
public class BernoulliMeanFieldEM extends ExpectationMaximization {
    private static final Logger log = LoggerFactory.getLogger(BernoulliMeanFieldEM.class);
    public static final String CONFIG_PREFIX = "bernoullimeanfieldem";
    public static final String MPE_INITIALIZATION_KEY = "bernoullimeanfieldem.mpeinit";
    public static final boolean MPE_INITIALIZATION_DEFAULT = true;
    protected final Map<RandomVariableAtom, Double> means;
    protected final boolean mpeInit;

    public BernoulliMeanFieldEM(Model model, Database database, Database database2, ConfigBundle configBundle) {
        super(model, database, database2, configBundle);
        this.means = new HashMap();
        this.mpeInit = configBundle.getBoolean(MPE_INITIALIZATION_KEY, true);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.em.ExpectationMaximization, edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron, edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void doLearn() {
        super.doLearn();
        for (Map.Entry<RandomVariableAtom, Double> entry : this.means.entrySet()) {
            log.debug("Mean for {}: {}", entry.getKey(), entry.getValue());
        }
    }

    @Override // edu.umd.cs.psl.application.learning.weight.em.ExpectationMaximization
    protected void minimizeKLDivergence() {
        if (this.mpeInit) {
            setMeansToMPE();
        }
        setLabeledRandomVariables();
        log.debug("Starting KL divergence: {}", Double.valueOf(getKLDivergence()));
        Vector vector = new Vector();
        for (int i = 0; i < 10; i++) {
            for (RandomVariableAtom randomVariableAtom : this.trainingMap.getLatentVariables()) {
                double d = 0.0d;
                for (GroundKernel groundKernel : randomVariableAtom.getRegisteredGroundKernels()) {
                    if (!(groundKernel instanceof GroundCompatibilityKernel)) {
                        throw new IllegalStateException("Model contains a constraint: " + groundKernel);
                    }
                    if (this.reasoner.containsGroundKernel(groundKernel)) {
                        GroundCompatibilityKernel groundCompatibilityKernel = (GroundCompatibilityKernel) groundKernel;
                        vector.clear();
                        for (GroundAtom groundAtom : groundCompatibilityKernel.getAtoms()) {
                            if (this.trainingMap.getLatentVariables().contains(groundAtom)) {
                                vector.add((RandomVariableAtom) groundAtom);
                            }
                        }
                        if (vector.size() <= 0) {
                            throw new IllegalStateException("Expected there to be at least one incident latent RV.");
                        }
                        for (int i2 = 0; i2 < Math.pow(2.0d, vector.size()); i2++) {
                            double d2 = 1.0d;
                            double d3 = 1.0d;
                            for (int i3 = 0; i3 < vector.size(); i3++) {
                                double doubleValue = this.means.get(vector.get(i3)).doubleValue();
                                if (((i2 >> i3) & 1) == 1) {
                                    if (!randomVariableAtom.equals(vector.get(i3))) {
                                        d2 *= doubleValue;
                                    }
                                    ((RandomVariableAtom) vector.get(i3)).setValue(1.0d);
                                } else {
                                    if (randomVariableAtom.equals(vector.get(i3))) {
                                        d3 = -1.0d;
                                    } else {
                                        d2 *= 1.0d - doubleValue;
                                    }
                                    ((RandomVariableAtom) vector.get(i3)).setValue(0.0d);
                                }
                            }
                            d += groundCompatibilityKernel.getWeight().getWeight() * groundCompatibilityKernel.getIncompatibility() * d2 * d3;
                        }
                    } else {
                        log.warn("Ground kernel {} registered to atom {} is not in the current distribution. Skipping.", groundKernel, randomVariableAtom);
                    }
                }
                double exp = 1.0d / (1.0d + Math.exp(d));
                if (exp == 0.0d) {
                    exp = 1.0E-4d;
                } else if (exp == 1.0d) {
                    exp = 0.9999d;
                }
                this.means.put(randomVariableAtom, Double.valueOf(exp));
            }
            log.debug("KL divergence after round {}: {}", Integer.valueOf(i + 1), Double.valueOf(getKLDivergence()));
        }
    }

    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron
    protected double[] computeObservedIncomp() {
        this.numGroundings = new double[this.kernels.size()];
        double[] dArr = new double[this.kernels.size()];
        setLabeledRandomVariables();
        Vector vector = new Vector();
        for (int i = 0; i < this.kernels.size(); i++) {
            Iterator<GroundKernel> it = this.reasoner.getGroundKernels(this.kernels.get(i)).iterator();
            while (it.hasNext()) {
                GroundCompatibilityKernel groundCompatibilityKernel = (GroundCompatibilityKernel) it.next();
                vector.clear();
                for (GroundAtom groundAtom : groundCompatibilityKernel.getAtoms()) {
                    if (this.trainingMap.getLatentVariables().contains(groundAtom)) {
                        vector.add((RandomVariableAtom) groundAtom);
                    }
                }
                for (int i2 = 0; i2 < Math.pow(2.0d, vector.size()); i2++) {
                    double d = 1.0d;
                    for (int i3 = 0; i3 < vector.size(); i3++) {
                        double doubleValue = this.means.get(vector.get(i3)).doubleValue();
                        if (((i2 >> i3) & 1) == 1) {
                            d *= doubleValue;
                            ((RandomVariableAtom) vector.get(i3)).setValue(1.0d);
                        } else {
                            d *= 1.0d - doubleValue;
                            ((RandomVariableAtom) vector.get(i3)).setValue(0.0d);
                        }
                    }
                    int i4 = i;
                    dArr[i4] = dArr[i4] + (groundCompatibilityKernel.getIncompatibility() * d);
                }
                double[] dArr2 = this.numGroundings;
                int i5 = i;
                dArr2[i5] = dArr2[i5] + 1.0d;
            }
        }
        return dArr;
    }

    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron
    protected double[] computeExpectedIncomp() {
        double[] dArr = new double[this.kernels.size()];
        this.reasoner.optimize();
        for (int i = 0; i < this.kernels.size(); i++) {
            Iterator<GroundKernel> it = this.reasoner.getGroundKernels(this.kernels.get(i)).iterator();
            while (it.hasNext()) {
                int i2 = i;
                dArr[i2] = dArr[i2] + ((GroundCompatibilityKernel) it.next()).getIncompatibility();
            }
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.em.ExpectationMaximization, edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void initGroundModel() throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        super.initGroundModel();
        if (this.mpeInit) {
            return;
        }
        this.means.clear();
        Iterator<RandomVariableAtom> it = this.trainingMap.getLatentVariables().iterator();
        while (it.hasNext()) {
            this.means.put(it.next(), Double.valueOf(0.5d));
        }
    }

    protected void setMeansToMPE() {
        log.debug("Running MPE inference to initialize mean field.");
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            arrayList.add(new GroundValueConstraint(entry.getKey(), entry.getValue().getValue()));
        }
        this.reasoner.optimize();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            this.reasoner.removeGroundKernel((GroundValueConstraint) it.next());
        }
        this.means.clear();
        for (RandomVariableAtom randomVariableAtom : this.trainingMap.getLatentVariables()) {
            this.means.put(randomVariableAtom, Double.valueOf(randomVariableAtom.getValue()));
        }
    }

    protected double getKLDivergence() {
        double d = 0.0d;
        for (Double d2 : this.means.values()) {
            d += (d2.doubleValue() * Math.log(d2.doubleValue())) + ((1.0d - d2.doubleValue()) * Math.log(1.0d - d2.doubleValue()));
        }
        setLabeledRandomVariables();
        Vector vector = new Vector();
        for (GroundCompatibilityKernel groundCompatibilityKernel : this.reasoner.getCompatibilityKernels()) {
            vector.clear();
            for (GroundAtom groundAtom : groundCompatibilityKernel.getAtoms()) {
                if (this.trainingMap.getLatentVariables().contains(groundAtom)) {
                    vector.add((RandomVariableAtom) groundAtom);
                }
            }
            for (int i = 0; i < Math.pow(2.0d, vector.size()); i++) {
                double d3 = 1.0d;
                for (int i2 = 0; i2 < vector.size(); i2++) {
                    double doubleValue = this.means.get(vector.get(i2)).doubleValue();
                    if (((i >> i2) & 1) == 1) {
                        d3 *= doubleValue;
                        ((RandomVariableAtom) vector.get(i2)).setValue(1.0d);
                    } else {
                        d3 *= 1.0d - doubleValue;
                        ((RandomVariableAtom) vector.get(i2)).setValue(0.0d);
                    }
                }
                d += groundCompatibilityKernel.getWeight().getWeight() * groundCompatibilityKernel.getIncompatibility() * d3;
            }
        }
        return d;
    }
}
