package edu.umd.cs.psl.application.learning.weight.random;

import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.atom.GroundAtom;
import edu.umd.cs.psl.model.atom.ObservedAtom;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.model.kernel.GroundKernel;
import edu.umd.cs.psl.model.parameters.PositiveWeight;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/random/UnforgivingGroundSliceRandOM.class */
public class UnforgivingGroundSliceRandOM extends GroundSliceRandOM {
    private static final Logger log = LoggerFactory.getLogger(UnforgivingGroundSliceRandOM.class);
    public static final String CONFIG_PREFIX = "unforgivinggroundslicerandom";
    public static final String L1_DIMENSION_KEY = "unforgivinggroundslicerandom.l1dimension";
    public static final double L1_DIMENSION_DEFAULT = 0.25d;
    protected double l1Dimension;

    public UnforgivingGroundSliceRandOM(Model model, Database database, Database database2, ConfigBundle configBundle) {
        super(model, database, database2, configBundle);
        this.l1Dimension = configBundle.getDouble(L1_DIMENSION_KEY, 0.25d);
        if (this.l1Dimension <= 0.0d) {
            throw new IllegalArgumentException("L1 dimension must be positive.");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.random.GroundSliceRandOM, edu.umd.cs.psl.application.learning.weight.random.SliceRandOM, edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void doLearn() {
        HashSet<GroundCompatibilityKernel> hashSet = new HashSet();
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            for (GroundKernel groundKernel : entry.getKey().getRegisteredGroundKernels()) {
                if (groundKernel instanceof GroundCompatibilityKernel) {
                    GroundCompatibilityKernel groundCompatibilityKernel = (GroundCompatibilityKernel) groundKernel;
                    boolean z = true;
                    for (GroundAtom groundAtom : groundCompatibilityKernel.getAtoms()) {
                        if ((groundAtom instanceof GroundCompatibilityKernel) && !groundAtom.equals(entry.getKey())) {
                            z = false;
                        }
                    }
                    if (z) {
                        hashSet.add(groundCompatibilityKernel);
                    } else {
                        groundCompatibilityKernel.setWeight(new PositiveWeight(0.0d));
                    }
                }
            }
            GroundCompatibilityKernel groundCompatibilityKernel2 = null;
            GroundCompatibilityKernel groundCompatibilityKernel3 = null;
            for (GroundCompatibilityKernel groundCompatibilityKernel4 : hashSet) {
                entry.getKey().setValue(0.0d);
                double incompatibility = groundCompatibilityKernel4.getIncompatibility();
                entry.getKey().setValue(1.0d);
                double incompatibility2 = groundCompatibilityKernel4.getIncompatibility();
                if (incompatibility == 1.0d && incompatibility2 == 0.0d && groundCompatibilityKernel4.getAtoms().size() == 2 && groundCompatibilityKernel2 == null) {
                    groundCompatibilityKernel2 = groundCompatibilityKernel4;
                } else if (incompatibility == 0.0d && incompatibility2 == 1.0d && groundCompatibilityKernel4.getAtoms().size() == 2 && groundCompatibilityKernel3 == null) {
                    groundCompatibilityKernel3 = groundCompatibilityKernel4;
                }
            }
            if (groundCompatibilityKernel2 == null || groundCompatibilityKernel3 == null) {
                throw new IllegalStateException("Did not find a positive and a negative unary ground compatibility kernel for atom: " + entry.getKey());
            }
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                ((GroundCompatibilityKernel) it.next()).setWeight(new PositiveWeight(0.0d));
            }
            if (entry.getValue().getValue() == 1.0d) {
                groundCompatibilityKernel2.setWeight(new PositiveWeight(1.0d));
            } else {
                if (entry.getValue().getValue() != 0.0d) {
                    throw new IllegalStateException("Unexpected truth value of " + entry.getValue().getValue() + " for atom " + entry.getKey() + ".");
                }
                groundCompatibilityKernel3.setWeight(new PositiveWeight(1.0d));
            }
            hashSet.clear();
        }
        this.reasoner.changedGroundKernelWeights();
        this.reasoner.optimize();
        log.warn("Log likelihood of observations: {}", Double.valueOf(getLogLikelihoodObservations()));
        super.doLearn();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.random.SliceRandOM
    public double getLogLikelihoodObservations() {
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            if (Math.abs(entry.getKey().getValue() - entry.getValue().getValue()) >= this.l1Dimension) {
                return Double.NEGATIVE_INFINITY;
            }
        }
        return 0.0d;
    }
}
