package edu.umd.cs.psl.application.topicmodel.reasoner.admm;

import edu.umd.cs.psl.application.topicmodel.kernel.LDAgroundLogLoss;
import edu.umd.cs.psl.application.topicmodel.reasoner.function.NegativeLogFunction;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.model.kernel.GroundConstraintKernel;
import edu.umd.cs.psl.model.kernel.GroundKernel;
import edu.umd.cs.psl.reasoner.admm.ADMMObjectiveTerm;
import edu.umd.cs.psl.reasoner.admm.ADMMReasoner;
import edu.umd.cs.psl.reasoner.function.ConstraintTerm;
import edu.umd.cs.psl.reasoner.function.FunctionSum;
import edu.umd.cs.psl.reasoner.function.FunctionTerm;
import java.util.HashMap;
import java.util.List;

/* loaded from: input_file:edu/umd/cs/psl/application/topicmodel/reasoner/admm/LatentTopicNetworkADMMReasoner.class */
public class LatentTopicNetworkADMMReasoner extends ADMMReasoner {
    public static final String LOWER_BOUND_EPSILON_KEY = "admmreasoner.lowerboundepsilon";
    public static final double LOWER_BOUND_EPSILON_DEFAULT = 1.0E-6d;
    protected double lowerBoundEpsilon;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !LatentTopicNetworkADMMReasoner.class.desiredAssertionStatus();
    }

    public LatentTopicNetworkADMMReasoner(ConfigBundle configBundle) {
        super(configBundle);
        this.lowerBoundEpsilon = configBundle.getDouble(LOWER_BOUND_EPSILON_KEY, 1.0E-6d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.reasoner.admm.ADMMReasoner
    public void buildGroundModel() {
        super.buildGroundModel();
        for (int i = 0; i < this.lb.size(); i++) {
            this.lb.set(i, new Double(this.lowerBoundEpsilon));
        }
        initDirichletTerms();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.reasoner.admm.ADMMReasoner
    public ADMMObjectiveTerm createTerm(GroundKernel groundKernel) {
        if (groundKernel instanceof GroundCompatibilityKernel) {
            FunctionTerm functionDefinition = ((GroundCompatibilityKernel) groundKernel).getFunctionDefinition();
            if (!(functionDefinition instanceof NegativeLogFunction)) {
                return super.createTerm(groundKernel);
            }
            ADMMReasoner.Hyperplane processHyperplane = processHyperplane((FunctionSum) functionDefinition);
            return groundKernel instanceof LDAgroundLogLoss ? new NegativeLogLossTerm(this, processHyperplane.zIndices, ((LDAgroundLogLoss) groundKernel).getCoefficientsArray(), ((GroundCompatibilityKernel) groundKernel).getWeight().getWeight()) : new NegativeLogLossTerm(this, processHyperplane.zIndices, processHyperplane.coeffs, ((GroundCompatibilityKernel) groundKernel).getWeight().getWeight());
        }
        if (!(groundKernel instanceof GroundConstraintKernel)) {
            throw new IllegalArgumentException("Unsupported ground kernel: " + groundKernel);
        }
        ConstraintTerm constraintDefinition = ((GroundConstraintKernel) groundKernel).getConstraintDefinition();
        FunctionTerm function = constraintDefinition.getFunction();
        if (!(function instanceof FunctionSum)) {
            throw new IllegalArgumentException("Unrecognized constraint: " + constraintDefinition);
        }
        ADMMReasoner.Hyperplane processHyperplane2 = processHyperplane((FunctionSum) function);
        return new LtnLinearConstraintTerm(this, processHyperplane2.zIndices, processHyperplane2.coeffs, constraintDefinition.getValue() + processHyperplane2.constant, constraintDefinition.getComparator());
    }

    public void initDirichletTerms() {
        System.out.println("Init Dirichlet terms");
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.varLocations.size(); i++) {
            List<ADMMReasoner.VariableLocation> list = this.varLocations.get(i);
            NegativeLogLossTerm negativeLogLossTerm = null;
            LtnLinearConstraintTerm ltnLinearConstraintTerm = null;
            for (int i2 = 0; i2 < list.size(); i2++) {
                ADMMObjectiveTerm term = list.get(i2).getTerm();
                if (term instanceof NegativeLogLossTerm) {
                    if (!$assertionsDisabled && negativeLogLossTerm != null) {
                        throw new AssertionError();
                    }
                    negativeLogLossTerm = (NegativeLogLossTerm) term;
                }
                if (term instanceof LtnLinearConstraintTerm) {
                    if (!$assertionsDisabled && ltnLinearConstraintTerm != null) {
                        throw new AssertionError();
                    }
                    ltnLinearConstraintTerm = (LtnLinearConstraintTerm) term;
                }
            }
            if (negativeLogLossTerm != null && ltnLinearConstraintTerm != null) {
                hashMap.put(negativeLogLossTerm, ltnLinearConstraintTerm);
            }
        }
        for (NegativeLogLossTerm negativeLogLossTerm2 : hashMap.keySet()) {
            ((LtnLinearConstraintTerm) hashMap.get(negativeLogLossTerm2)).initDualVariablesAsDirichlet(negativeLogLossTerm2.initAsDirichlet());
        }
    }
}
