package edu.umd.cs.psl.application.learning.weight.em;

import edu.umd.cs.psl.application.learning.weight.TrainingMap;
import edu.umd.cs.psl.application.learning.weight.WeightLearningApplication;
import edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron;
import edu.umd.cs.psl.application.util.Grounding;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.atom.ObservedAtom;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.kernel.CompatibilityKernel;
import edu.umd.cs.psl.model.kernel.linearconstraint.GroundValueConstraint;
import edu.umd.cs.psl.model.parameters.PositiveWeight;
import edu.umd.cs.psl.reasoner.Reasoner;
import edu.umd.cs.psl.reasoner.ReasonerFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/em/ExpectationMaximization.class */
public abstract class ExpectationMaximization extends VotedPerceptron {
    private static final Logger log = LoggerFactory.getLogger(ExpectationMaximization.class);
    public static final String CONFIG_PREFIX = "em";
    public static final String ITER_KEY = "em.iterations";
    public static final int ITER_DEFAULT = 10;
    public static final String RESET_SCHEDULE_KEY = "em.resetschedule";
    public static final boolean RESET_SCHEDULE_DEFAULT = true;
    public static final String STORE_WEIGHTS_KEY = "em.storeweights";
    public static final boolean STORE_WEIGHTS_DEFAULT = false;
    public static final String TOLERANCE_KEY = "em.tolerance";
    public static final double TOLERANCE_DEFAULT = 0.001d;
    protected final int iterations;
    protected final double tolerance;
    protected final boolean resetSchedule;
    private int round;
    protected final boolean storeWeights;
    protected ArrayList<Map<CompatibilityKernel, Double>> storedWeights;
    protected Reasoner latentVariableReasoner;

    public ExpectationMaximization(Model model, Database database, Database database2, ConfigBundle configBundle) {
        super(model, database, database2, configBundle);
        this.iterations = configBundle.getInt(ITER_KEY, 10);
        this.tolerance = configBundle.getDouble(TOLERANCE_KEY, 0.001d);
        this.resetSchedule = configBundle.getBoolean(RESET_SCHEDULE_KEY, true);
        this.latentVariableReasoner = null;
        this.storeWeights = configBundle.getBoolean(STORE_WEIGHTS_KEY, false);
        if (this.storeWeights) {
            this.storedWeights = new ArrayList<>();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron, edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void doLearn() {
        double[] dArr = new double[this.kernels.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.kernels.get(i).getWeight().getWeight();
        }
        double[] dArr2 = new double[this.kernels.size()];
        this.round = 0;
        while (true) {
            int i2 = this.round;
            this.round = i2 + 1;
            if (i2 >= this.iterations) {
                break;
            }
            log.debug("Beginning EM round {} of {}", Integer.valueOf(this.round), Integer.valueOf(this.iterations));
            minimizeKLDivergence();
            super.doLearn();
            double d = 0.0d;
            for (int i3 = 0; i3 < this.kernels.size(); i3++) {
                d += Math.pow(dArr[i3] - this.kernels.get(i3).getWeight().getWeight(), 2.0d);
                dArr[i3] = this.kernels.get(i3).getWeight().getWeight();
                dArr2[i3] = ((1.0d - (1.0d / this.round)) * dArr2[i3]) + ((1.0d / this.round) * dArr[i3]);
            }
            if (this.storeWeights) {
                HashMap hashMap = new HashMap();
                for (int i4 = 0; i4 < this.kernels.size(); i4++) {
                    double d2 = this.averageSteps ? dArr2[i4] : dArr[i4];
                    if (d2 > 0.0d) {
                        hashMap.put(this.kernels.get(i4), Double.valueOf(d2));
                    }
                }
                this.storedWeights.add(hashMap);
            }
            double loss = getLoss();
            double computeRegularizer = computeRegularizer();
            double d3 = loss + computeRegularizer;
            double sqrt = Math.sqrt(d);
            if (sqrt <= this.tolerance) {
                log.info("EM converged with m-step norm {} in {} rounds. Loss: " + loss, Double.valueOf(sqrt), Integer.valueOf(this.round));
                break;
            }
            log.info("Finished EM round {} with m-step norm {}. Loss: " + loss + ", regularizer: " + computeRegularizer + ", objective: " + d3, Integer.valueOf(this.round), Double.valueOf(sqrt));
        }
        if (this.averageSteps) {
            for (int i5 = 0; i5 < this.kernels.size(); i5++) {
                this.kernels.get(i5).setWeight(new PositiveWeight(dArr2[i5]));
            }
        }
    }

    protected abstract void minimizeKLDivergence();

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void initGroundModel() throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        this.trainingMap = new TrainingMap(this.rvDB, this.observedDB);
        this.reasoner = ((ReasonerFactory) this.config.getFactory(WeightLearningApplication.REASONER_KEY, REASONER_DEFAULT)).getReasoner(this.config);
        Grounding.groundAll(this.model, this.trainingMap, this.reasoner);
        if (this.latentVariableReasoner != null) {
            this.latentVariableReasoner.close();
        }
        this.latentVariableReasoner = ((ReasonerFactory) this.config.getFactory(WeightLearningApplication.REASONER_KEY, REASONER_DEFAULT)).getReasoner(this.config);
        Grounding.groundAll(this.model, this.trainingMap, this.latentVariableReasoner);
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            this.latentVariableReasoner.addGroundKernel(new GroundValueConstraint(entry.getKey(), entry.getValue().getValue()));
        }
    }

    public void inferLatentVariables() {
        if (this.latentVariableReasoner == null) {
            throw new IllegalStateException("A model must have been learned before latent variables can be inferred.");
        }
        this.latentVariableReasoner.changedGroundKernelWeights();
        this.latentVariableReasoner.optimize();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron
    public double getStepSize(int i) {
        return (!this.scheduleStepSize || this.resetSchedule) ? super.getStepSize(i) : this.stepSize / ((((this.round - 1) * this.numSteps) + i) + 1);
    }

    public ArrayList<Map<CompatibilityKernel, Double>> getStoredWeights() {
        if (this.storeWeights) {
            return this.storedWeights;
        }
        return null;
    }

    @Override // edu.umd.cs.psl.application.learning.weight.WeightLearningApplication, edu.umd.cs.psl.application.ModelApplication
    public void close() {
        super.close();
        if (this.latentVariableReasoner != null) {
            this.latentVariableReasoner.close();
            this.latentVariableReasoner = null;
        }
    }
}
