package edu.umd.cs.psl.sampler;

import cern.colt.list.tlong.LongArrayList;
import cern.colt.list.tobject.ObjectArrayList;
import cern.colt.map.tobject.OpenLongObjectHashMap;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.ujmp.core.Matrix;

/* loaded from: input_file:edu/umd/cs/psl/sampler/LinearSampler.class */
public abstract class LinearSampler extends AbstractHitAndRunSampler {
    private static final Logger log;
    private final double epsilon = 1.0E-4d;
    private final int alphaPlusSignificantDigits = 3;
    private final long alphaRoundingScheme;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/umd/cs/psl/sampler/LinearSampler$ObjFunHandle.class */
    private static class ObjFunHandle {
        private final long alphaBinned;
        private double rate = 0.0d;
        private double constant = 0.0d;

        ObjFunHandle(long j) {
            this.alphaBinned = j;
        }

        double getRate() {
            return this.rate;
        }

        double getConstant() {
            return this.constant;
        }

        void increaseBy(double d, double d2) {
            this.rate += d;
            this.constant += d2;
        }

        public static final ObjFunHandle getHandle(long j, OpenLongObjectHashMap openLongObjectHashMap) {
            Object obj = openLongObjectHashMap.get(j);
            if (obj == null) {
                obj = new ObjFunHandle(j);
                openLongObjectHashMap.put(j, obj);
            }
            return (ObjFunHandle) obj;
        }

        public String toString() {
            return "{" + this.alphaBinned + "}: R: " + this.rate + " C: " + this.constant;
        }
    }

    static {
        $assertionsDisabled = !LinearSampler.class.desiredAssertionStatus();
        log = LoggerFactory.getLogger(LinearSampler.class);
    }

    public LinearSampler(int i, int i2) {
        super(i, i2);
        this.epsilon = 1.0E-4d;
        this.alphaPlusSignificantDigits = 3;
        this.alphaRoundingScheme = (long) Math.pow(10.0d, i2 + 3);
    }

    @Override // edu.umd.cs.psl.sampler.AbstractHitAndRunSampler
    protected double sampleAlpha(Matrix matrix, Matrix matrix2, Matrix matrix3, double d, double d2) {
        double log2;
        Matrix mtimes = matrix2.mtimes(matrix);
        Matrix mtimes2 = matrix2.mtimes(this.currentPt);
        OpenLongObjectHashMap openLongObjectHashMap = new OpenLongObjectHashMap();
        ObjFunHandle handle = ObjFunHandle.getHandle(Math.round(d * this.alphaRoundingScheme), openLongObjectHashMap);
        for (int i = 0; i < matrix2.getRowCount(); i++) {
            double asDouble = mtimes.getAsDouble(new long[]{i, 0});
            double asDouble2 = mtimes2.getAsDouble(new long[]{i, 0}) + matrix3.getAsDouble(new long[]{i, 0});
            double d3 = (-asDouble2) / asDouble;
            long round = Math.round(d3 * this.alphaRoundingScheme);
            log.trace("Alpha {}", Double.valueOf(d3));
            log.trace("Rate {} Constant {}", Double.valueOf(asDouble), Double.valueOf(asDouble2));
            if (asDouble > 1.0E-4d && d3 < d2) {
                (d3 <= d ? handle : ObjFunHandle.getHandle(round, openLongObjectHashMap)).increaseBy(asDouble, asDouble2);
            } else if (asDouble >= -1.0E-4d || d3 <= d) {
                handle.increaseBy(0.0d, asDouble2);
            } else {
                handle.increaseBy(asDouble, asDouble2);
                if (d3 < d2) {
                    ObjFunHandle.getHandle(round, openLongObjectHashMap).increaseBy(-asDouble, -asDouble2);
                }
            }
        }
        ObjFunHandle.getHandle(Math.round(d2 * this.alphaRoundingScheme), openLongObjectHashMap);
        int size = openLongObjectHashMap.size();
        LongArrayList longArrayList = new LongArrayList(size);
        ObjectArrayList objectArrayList = new ObjectArrayList(size);
        openLongObjectHashMap.pairsSortedByKey(longArrayList, objectArrayList);
        if (!$assertionsDisabled && (longArrayList.size() != size || objectArrayList.size() != size)) {
            throw new AssertionError();
        }
        double[] dArr = new double[size];
        double[] dArr2 = new double[size - 1];
        double[] dArr3 = new double[size - 1];
        long[] elements = longArrayList.elements();
        Object[] elements2 = objectArrayList.elements();
        log.trace("alphaBins: {}", Arrays.toString(elements));
        log.trace("handles: {}", Arrays.toString(elements2));
        ObjFunHandle objFunHandle = (ObjFunHandle) elements2[0];
        double rate = (objFunHandle.getRate() * (elements[0] / this.alphaRoundingScheme)) + objFunHandle.getConstant();
        dArr[0] = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        for (int i2 = 1; i2 < size; i2++) {
            ObjFunHandle objFunHandle2 = (ObjFunHandle) elements2[i2 - 1];
            d4 += objFunHandle2.getRate();
            dArr2[i2 - 1] = d4;
            d5 += objFunHandle2.getConstant();
            dArr3[i2 - 1] = d5;
            double d6 = elements[i2 - 1] / this.alphaRoundingScheme;
            double d7 = elements[i2] / this.alphaRoundingScheme;
            double exp = (d4 < -1.0E-4d || d4 > 1.0E-4d) ? (1.0d / d4) * (Math.exp((rate - (d4 * d6)) - d5) - Math.exp((rate - (d4 * d7)) - d5)) : (d7 - d6) * Math.exp(rate - d5);
            if (!$assertionsDisabled && exp <= -1.0E-4d) {
                throw new AssertionError(exp);
            }
            dArr[i2] = dArr[i2 - 1] + exp;
        }
        double random = Math.random() * dArr[size - 1];
        log.trace("Random Pt {}", Double.valueOf(random));
        log.trace("logcumulative: {}", Arrays.toString(dArr));
        int binarySearch = Arrays.binarySearch(dArr, random);
        log.trace("Find Index {}", Integer.valueOf(binarySearch));
        if (binarySearch >= 0) {
            log2 = elements[binarySearch] / this.alphaRoundingScheme;
        } else {
            int i3 = (-binarySearch) - 1;
            if (!$assertionsDisabled && i3 <= 0) {
                throw new AssertionError(i3);
            }
            double d8 = elements[i3 - 1] / this.alphaRoundingScheme;
            double d9 = dArr2[i3 - 1];
            double d10 = dArr3[i3 - 1];
            log.trace("R {}, c {} aL " + d8, Double.valueOf(d9), Double.valueOf(d10));
            if (d9 < -1.0E-4d || d9 > 1.0E-4d) {
                double exp2 = (-random) + dArr[i3 - 1] + ((1.0d / d9) * Math.exp(((-d10) - (d9 * d8)) + rate));
                log.trace("Target value {}", Double.valueOf(exp2));
                log.trace("Log: {}, inside log {}", Double.valueOf(Math.log(d9 * exp2)), Double.valueOf(d9 * exp2));
                log2 = ((-1.0d) / d9) * ((Math.log(d9 * exp2) + d10) - rate);
            } else {
                double exp3 = Math.exp(rate - d10);
                log2 = ((random / exp3) - (dArr[i3 - 1] / exp3)) + d8;
            }
            if (!$assertionsDisabled && Double.isNaN(log2)) {
                throw new AssertionError("C: " + d10 + " | R: " + d9 + " Pt: " + random + " aL: " + d8 + "  C0:" + dArr[i3 - 1]);
            }
        }
        if (!$assertionsDisabled && (log2 < d - 1.0E-4d || log2 > d2 + 1.0E-4d)) {
            throw new AssertionError(log2);
        }
        double min = Math.min(Math.max(log2, d), d2);
        log.trace("Sampled alpha {}", Double.valueOf(min));
        return min;
    }
}
