package edu.umd.cs.psl.application.inference;

import de.mathnbits.statistics.DoubleDist;
import edu.umd.cs.psl.application.ModelApplication;
import edu.umd.cs.psl.application.util.Grounding;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.evaluation.result.FullConfidenceAnalysisResult;
import edu.umd.cs.psl.evaluation.result.memory.MemoryFullConfidenceAnalysisResult;
import edu.umd.cs.psl.model.ConfidenceValues;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.atom.PersistedAtomManager;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.reasoner.Reasoner;
import edu.umd.cs.psl.reasoner.ReasonerFactory;
import edu.umd.cs.psl.reasoner.admm.ADMMReasonerFactory;
import edu.umd.cs.psl.sampler.MarginalSampler;

/* loaded from: input_file:edu/umd/cs/psl/application/inference/ConfidenceAnalysis.class */
public class ConfidenceAnalysis implements ModelApplication {
    public static final String CONFIG_PREFIX = "confidenceanalysis";
    public static final String NUM_SAMPLES_KEY = "confidenceanalysis.numsamples";
    public static final int NUM_SAMPLES_DEFAULT = 1000;
    public static final String REASONER_KEY = "confidenceanalysis.reasoner";
    public static final ReasonerFactory REASONER_DEFAULT = new ADMMReasonerFactory();
    private Model model;
    private Database db;
    private ConfigBundle config;
    private final int numSamples;

    public ConfidenceAnalysis(Model model, Database database, ConfigBundle configBundle) {
        this.model = model;
        this.db = database;
        this.config = configBundle;
        this.numSamples = configBundle.getInt(NUM_SAMPLES_KEY, NUM_SAMPLES_DEFAULT);
        if (this.numSamples <= 0) {
            throw new IllegalArgumentException("Number of samples must be positive.");
        }
    }

    public FullConfidenceAnalysisResult runConfidenceAnalysis() throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        Reasoner reasoner = ((ReasonerFactory) this.config.getFactory(REASONER_KEY, REASONER_DEFAULT)).getReasoner(this.config);
        PersistedAtomManager persistedAtomManager = new PersistedAtomManager(this.db);
        Grounding.groundAll(this.model, persistedAtomManager, reasoner);
        reasoner.optimize();
        MarginalSampler marginalSampler = new MarginalSampler(this.numSamples);
        marginalSampler.sample(reasoner.getGroundKernels(), 1.0d, 1);
        for (RandomVariableAtom randomVariableAtom : persistedAtomManager.getPersistedRVAtoms()) {
            DoubleDist distribution = marginalSampler.getDistribution(randomVariableAtom.getVariable());
            randomVariableAtom.setValue(distribution.mean());
            randomVariableAtom.setConfidenceValue(Math.min(Math.max(1.0d / distribution.stdDev(), ConfidenceValues.getMin()), ConfidenceValues.getMax()));
            randomVariableAtom.commitToDB();
        }
        return new MemoryFullConfidenceAnalysisResult(marginalSampler.getDistributions());
    }

    @Override // edu.umd.cs.psl.application.ModelApplication
    public void close() {
        this.model = null;
        this.db = null;
        this.config = null;
    }
}
