package edu.umd.cs.psl.application.learning.weight.maxmargin;

import edu.umd.cs.psl.model.atom.Atom;
import edu.umd.cs.psl.model.atom.GroundAtom;
import edu.umd.cs.psl.model.kernel.BindingMode;
import edu.umd.cs.psl.model.kernel.CompatibilityKernel;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.model.parameters.Weight;
import edu.umd.cs.psl.reasoner.function.ConstantNumber;
import edu.umd.cs.psl.reasoner.function.FunctionSum;
import edu.umd.cs.psl.reasoner.function.FunctionSummand;
import edu.umd.cs.psl.reasoner.function.FunctionTerm;
import java.util.HashSet;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/maxmargin/LossAugmentingGroundKernel.class */
public class LossAugmentingGroundKernel implements GroundCompatibilityKernel {
    private static final Logger log = LoggerFactory.getLogger(LossAugmentingGroundKernel.class);
    private GroundAtom atom;
    private double groundTruth;
    private Weight weight;

    public LossAugmentingGroundKernel(GroundAtom groundAtom, double d, Weight weight) {
        this.atom = groundAtom;
        this.groundTruth = d;
        if (this.groundTruth != 1.0d && this.groundTruth != 0.0d) {
            throw new IllegalArgumentException("Truth value must be 1.0 or 0.0.");
        }
        this.weight = weight;
    }

    @Override // edu.umd.cs.psl.model.kernel.GroundKernel
    public boolean updateParameters() {
        log.warn("Called unsupported function on LossAugmentedGroundKernel");
        return false;
    }

    @Override // edu.umd.cs.psl.model.kernel.GroundKernel
    public CompatibilityKernel getKernel() {
        return null;
    }

    @Override // edu.umd.cs.psl.model.kernel.GroundKernel
    public Set<GroundAtom> getAtoms() {
        HashSet hashSet = new HashSet();
        hashSet.add(this.atom);
        return hashSet;
    }

    @Override // edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel
    public double getIncompatibility() {
        return Math.abs(this.atom.getValue() - this.groundTruth);
    }

    @Override // edu.umd.cs.psl.model.kernel.GroundKernel
    public BindingMode getBinding(Atom atom) {
        return null;
    }

    @Override // edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel
    public Weight getWeight() {
        return this.weight;
    }

    @Override // edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel
    public void setWeight(Weight weight) {
        this.weight = weight;
    }

    @Override // edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel
    public FunctionTerm getFunctionDefinition() {
        FunctionSum functionSum = new FunctionSum();
        if (this.groundTruth == 1.0d) {
            functionSum.add(new FunctionSummand(1.0d, new ConstantNumber(1.0d)));
            functionSum.add(new FunctionSummand(-1.0d, this.atom.getVariable()));
        } else {
            if (this.groundTruth != 0.0d) {
                throw new IllegalStateException("Ground truth is not 0 or 1.");
            }
            functionSum.add(new FunctionSummand(1.0d, this.atom.getVariable()));
        }
        return functionSum;
    }

    public GroundAtom getAtom() {
        return this.atom;
    }
}
