package edu.umd.cs.psl.application.learning.weight.random;

import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.model.kernel.GroundKernel;
import edu.umd.cs.psl.model.parameters.PositiveWeight;
import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/random/FirstOrderSliceRandOM.class */
public class FirstOrderSliceRandOM extends SliceRandOM {
    private static final Logger log = LoggerFactory.getLogger(FirstOrderSliceRandOM.class);
    protected double[] currentWeights;
    protected double[] previousWeights;
    protected double[] sum;
    protected double[] sumSq;
    protected double variance;
    protected int nextKernel;
    protected double current;
    protected double l;
    protected double r;
    protected final double stepSize = 0.5d;
    protected final int maxNumSteps = 20;

    public FirstOrderSliceRandOM(Model model, Database database, Database database2, ConfigBundle configBundle) {
        super(model, database, database2, configBundle);
        this.stepSize = 0.5d;
        this.maxNumSteps = 20;
        this.variance = this.initialVariance;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.random.SliceRandOM, edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void doLearn() {
        this.currentWeights = new double[this.kernels.size()];
        this.previousWeights = new double[this.kernels.size()];
        for (int i = 0; i < this.previousWeights.length; i++) {
            this.currentWeights[i] = this.kernels.get(i).getWeight().getWeight();
            this.previousWeights[i] = this.kernels.get(i).getWeight().getWeight();
        }
        this.sum = new double[this.kernels.size()];
        this.sumSq = new double[this.kernels.size()];
        super.doLearn();
    }

    @Override // edu.umd.cs.psl.application.learning.weight.random.SliceRandOM
    protected void prepareForRound() {
        for (int i = 0; i < this.kernels.size(); i++) {
            this.sum[i] = 0.0d;
            this.sumSq[i] = 0.0d;
        }
        this.nextKernel = 0;
    }

    @Override // edu.umd.cs.psl.application.learning.weight.random.SliceRandOM
    protected int getNumStepsPerSample() {
        return this.kernels.size();
    }

    @Override // edu.umd.cs.psl.application.learning.weight.random.SliceRandOM
    protected void stepOut() {
        this.current = this.currentWeights[this.nextKernel];
        this.l = this.current - (0.5d * this.rand.nextDouble());
        this.r = this.l + 0.5d;
        int floor = (int) Math.floor(20.0d * this.rand.nextDouble());
        int i = 19 - floor;
        while (floor > 0 && this.sliceHeight < moveAndCheck(this.l)) {
            log.info("Stepped left.");
            this.l -= 0.5d;
            floor--;
        }
        while (i > 0 && this.sliceHeight < moveAndCheck(this.r)) {
            log.info("Stepped right.");
            this.r += 0.5d;
            i--;
        }
        log.info("L: {}, R: {}", Double.valueOf(this.l), Double.valueOf(this.r));
        log.info("J: {}, K: {}", Integer.valueOf(floor), Integer.valueOf(i));
    }

    @Override // edu.umd.cs.psl.application.learning.weight.random.SliceRandOM
    protected double stepIn() {
        double d = this.l;
        double d2 = this.r;
        for (int i = 0; i < 100; i++) {
            double nextDouble = d + (this.rand.nextDouble() * (d2 - d));
            double moveAndCheck = moveAndCheck(nextDouble);
            log.info("Likelihood at weight {}: {}", Double.valueOf(nextDouble), Double.valueOf(moveAndCheck));
            if (this.sliceHeight < moveAndCheck || Math.abs(d2 - d) < 1.0E-8d) {
                if (Math.abs(d2 - d) < 1.0E-8d) {
                    log.warn("Interval collapsed.");
                }
                this.nextKernel++;
                if (this.nextKernel == this.kernels.size()) {
                    this.nextKernel = 0;
                }
                return moveAndCheck;
            }
            if (nextDouble < this.current) {
                d = nextDouble;
            } else {
                d2 = nextDouble;
            }
        }
        throw new IllegalStateException("Step in failed.");
    }

    @Override // edu.umd.cs.psl.application.learning.weight.random.SliceRandOM
    protected double getLogLikelihoodSampledWeights() {
        double d = 0.0d;
        for (int i = 0; i < this.kernels.size(); i++) {
            d -= Math.pow(this.currentWeights[i] - this.kernelMeans[i], 2.0d) / (2.0d * this.initialVariance);
        }
        return d;
    }

    @Override // edu.umd.cs.psl.application.learning.weight.random.SliceRandOM
    protected void processSample() {
        for (int i = 0; i < this.kernels.size(); i++) {
            log.warn("Current weight : {}, mean : {}", Double.valueOf(this.currentWeights[i]), Double.valueOf(this.kernelMeans[i]));
            double[] dArr = this.sum;
            int i2 = i;
            dArr[i2] = dArr[i2] + this.currentWeights[i];
            double[] dArr2 = this.sumSq;
            int i3 = i;
            dArr2[i3] = dArr2[i3] + (this.currentWeights[i] * this.currentWeights[i]);
        }
    }

    @Override // edu.umd.cs.psl.application.learning.weight.random.SliceRandOM
    protected void finishRound() {
        this.variance = 0.0d;
        for (int i = 0; i < this.kernels.size(); i++) {
            this.kernelMeans[i] = this.sum[i] / (this.numSamples - this.burnIn);
            this.kernelVariances[i] = (this.sumSq[i] - ((this.sum[i] * this.sum[i]) / (this.numSamples - this.burnIn))) / ((this.numSamples - this.burnIn) - 1);
            this.variance += this.kernelVariances[i];
            log.warn("Variance of {} for kernel {}", Double.valueOf(this.kernelVariances[i]), this.kernels.get(i));
        }
        this.variance /= this.kernels.size();
        this.variance = Math.max(this.variance, 0.001d);
        log.warn("Variance: {}", Double.valueOf(this.variance));
    }

    private double moveAndCheck(double d) {
        this.currentWeights[this.nextKernel] = d;
        this.kernels.get(this.nextKernel).setWeight(new PositiveWeight(Math.max(d, 0.0d)));
        Iterator<GroundKernel> it = this.reasoner.getGroundKernels(this.kernels.get(this.nextKernel)).iterator();
        while (it.hasNext()) {
            this.reasoner.changedGroundKernelWeight((GroundCompatibilityKernel) it.next());
        }
        this.reasoner.optimize();
        double logLikelihoodObservations = getLogLikelihoodObservations();
        double logLikelihoodSampledWeights = getLogLikelihoodSampledWeights();
        log.info("Likelihood of observations: {}", Double.valueOf(logLikelihoodObservations));
        log.info("likelihood of weights: {}", Double.valueOf(logLikelihoodSampledWeights));
        log.info("Total likelihood: {}", Double.valueOf(logLikelihoodObservations + logLikelihoodSampledWeights));
        return logLikelihoodObservations + logLikelihoodSampledWeights;
    }
}
