package edu.umd.cs.psl.application.learning.weight.random;

import edu.umd.cs.psl.application.learning.weight.WeightLearningApplication;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.atom.ObservedAtom;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.parameters.PositiveWeight;
import java.util.Map;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/random/MetropolisRandOM.class */
public abstract class MetropolisRandOM extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger(MetropolisRandOM.class);
    public static final String CONFIG_PREFIX = "random";
    public static final String MAX_ITER_KEY = "random.maxiter";
    public static final int MAX_ITER_DEFAULT = 30;
    public static final String NUM_SAMPLES_KEY = "random.numsamples";
    public static final int NUM_SAMPLES_DEFAULT = 100;
    public static final String BURN_IN_KEY = "random.burnin";
    public static final int BURN_IN_DEFAULT = 20;
    public static final String INITIAL_VARIANCE_KEY = "random.initialvariance";
    public static final double INITIAL_VARIANCE_DEFAULT = 1.0d;
    public static final String OBSERVATION_DENSITY_SCALE_KEY = "random.observationscale";
    public static final double OBSERVATION_DENSITY_SCALE_DEFAULT = 0.1d;
    public static final String CHANGE_THRESHOLD_KEY = "random.changethreshold";
    public static final double CHANGE_THRESHOLD_DEFAULT = 0.05d;
    protected final Random rand;
    protected double[] kernelMeans;
    protected double[] kernelVariances;
    protected final int maxIter;
    protected final int numSamples;
    protected final int burnIn;
    protected final double initialVariance;
    protected final double observationScale;
    protected final double changeThresholdFactor;

    public MetropolisRandOM(Model model, Database database, Database database2, ConfigBundle configBundle) {
        super(model, database, database2, configBundle);
        this.rand = new Random();
        this.maxIter = configBundle.getInt(MAX_ITER_KEY, 30);
        this.numSamples = configBundle.getInt(NUM_SAMPLES_KEY, 100);
        this.burnIn = configBundle.getInt(BURN_IN_KEY, 20);
        this.initialVariance = configBundle.getDouble(INITIAL_VARIANCE_KEY, 1.0d);
        if (this.initialVariance <= 0.0d) {
            throw new IllegalArgumentException("Initial variance must be positive.");
        }
        this.observationScale = configBundle.getDouble(OBSERVATION_DENSITY_SCALE_KEY, 0.1d);
        if (this.observationScale <= 0.0d) {
            throw new IllegalArgumentException("Observation density scale must be positive.");
        }
        this.changeThresholdFactor = configBundle.getDouble(CHANGE_THRESHOLD_KEY, 0.05d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void doLearn() {
        double sqrt;
        this.kernelMeans = new double[this.kernels.size()];
        this.kernelVariances = new double[this.kernels.size()];
        double[] dArr = new double[this.kernels.size()];
        for (int i = 0; i < this.kernelMeans.length; i++) {
            this.kernelMeans[i] = this.kernels.get(i).getWeight().getWeight();
            dArr[i] = this.kernelMeans[i];
            this.kernelVariances[i] = this.initialVariance;
        }
        this.reasoner.optimize();
        int i2 = 1;
        do {
            log.info("Starting Monte Carlo EM round " + i2 + ".");
            prepareForRound();
            this.reasoner.optimize();
            double logLikelihoodObservations = getLogLikelihoodObservations() + getLogLikelihoodSampledWeights();
            int i3 = 0;
            int i4 = 0;
            while (i4 < this.numSamples) {
                sampleAndSetWeights();
                this.reasoner.changedGroundKernelWeights();
                optimizeEnergyFunction();
                double logLikelihoodObservations2 = getLogLikelihoodObservations() + getLogLikelihoodSampledWeights();
                boolean z = this.rand.nextDouble() < Math.exp(logLikelihoodObservations2 - logLikelihoodObservations);
                log.info("Likelihood of observations {}", Double.valueOf(getLogLikelihoodObservations()));
                log.info("Likelihood of weights {}", Double.valueOf(getLogLikelihoodSampledWeights()));
                log.info("New likelihood {}", Double.valueOf(logLikelihoodObservations2));
                log.info("Previous likelihood {}", Double.valueOf(logLikelihoodObservations));
                log.info(z ? "Accepted" : "Rejected");
                if (z) {
                    acceptSample(i4 < this.burnIn);
                    logLikelihoodObservations = logLikelihoodObservations2;
                    i3++;
                } else {
                    rejectSample(i4 < this.burnIn);
                }
                updateProposalVariance(i3, i4);
                i4++;
            }
            finishRound();
            log.info("Sample acceptance rate: {}", Double.valueOf(i3 / this.numSamples));
            double d = 0.0d;
            for (int i5 = 0; i5 < this.kernels.size(); i5++) {
                double d2 = this.kernelMeans[i5] - dArr[i5];
                d += d2 * d2;
                dArr[i5] = this.kernelMeans[i5];
                log.info("Mean of {} for kernel {}, ", Double.valueOf(this.kernelMeans[i5]), this.kernels.get(i5));
            }
            sqrt = Math.sqrt(d);
            log.info("Change in weight means: {}", Double.valueOf(sqrt));
            i2++;
            if (i2 > this.maxIter) {
                break;
            }
        } while (sqrt > this.changeThresholdFactor * Math.sqrt(this.kernels.size()));
        for (int i6 = 0; i6 < this.kernels.size(); i6++) {
            this.kernels.get(i6).setWeight(new PositiveWeight(Math.max(0.0d, this.kernelMeans[i6])));
        }
    }

    protected abstract void prepareForRound();

    protected abstract void sampleAndSetWeights();

    protected void optimizeEnergyFunction() {
        this.reasoner.optimize();
    }

    protected double getLogLikelihoodObservations() {
        double d = 0.0d;
        int i = 0;
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            if (1 != 0 && entry.getKey().getRegisteredGroundKernels().size() > 2) {
                entry.getValue().getValue();
            }
            if (entry.getKey().getRegisteredGroundKernels().size() > 2 && entry.getValue().getValue() == 0.0d) {
                i++;
            }
            d -= Math.abs(entry.getKey().getValue() - entry.getValue().getValue()) / this.observationScale;
        }
        return d;
    }

    protected abstract void updateProposalVariance(int i, int i2);

    protected abstract double getLogLikelihoodSampledWeights();

    protected abstract void acceptSample(boolean z);

    protected abstract void rejectSample(boolean z);

    protected abstract void finishRound();

    /* JADX INFO: Access modifiers changed from: protected */
    public double sampleFromGaussian(double d, double d2) {
        return (Math.sqrt(d2) * this.rand.nextGaussian()) + d;
    }
}
