package edu.umd.cs.psl.sampler;

import edu.umd.cs.psl.application.learning.weight.random.SliceRandOM;
import edu.umd.cs.psl.model.atom.GroundAtom;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.model.kernel.GroundConstraintKernel;
import edu.umd.cs.psl.model.kernel.GroundKernel;
import edu.umd.cs.psl.reasoner.function.AtomFunctionVariable;
import edu.umd.cs.psl.reasoner.function.ConstraintTerm;
import edu.umd.cs.psl.reasoner.function.FunctionComparator;
import edu.umd.cs.psl.reasoner.function.FunctionSum;
import edu.umd.cs.psl.reasoner.function.FunctionSummand;
import edu.umd.cs.psl.reasoner.function.FunctionTerm;
import edu.umd.cs.psl.reasoner.function.util.FunctionAnalyser;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.ujmp.core.Matrix;
import org.ujmp.core.MatrixFactory;
import org.ujmp.core.calculation.Calculation;

/* loaded from: input_file:edu/umd/cs/psl/sampler/AbstractHitAndRunSampler.class */
public abstract class AbstractHitAndRunSampler implements Sampler {
    private static final Logger log;
    public static final int defaultMaxNoSteps = 1000000;
    public static final int defaultSignificantDigits = 4;
    public static final double defaultBurnInStepsPercentage = 0.01d;
    private final double epsilon = 1.0E-4d;
    private final double errorEpsilon = 0.01d;
    private final int maxDimension2Display = 10;
    private final int maxActiveConstraints = 2;
    private final long roundingScheme;
    private final int maxSteps;
    private int noSteps;
    private int noSamples;
    private int dimensions;
    private transient Map<AtomFunctionVariable, Integer> atomIndex;
    protected transient Matrix currentPt;
    private final transient HitAndRunSamplerStatistics stats;
    private static /* synthetic */ int[] $SWITCH_TABLE$edu$umd$cs$psl$reasoner$function$FunctionComparator;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !AbstractHitAndRunSampler.class.desiredAssertionStatus();
        log = LoggerFactory.getLogger(AbstractHitAndRunSampler.class);
    }

    public AbstractHitAndRunSampler() {
        this(4);
    }

    public AbstractHitAndRunSampler(int i) {
        this(i, 4);
    }

    public AbstractHitAndRunSampler(int i, int i2) {
        this.epsilon = 1.0E-4d;
        this.errorEpsilon = 0.01d;
        this.maxDimension2Display = 10;
        this.maxActiveConstraints = 2;
        this.noSteps = 0;
        this.noSamples = 0;
        this.dimensions = 0;
        this.atomIndex = new HashMap();
        this.currentPt = null;
        this.roundingScheme = (long) Math.pow(10.0d, i2);
        this.maxSteps = i;
        this.stats = new HitAndRunSamplerStatistics(this);
    }

    public HitAndRunSamplerStatistics getStatistics() {
        return this.stats;
    }

    public int getNoSamples() {
        return this.noSamples;
    }

    protected int getorSetIndex(AtomFunctionVariable atomFunctionVariable) {
        if (atomFunctionVariable.isConstant()) {
            throw new IllegalArgumentException("Cannot retrieve index for known atom!");
        }
        Integer num = this.atomIndex.get(atomFunctionVariable);
        if (num == null) {
            num = Integer.valueOf(this.dimensions);
            this.dimensions++;
            this.atomIndex.put(atomFunctionVariable, num);
            processNewDimension(atomFunctionVariable, num.intValue());
        }
        return num.intValue();
    }

    protected int getIndex(AtomFunctionVariable atomFunctionVariable) {
        if (atomFunctionVariable.isConstant()) {
            throw new IllegalArgumentException("Cannot retrieve index for known atom!");
        }
        Integer num = this.atomIndex.get(atomFunctionVariable);
        if (num == null) {
            throw new IllegalArgumentException("Atom has not yet been assigned a dimension!");
        }
        return num.intValue();
    }

    protected abstract void processNewDimension(AtomFunctionVariable atomFunctionVariable, int i);

    protected abstract double sampleAlpha(Matrix matrix, Matrix matrix2, Matrix matrix3, double d, double d2);

    protected abstract void processSampledPoint(Iterable<GroundKernel> iterable);

    /* JADX WARN: Failed to find 'out' block for switch in B:44:0x02a1. Please report as an issue. */
    public void sample(Iterable<GroundKernel> iterable, double d, int i) {
        Matrix subMatrix;
        Matrix dense;
        int i2;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        for (GroundKernel groundKernel : iterable) {
            if (groundKernel instanceof GroundCompatibilityKernel) {
                i5++;
            } else {
                if (!(groundKernel instanceof GroundConstraintKernel)) {
                    throw new AssertionError("Unknown evidence type: " + groundKernel);
                }
                ConstraintTerm constraintDefinition = ((GroundConstraintKernel) groundKernel).getConstraintDefinition();
                switch ($SWITCH_TABLE$edu$umd$cs$psl$reasoner$function$FunctionComparator()[constraintDefinition.getComparator().ordinal()]) {
                    case 1:
                        i3++;
                        break;
                    case SliceRandOM.BURN_IN_DEFAULT /* 2 */:
                    case 3:
                        i4++;
                        break;
                    default:
                        throw new AssertionError("Unknown comparator type: " + constraintDefinition.getComparator());
                }
            }
            for (GroundAtom groundAtom : groundKernel.getAtoms()) {
                if (groundAtom instanceof RandomVariableAtom) {
                    getorSetIndex(groundAtom.getVariable());
                }
            }
        }
        log.debug("Dimesions: {}", Integer.valueOf(this.dimensions));
        Matrix sparse = MatrixFactory.sparse(new long[]{i3, this.dimensions});
        Matrix dense2 = MatrixFactory.dense(new long[]{i3, 1});
        log.trace("Equality Constraints: {}", Integer.valueOf(i3));
        Matrix sparse2 = MatrixFactory.sparse(new long[]{i4 + (2 * this.dimensions), this.dimensions});
        Matrix dense3 = MatrixFactory.dense(new long[]{i4 + (2 * this.dimensions), 1});
        log.trace("Inequality Constraints: {}", Integer.valueOf(i4));
        Matrix sparse3 = MatrixFactory.sparse(new long[]{i5, this.dimensions});
        Matrix dense4 = MatrixFactory.dense(new long[]{i5, 1});
        log.trace("Objective Funs: {}", Integer.valueOf(i5));
        int i6 = 0;
        int i7 = 0;
        int i8 = 0;
        for (GroundKernel groundKernel2 : iterable) {
            if (groundKernel2 instanceof GroundCompatibilityKernel) {
                GroundCompatibilityKernel groundCompatibilityKernel = (GroundCompatibilityKernel) groundKernel2;
                FunctionTerm coreObjectiveFunction = FunctionAnalyser.getCoreObjectiveFunction(groundCompatibilityKernel.getFunctionDefinition());
                if (coreObjectiveFunction == null) {
                    coreObjectiveFunction = groundCompatibilityKernel.getFunctionDefinition();
                }
                if (!coreObjectiveFunction.isLinear()) {
                    throw new AssertionError("Expected linear probabilistic evidence only, but got: " + groundCompatibilityKernel);
                }
                dense4.setAsDouble(setMatrixRow(sparse3, i8, coreObjectiveFunction, false), new long[]{i8, 0});
                i8++;
            } else if (groundKernel2 instanceof GroundConstraintKernel) {
                GroundConstraintKernel groundConstraintKernel = (GroundConstraintKernel) groundKernel2;
                ConstraintTerm constraintDefinition2 = groundConstraintKernel.getConstraintDefinition();
                if (!constraintDefinition2.getFunction().isLinear()) {
                    throw new AssertionError("Expected linear constraints only, but got: " + groundConstraintKernel);
                }
                double value = constraintDefinition2.getValue();
                boolean z = false;
                switch ($SWITCH_TABLE$edu$umd$cs$psl$reasoner$function$FunctionComparator()[constraintDefinition2.getComparator().ordinal()]) {
                    case 1:
                        dense2.setAsDouble(value - setMatrixRow(sparse, i6, constraintDefinition2.getFunction(), false), new long[]{i6, 0});
                        i6++;
                        break;
                    case SliceRandOM.BURN_IN_DEFAULT /* 2 */:
                        dense3.setAsDouble(value - setMatrixRow(sparse2, i7, constraintDefinition2.getFunction(), z), new long[]{i7, 0});
                        i7++;
                        break;
                    case 3:
                        z = true;
                        value = -value;
                        dense3.setAsDouble(value - setMatrixRow(sparse2, i7, constraintDefinition2.getFunction(), z), new long[]{i7, 0});
                        i7++;
                        break;
                    default:
                        throw new AssertionError("Unknown comparator type: " + constraintDefinition2.getComparator());
                }
            } else {
                continue;
            }
        }
        if (!$assertionsDisabled && i6 != i3) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && i7 != i4) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && i8 != i5) {
            throw new AssertionError();
        }
        for (int i9 = 0; i9 < this.dimensions; i9++) {
            sparse2.setAsDouble(1.0d, new long[]{i7, i9});
            dense3.setAsDouble(1.0d, new long[]{i7, 0});
            sparse2.setAsDouble(-1.0d, new long[]{i7 + 1, i9});
            dense3.setAsDouble(0.0d, new long[]{i7 + 1, 0});
            i7 += 2;
        }
        if (this.currentPt == null) {
            this.currentPt = MatrixFactory.sparse(new long[]{this.dimensions, 1});
            Iterator<Map.Entry<AtomFunctionVariable, Integer>> it = this.atomIndex.entrySet().iterator();
            while (it.hasNext()) {
                this.currentPt.setAsDouble(it.next().getKey().getValue(), new long[]{r0.getValue().intValue(), 0});
            }
        } else if (this.currentPt != null && this.currentPt.getSize(1) != this.dimensions) {
            if (!$assertionsDisabled && this.currentPt.getSize(1) >= this.dimensions) {
                throw new AssertionError();
            }
            this.currentPt.setSize(new long[]{this.dimensions, 1});
        }
        if (i3 == 0) {
            subMatrix = MatrixFactory.eye(new long[]{this.dimensions, this.dimensions});
        } else {
            log.trace("Aeq {}", Arrays.toString(sparse.getSize()));
            Matrix[] svd = sparse.svd();
            log.trace("U {}", Arrays.toString(svd[0].getSize()));
            log.trace("o {}", Arrays.toString(svd[1].getSize()));
            log.trace("V' {}", Arrays.toString(svd[2].getSize()));
            if (this.dimensions < 10) {
                log.trace("U \n{}", svd[0]);
                log.trace("o \n{}", svd[1]);
                log.trace("V' \n{}", svd[2]);
            }
            if (!$assertionsDisabled && (svd[2].getSize(0) != this.dimensions || svd[2].getSize(1) != this.dimensions)) {
                throw new AssertionError();
            }
            if (this.dimensions < 10) {
                log.trace("SVD: \n{}", svd[0].mtimes(svd[1]).mtimes(svd[2].transpose()));
            }
            int i10 = 0;
            while (i10 < i3 && svd[1].getAsDouble(new long[]{i10, i10}) != 0.0d) {
                i10++;
            }
            log.trace("Rank of Aeq {}", Integer.valueOf(i10));
            subMatrix = svd[2].subMatrix(Calculation.Ret.NEW, 0L, i10, this.dimensions - 1, this.dimensions - 1);
        }
        if (this.dimensions < 10) {
            log.trace("P {}", Arrays.toString(subMatrix.getSize()));
        }
        long size = subMatrix.getSize(1);
        if (this.dimensions < 10) {
            log.trace("Aobj \n{}", sparse3);
            log.trace("objConstant \n{}", dense4);
            log.trace("Aeq \n{}", sparse);
            log.trace("P \n{}", subMatrix);
            log.trace("Aineq \n{}", sparse2);
            log.trace("bsmaller \n{}", dense3);
        }
        this.stats.finishedSetup(i3, i4, i5, this.dimensions, (int) size);
        int i11 = this.maxSteps;
        log.debug("Number of steps to take: {}", Integer.valueOf(i11));
        HashSet hashSet = new HashSet();
        ArrayList arrayList = new ArrayList();
        Random random = new Random();
        boolean z2 = false;
        do {
            if ((this.noSteps + 1) % 10000 == 0) {
                log.debug("Starting step #{}", Integer.valueOf(this.noSteps + 1));
            } else {
                log.trace("Starting step #{}", Integer.valueOf(this.noSteps + 1));
            }
            if (z2) {
                this.stats.inCorner();
                log.debug("Has been cornered with {} active constraints on step {}", Integer.valueOf(arrayList.size()), Integer.valueOf(this.noSteps));
                if (!$assertionsDisabled && arrayList.size() <= 0) {
                    throw new AssertionError(arrayList.size());
                }
                z2 = false;
                int size2 = arrayList.size();
                dense = MatrixFactory.zeros(new long[]{size, 1});
                Matrix mtimes = sparse2.selectRows(Calculation.Ret.NEW, arrayList).mtimes(subMatrix);
                Matrix dense5 = MatrixFactory.dense(new long[]{size2, 1});
                for (int i12 = 0; i12 < size2; i12++) {
                    dense5.setAsDouble(-Math.abs(random.nextDouble()), new long[]{i12, 0});
                }
                int i13 = 0;
                double[] dArr = new double[size2];
                log.trace("Compute norms");
                for (long[] jArr : mtimes.availableCoordinates()) {
                    int i14 = (int) jArr[0];
                    dArr[i14] = dArr[i14] + Math.pow(mtimes.getAsDouble(jArr), 2.0d);
                }
                for (int i15 = 0; i15 < size2; i15++) {
                    dArr[i15] = Math.sqrt(dArr[i15]);
                }
                do {
                    i2 = -1;
                    double d2 = 0.0d;
                    log.trace("Compute violation matrix");
                    Matrix minus = mtimes.mtimes(Calculation.Ret.NEW, false, dense).minus(Calculation.Ret.ORIG, false, dense5);
                    log.trace("Find most violated");
                    for (int i16 = 0; i16 < size2; i16++) {
                        double asDouble = minus.getAsDouble(new long[]{i16, 0}) / dArr[i16];
                        if (asDouble > d2) {
                            i2 = i16;
                            d2 = asDouble;
                        }
                    }
                    if (i2 >= 0) {
                        log.trace("Most violated {}", Integer.valueOf(i2));
                        dense = dense.plus(Calculation.Ret.ORIG, false, mtimes.selectRows(Calculation.Ret.LINK, new long[]{i2}).transpose().times(Calculation.Ret.ORIG, false, (2.0d * (-minus.getAsDouble(new long[]{i2, 0}))) / Math.pow(dArr[i2], 2.0d)));
                    }
                    i13++;
                } while (i2 >= 0);
                log.debug("Iterative direction steps until convergence: {}", Integer.valueOf(i13));
                if (!$assertionsDisabled && sparse2.selectRows(Calculation.Ret.NEW, arrayList).mtimes(subMatrix.mtimes(dense)).getMaxValue() >= 0.0d) {
                    throw new AssertionError(sparse2.selectRows(Calculation.Ret.NEW, arrayList).mtimes(subMatrix.mtimes(dense)).getMaxValue());
                }
                this.stats.outCorner();
            } else {
                dense = MatrixFactory.dense(new long[]{size, 1});
                for (int i17 = 0; i17 < size; i17++) {
                    dense.setAsDouble(random.nextGaussian(), new long[]{i17, 0});
                }
            }
            double d3 = 0.0d;
            for (int i18 = 0; i18 < size; i18++) {
                d3 += Math.pow(dense.getAsDouble(new long[]{i18, 0}), 2.0d);
            }
            dense.times(Calculation.Ret.ORIG, false, 1.0d / Math.sqrt(d3));
            if (this.dimensions < 10) {
                log.trace("Direction \n{}", dense);
            }
            Matrix mtimes2 = subMatrix.mtimes(dense);
            Matrix mtimes3 = sparse2.mtimes(mtimes2);
            Matrix mtimes4 = sparse2.mtimes(this.currentPt);
            arrayList.clear();
            double d4 = Double.NEGATIVE_INFINITY;
            double d5 = Double.POSITIVE_INFINITY;
            for (int i19 = 0; i19 < i4 + (2 * this.dimensions); i19++) {
                double asDouble2 = mtimes3.getAsDouble(new long[]{i19, 0});
                boolean z3 = false;
                if (asDouble2 > 0.0d) {
                    double asDouble3 = (dense3.getAsDouble(new long[]{i19, 0}) - mtimes4.getAsDouble(new long[]{i19, 0})) / asDouble2;
                    d5 = Math.min(d5, asDouble3);
                    if (asDouble3 < 0.01d) {
                        z3 = true;
                    }
                } else if (asDouble2 < 0.0d) {
                    double asDouble4 = (dense3.getAsDouble(new long[]{i19, 0}) - mtimes4.getAsDouble(new long[]{i19, 0})) / asDouble2;
                    d4 = Math.max(d4, asDouble4);
                    if (asDouble4 > -0.01d) {
                        z3 = true;
                    }
                }
                if (z3) {
                    arrayList.add(Integer.valueOf(i19));
                }
            }
            if (!$assertionsDisabled && d5 <= -0.01d) {
                throw new AssertionError(String.valueOf(d5) + "\n" + this.currentPt.toString());
            }
            if (!$assertionsDisabled && d4 >= 0.01d) {
                throw new AssertionError(String.valueOf(d4) + "\n" + this.currentPt.toString());
            }
            if (d5 - d4 >= 1.0E-4d) {
                this.currentPt = this.currentPt.plus(Calculation.Ret.ORIG, false, mtimes2.times(Calculation.Ret.ORIG, false, sampleAlpha(mtimes2, sparse3, dense4, d4, d5)));
                log.trace("New Point \n{}", this.currentPt);
                this.noSteps++;
                if (this.noSteps > 0.01d * i11) {
                    Iterator<Map.Entry<AtomFunctionVariable, Integer>> it2 = this.atomIndex.entrySet().iterator();
                    while (it2.hasNext()) {
                        it2.next().getKey().setValue(Math.round(this.currentPt.getAsDouble(new long[]{getIndex(r0), 0}) * this.roundingScheme) / this.roundingScheme);
                    }
                    processSampledPoint(iterable);
                    this.noSamples++;
                }
            } else if (arrayList.size() > 2) {
                z2 = true;
            }
            if (this.noSteps < i11) {
            }
            this.stats.finish(this.noSamples);
        } while (hashSet.size() < i);
        this.stats.finish(this.noSamples);
    }

    private double setMatrixRow(Matrix matrix, int i, FunctionTerm functionTerm, boolean z) {
        double d = 0.0d;
        if (functionTerm instanceof FunctionSummand) {
            d = setMatrixCell(matrix, i, (FunctionSummand) functionTerm, z);
        } else {
            if (!(functionTerm instanceof FunctionSum)) {
                throw new IllegalArgumentException("Expected sum but was given: " + functionTerm);
            }
            Iterator<FunctionSummand> it = ((FunctionSum) functionTerm).iterator();
            while (it.hasNext()) {
                d += setMatrixCell(matrix, i, it.next(), z);
            }
        }
        return z ? -d : d;
    }

    private double setMatrixCell(Matrix matrix, int i, FunctionSummand functionSummand, boolean z) {
        if (functionSummand.isConstant()) {
            return functionSummand.getValue();
        }
        if (!(functionSummand.getTerm() instanceof AtomFunctionVariable)) {
            throw new IllegalArgumentException("Expected sum of simple variables but got: " + functionSummand);
        }
        int index = getIndex((AtomFunctionVariable) functionSummand.getTerm());
        double coefficient = functionSummand.getCoefficient();
        matrix.setAsDouble(z ? -coefficient : coefficient, new long[]{i, index});
        return 0.0d;
    }

    static /* synthetic */ int[] $SWITCH_TABLE$edu$umd$cs$psl$reasoner$function$FunctionComparator() {
        int[] iArr = $SWITCH_TABLE$edu$umd$cs$psl$reasoner$function$FunctionComparator;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[FunctionComparator.valuesCustom().length];
        try {
            iArr2[FunctionComparator.Equality.ordinal()] = 1;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[FunctionComparator.LargerThan.ordinal()] = 3;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[FunctionComparator.SmallerThan.ordinal()] = 2;
        } catch (NoSuchFieldError unused3) {
        }
        $SWITCH_TABLE$edu$umd$cs$psl$reasoner$function$FunctionComparator = iArr2;
        return iArr2;
    }
}
