package edu.umd.cs.psl.application.learning.weight.maxlikelihood;

import edu.umd.cs.psl.application.learning.weight.WeightLearningApplication;
import edu.umd.cs.psl.application.util.Grounding;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.atom.AtomEventFramework;
import edu.umd.cs.psl.model.atom.GroundAtom;
import edu.umd.cs.psl.model.atom.ObservedAtom;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.model.kernel.GroundKernel;
import edu.umd.cs.psl.model.kernel.Kernel;
import edu.umd.cs.psl.model.kernel.linearconstraint.GroundValueConstraint;
import edu.umd.cs.psl.model.predicate.StandardPredicate;
import edu.umd.cs.psl.reasoner.ReasonerFactory;
import edu.umd.cs.psl.util.database.Queries;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/maxlikelihood/LazyMaxLikelihoodMPE.class */
public class LazyMaxLikelihoodMPE extends VotedPerceptron {
    private static final Logger log = LoggerFactory.getLogger(AtomEventFramework.class);
    private AtomEventFramework eventFramework;

    public LazyMaxLikelihoodMPE(Model model, Database database, Database database2, ConfigBundle configBundle) {
        super(model, database, database2, configBundle);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void initGroundModel() throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        this.reasoner = ((ReasonerFactory) this.config.getFactory(WeightLearningApplication.REASONER_KEY, REASONER_DEFAULT)).getReasoner(this.config);
        this.eventFramework = new AtomEventFramework(this.rvDB, this.config);
        Iterator<Kernel> it = this.model.getKernels().iterator();
        while (it.hasNext()) {
            it.next().registerForAtomEvents(this.eventFramework, this.reasoner);
        }
        Grounding.groundAll(this.model, this.eventFramework, this.reasoner);
        while (this.eventFramework.checkToActivate() > 0) {
            this.eventFramework.workOffJobQueue();
        }
    }

    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron
    protected double[] computeObservedIncomp() {
        boolean z;
        HashMap hashMap = new HashMap();
        Iterator<StandardPredicate> it = this.observedDB.getRegisteredPredicates().iterator();
        while (it.hasNext()) {
            for (GroundAtom groundAtom : Queries.getAllAtoms(this.observedDB, it.next())) {
                if ((groundAtom instanceof ObservedAtom) && groundAtom.getValue() > 0.0d) {
                    GroundAtom atom = this.eventFramework.getAtom(groundAtom.getPredicate(), groundAtom.getArguments());
                    if (atom instanceof RandomVariableAtom) {
                        this.eventFramework.activateAtom((RandomVariableAtom) atom);
                        GroundValueConstraint groundValueConstraint = new GroundValueConstraint((RandomVariableAtom) atom, ((ObservedAtom) groundAtom).getValue());
                        hashMap.put((RandomVariableAtom) atom, groundValueConstraint);
                        this.reasoner.addGroundKernel(groundValueConstraint);
                    }
                }
            }
        }
        HashSet hashSet = new HashSet();
        log.debug("Beginning to grow labeled network.");
        do {
            z = false;
            do {
                this.eventFramework.workOffJobQueue();
                this.reasoner.optimize();
            } while (this.eventFramework.checkToActivate() > 0);
            Iterator<GroundKernel> it2 = this.reasoner.getGroundKernels().iterator();
            while (it2.hasNext()) {
                for (GroundAtom groundAtom2 : it2.next().getAtoms()) {
                    if (groundAtom2 instanceof RandomVariableAtom) {
                        RandomVariableAtom randomVariableAtom = (RandomVariableAtom) groundAtom2;
                        if (!hashMap.containsKey(randomVariableAtom)) {
                            GroundAtom atom2 = this.observedDB.getAtom(randomVariableAtom.getPredicate(), randomVariableAtom.getArguments());
                            if (atom2 instanceof ObservedAtom) {
                                GroundValueConstraint groundValueConstraint2 = new GroundValueConstraint(randomVariableAtom, ((ObservedAtom) atom2).getValue());
                                hashMap.put(randomVariableAtom, groundValueConstraint2);
                                hashSet.add(groundValueConstraint2);
                                z = true;
                            }
                        }
                    }
                }
            }
            Iterator it3 = hashSet.iterator();
            while (it3.hasNext()) {
                this.reasoner.addGroundKernel((GroundValueConstraint) it3.next());
            }
            hashSet.clear();
        } while (z);
        log.debug("Finished growing labeled network.");
        double[] dArr = new double[this.kernels.size()];
        for (int i = 0; i < this.kernels.size(); i++) {
            Iterator<GroundKernel> it4 = this.reasoner.getGroundKernels(this.kernels.get(i)).iterator();
            while (it4.hasNext()) {
                int i2 = i;
                dArr[i2] = dArr[i2] + ((GroundCompatibilityKernel) it4.next()).getIncompatibility();
            }
        }
        Iterator it5 = hashMap.values().iterator();
        while (it5.hasNext()) {
            this.reasoner.removeGroundKernel((GroundValueConstraint) it5.next());
        }
        return dArr;
    }

    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron
    protected double[] computeExpectedIncomp() {
        double[] dArr = new double[this.kernels.size()];
        do {
            this.eventFramework.workOffJobQueue();
            this.reasoner.optimize();
        } while (this.eventFramework.checkToActivate() > 0);
        for (int i = 0; i < this.kernels.size(); i++) {
            Iterator<GroundKernel> it = this.reasoner.getGroundKernels(this.kernels.get(i)).iterator();
            while (it.hasNext()) {
                int i2 = i;
                dArr[i2] = dArr[i2] + ((GroundCompatibilityKernel) it.next()).getIncompatibility();
            }
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron
    public double[] computeScalingFactor() {
        double[] dArr = new double[this.kernels.size()];
        for (int i = 0; i < this.kernels.size(); i++) {
            Iterator<GroundKernel> it = this.reasoner.getGroundKernels(this.kernels.get(i)).iterator();
            while (it.hasNext()) {
                it.next();
                int i2 = i;
                dArr[i2] = dArr[i2] + 1.0d;
            }
            if (dArr[i] == 0.0d) {
                int i3 = i;
                dArr[i3] = dArr[i3] + 1.0d;
            }
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void cleanUpGroundModel() {
        Iterator<Kernel> it = this.model.getKernels().iterator();
        while (it.hasNext()) {
            it.next().unregisterForAtomEvents(this.eventFramework, this.reasoner);
        }
        this.eventFramework = null;
        super.cleanUpGroundModel();
    }
}
