package edu.umd.cs.psl.reasoner.bool;

import edu.umd.cs.psl.application.groundkernelstore.MemoryGroundKernelStore;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.reasoner.Reasoner;
import edu.umd.cs.psl.util.model.ConstraintBlocker;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/reasoner/bool/BooleanMCSat.class */
public class BooleanMCSat extends MemoryGroundKernelStore implements Reasoner {
    private static final Logger log = LoggerFactory.getLogger(BooleanMCSat.class);
    public static final String CONFIG_PREFIX = "booleanmcsat";
    public static final String NUM_SAMPLES_KEY = "booleanmcsat.numsamples";
    public static final int NUM_SAMPLES_DEFAULT = 2500;
    public static final String NUM_BURN_IN_KEY = "booleanmcsat.numburnin";
    public static final int NUM_BURN_IN_DEFAULT = 500;
    private final Random rand = new Random();
    private final int numSamples;
    private final int numBurnIn;

    public BooleanMCSat(ConfigBundle configBundle) {
        this.numSamples = configBundle.getInt(NUM_SAMPLES_KEY, NUM_SAMPLES_DEFAULT);
        if (this.numSamples <= 0) {
            throw new IllegalArgumentException("Number of samples must be positive.");
        }
        this.numBurnIn = configBundle.getInt(NUM_BURN_IN_KEY, 500);
        if (this.numSamples <= 0) {
            throw new IllegalArgumentException("Number of burn in samples must be positive.");
        }
        if (this.numBurnIn >= this.numSamples) {
            throw new IllegalArgumentException("Number of burn in samples must be less than number of samples.");
        }
    }

    @Override // edu.umd.cs.psl.reasoner.Reasoner
    public void optimize() {
        ConstraintBlocker constraintBlocker = new ConstraintBlocker(this);
        constraintBlocker.prepareBlocks(false);
        RandomVariableAtom[][] rVBlocks = constraintBlocker.getRVBlocks();
        boolean[] exactlyOne = constraintBlocker.getExactlyOne();
        GroundCompatibilityKernel[][] incidentGKs = constraintBlocker.getIncidentGKs();
        double[][] emptyDouble2DArray = constraintBlocker.getEmptyDouble2DArray();
        constraintBlocker.randomlyInitializeRVs();
        log.info("Beginning inference.");
        for (int i = 0; i < this.numSamples; i++) {
            for (int i2 = 0; i2 < rVBlocks.length; i2++) {
                if (rVBlocks.length != 0) {
                    double[] dArr = new double[exactlyOne[i2] ? rVBlocks[i2].length : rVBlocks[i2].length + 1];
                    int i3 = 0;
                    while (i3 < rVBlocks[i2].length) {
                        int i4 = 0;
                        while (i4 < rVBlocks[i2].length) {
                            rVBlocks[i2][i4].setValue(i4 == i3 ? 1.0d : 0.0d);
                            i4++;
                        }
                        dArr[i3] = computeProbability(incidentGKs[i2]);
                        i3++;
                    }
                    if (!exactlyOne[i2]) {
                        for (RandomVariableAtom randomVariableAtom : rVBlocks[i2]) {
                            randomVariableAtom.setValue(0.0d);
                        }
                        dArr[dArr.length - 1] = computeProbability(incidentGKs[i2]);
                    }
                    double[] sampleWithProbability = sampleWithProbability(dArr);
                    for (int i5 = 0; i5 < rVBlocks[i2].length; i5++) {
                        rVBlocks[i2][i5].setValue(sampleWithProbability[i5]);
                        if (i >= this.numBurnIn) {
                            double[] dArr2 = emptyDouble2DArray[i2];
                            int i6 = i5;
                            dArr2[i6] = dArr2[i6] + sampleWithProbability[i5];
                        }
                    }
                }
            }
        }
        log.info("Inference complete.");
        for (int i7 = 0; i7 < rVBlocks.length; i7++) {
            for (int i8 = 0; i8 < rVBlocks[i7].length; i8++) {
                rVBlocks[i7][i8].setValue(emptyDouble2DArray[i7][i8] / (this.numSamples - this.numBurnIn));
            }
        }
    }

    private double computeProbability(GroundCompatibilityKernel[] groundCompatibilityKernelArr) {
        double d = 0.0d;
        for (GroundCompatibilityKernel groundCompatibilityKernel : groundCompatibilityKernelArr) {
            d += groundCompatibilityKernel.getWeight().getWeight() * groundCompatibilityKernel.getIncompatibility();
        }
        return Math.exp((-1.0d) * d);
    }

    private double[] sampleWithProbability(double[] dArr) {
        if (dArr.length == 0) {
            return new double[0];
        }
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] / d;
        }
        double[] dArr2 = new double[dArr.length];
        double nextDouble = this.rand.nextDouble();
        double d3 = 0.0d;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            d3 += dArr[i3];
            if (d3 >= nextDouble) {
                dArr2[i3] = 1.0d;
                return dArr2;
            }
        }
        dArr2[dArr2.length - 1] = 1.0d;
        return dArr2;
    }

    @Override // edu.umd.cs.psl.reasoner.Reasoner
    public void close() {
    }
}
