package edu.umd.cs.psl.application.topicmodel;

import edu.umd.cs.psl.application.learning.weight.maxlikelihood.SimplexSampler;
import edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.kernel.CompatibilityKernel;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.model.predicate.Predicate;
import edu.umd.cs.psl.util.model.ConstraintBlocker;
import java.util.Calendar;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;

/* loaded from: input_file:edu/umd/cs/psl/application/topicmodel/LatentTopicNetworkMaxPseudoLikelihood.class */
public class LatentTopicNetworkMaxPseudoLikelihood extends VotedPerceptron {
    public static final String CONFIG_PREFIX = "LTNmaxspeudolikelihood";
    public static final String BOOLEAN_KEY = "LTNmaxspeudolikelihood.bool";
    public static final boolean BOOLEAN_DEFAULT = false;
    public static final String NUM_SAMPLES_KEY = "LTNmaxspeudolikelihood.numsamples";
    public static final int NUM_SAMPLES_DEFAULT = 10;
    public static final String CONSTRAINT_TOLERANCE_KEY = "LTNmaxspeudolikelihood.constrainttolerance";
    public static final double CONSTRAINT_TOLERANCE_DEFAULT = 1.0E-5d;
    public static final String MIN_WIDTH_KEY = "LTNmaxspeudolikelihood.minwidth";
    public static final double MIN_WIDTH_DEFAULT = 0.01d;
    private ConstraintBlocker blocker;
    private final boolean bool;
    private final int numSamples;
    private final double minWidth;
    private final double constraintTol;
    private double dirichletParam;
    private Predicate p;
    private static Random rng = new Random(Calendar.getInstance().getTimeInMillis() + Thread.currentThread().getId());

    public LatentTopicNetworkMaxPseudoLikelihood(Model model, Database database, Database database2, ConfigBundle configBundle, double d, Predicate predicate) {
        super(model, database, database2, configBundle);
        this.bool = configBundle.getBoolean("LTNmaxspeudolikelihood.bool", false);
        this.numSamples = configBundle.getInt("LTNmaxspeudolikelihood.numsamples", 10);
        if (this.numSamples <= 0) {
            throw new IllegalArgumentException("Number of samples must be positive integer.");
        }
        this.minWidth = configBundle.getDouble("LTNmaxspeudolikelihood.minwidth", 0.01d);
        if (this.minWidth <= 0.0d) {
            throw new IllegalArgumentException("Minimum width must be positive double.");
        }
        this.constraintTol = configBundle.getDouble("LTNmaxspeudolikelihood.constrainttolerance", 1.0E-5d);
        if (this.constraintTol <= 0.0d) {
            throw new IllegalArgumentException("Minimum width must be positive double.");
        }
        this.dirichletParam = d;
        this.p = predicate;
    }

    @Override // edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void initGroundModel() throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        super.initGroundModel();
        this.blocker = new ConstraintBlocker(this.reasoner);
        this.blocker.prepareBlocks(true);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v154, types: [double[]] */
    /* JADX WARN: Type inference failed for: r0v34, types: [double[]] */
    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron
    protected double[] computeExpectedIncomp() {
        double[][] dArr;
        RandomVariableAtom[][] rVBlocks = this.blocker.getRVBlocks();
        boolean[] exactlyOne = this.blocker.getExactlyOne();
        GroundCompatibilityKernel[][] incidentGKs = this.blocker.getIncidentGKs();
        double[] dArr2 = new double[this.kernels.size()];
        for (int i = 0; i < rVBlocks.length; i++) {
            if (rVBlocks[i].length != 0 && incidentGKs[i].length != 0) {
                boolean z = true;
                for (int i2 = 0; i2 < rVBlocks[i].length; i2++) {
                    if (rVBlocks[i][i2].getPredicate() != this.p) {
                        z = false;
                    }
                }
                if (this.bool) {
                    dArr = new double[exactlyOne[i] ? rVBlocks[i].length : rVBlocks[i].length + 1];
                    int i3 = 0;
                    while (true) {
                        if (i3 >= (exactlyOne[i] ? dArr.length : dArr.length - 1)) {
                            break;
                        }
                        dArr[i3] = new double[rVBlocks[i].length];
                        dArr[i3][i3] = 1.0d;
                        i3++;
                    }
                    if (!exactlyOne[i]) {
                        dArr[dArr.length - 1] = new double[rVBlocks[i].length];
                    }
                } else {
                    dArr = new double[Math.max(this.numSamples * rVBlocks[i].length, 150)];
                    if (z) {
                        for (int i4 = 0; i4 < dArr.length; i4++) {
                            dArr[i4] = sampleFromDirichlet(rVBlocks[i].length, this.dirichletParam);
                        }
                    } else {
                        SimplexSampler simplexSampler = new SimplexSampler();
                        for (int i5 = 0; i5 < dArr.length; i5++) {
                            dArr[i5] = simplexSampler.getNext(rVBlocks[i].length);
                        }
                    }
                }
                HashMap hashMap = new HashMap();
                double[] dArr3 = new double[rVBlocks[i].length];
                for (int i6 = 0; i6 < rVBlocks[i].length; i6++) {
                    dArr3[i6] = rVBlocks[i][i6].getValue();
                }
                for (GroundCompatibilityKernel groundCompatibilityKernel : incidentGKs[i]) {
                    if (groundCompatibilityKernel instanceof GroundCompatibilityKernel) {
                        CompatibilityKernel compatibilityKernel = (CompatibilityKernel) groundCompatibilityKernel.getKernel();
                        if (!hashMap.containsKey(compatibilityKernel)) {
                            hashMap.put(compatibilityKernel, new double[dArr.length]);
                        }
                        double[] dArr4 = (double[]) hashMap.get(compatibilityKernel);
                        for (int i7 = 0; i7 < dArr.length; i7++) {
                            for (int i8 = 0; i8 < rVBlocks[i].length; i8++) {
                                rVBlocks[i][i8].setValue(dArr[i7][i8]);
                            }
                            int i9 = i7;
                            dArr4[i9] = dArr4[i9] + groundCompatibilityKernel.getIncompatibility();
                        }
                    }
                }
                for (int i10 = 0; i10 < rVBlocks[i].length; i10++) {
                    rVBlocks[i][i10].setValue(dArr3[i10]);
                }
                HashMap hashMap2 = new HashMap();
                double d = 0.0d;
                for (int i11 = 0; i11 < dArr.length; i11++) {
                    double d2 = 0.0d;
                    for (Map.Entry entry : hashMap.entrySet()) {
                        d2 -= ((CompatibilityKernel) entry.getKey()).getWeight().getWeight() * ((double[]) entry.getValue())[i11];
                    }
                    double exp = Math.exp(d2);
                    d += exp;
                    Iterator it = hashMap.entrySet().iterator();
                    while (it.hasNext()) {
                        CompatibilityKernel compatibilityKernel2 = (CompatibilityKernel) ((Map.Entry) it.next()).getKey();
                        if (!hashMap2.containsKey(compatibilityKernel2)) {
                            hashMap2.put(compatibilityKernel2, Double.valueOf(0.0d));
                        }
                        hashMap2.put(compatibilityKernel2, Double.valueOf(((Double) hashMap2.get(compatibilityKernel2)).doubleValue() + (exp * ((double[]) hashMap.get(compatibilityKernel2))[i11])));
                    }
                }
                for (int i12 = 0; i12 < this.kernels.size(); i12++) {
                    CompatibilityKernel compatibilityKernel3 = this.kernels.get(i12);
                    if (hashMap2.containsKey(compatibilityKernel3) && ((Double) hashMap2.get(compatibilityKernel3)).doubleValue() > 0.0d) {
                        int i13 = i12;
                        dArr2[i13] = dArr2[i13] + (((Double) hashMap2.get(compatibilityKernel3)).doubleValue() / d);
                    }
                }
            }
        }
        return dArr2;
    }

    private static double[] sampleFromDirichlet(int i, double d) {
        double[] dArr = new double[i];
        double d2 = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = sampleGamma(d, 1.0d);
            d2 += dArr[i2];
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / d2;
        }
        return dArr;
    }

    private static double sampleGamma(double d, double d2) {
        double exp;
        double pow;
        boolean z = false;
        if (d < 1.0d) {
            double d3 = 1.0d / d;
            double pow2 = (1.0d - d) * Math.pow(d, d / (1.0d - d));
            do {
                double nextDouble = rng.nextDouble();
                double nextDouble2 = rng.nextDouble();
                double d4 = -Math.log(nextDouble);
                double d5 = -Math.log(nextDouble2);
                pow = Math.pow(d4, d3);
                if (d4 + d5 >= pow2 + pow) {
                    z = true;
                }
            } while (!z);
            return pow * d2;
        }
        double log = d - Math.log(4.0d);
        double sqrt = d + Math.sqrt((2.0d * d) - 1.0d);
        double sqrt2 = Math.sqrt((2.0d * d) - 1.0d);
        double log2 = 1.0d + Math.log(4.5d);
        do {
            double nextDouble3 = rng.nextDouble();
            double nextDouble4 = rng.nextDouble();
            double log3 = (1.0d / sqrt2) * Math.log(nextDouble4 / (1.0d - nextDouble4));
            exp = d * Math.exp(log3);
            double d6 = nextDouble3 * nextDouble4 * nextDouble4;
            double d7 = (log + (sqrt * log3)) - exp;
            if (d7 >= (4.5d * d6) - log2 || d7 >= Math.log(d6)) {
                z = true;
            }
        } while (!z);
        return exp * d2;
    }
}
