package edu.umd.cs.psl.evaluation.result.memory;

import de.mathnbits.statistics.DoubleDist;
import edu.umd.cs.psl.evaluation.result.FullConfidenceAnalysisResult;
import edu.umd.cs.psl.model.predicate.Predicate;
import edu.umd.cs.psl.reasoner.function.AtomFunctionVariable;
import java.util.Map;

/* loaded from: input_file:edu/umd/cs/psl/evaluation/result/memory/MemoryFullConfidenceAnalysisResult.class */
public class MemoryFullConfidenceAnalysisResult implements FullConfidenceAnalysisResult {
    private final Map<AtomFunctionVariable, DoubleDist> distributions;

    public MemoryFullConfidenceAnalysisResult(Map<AtomFunctionVariable, DoubleDist> map) {
        this.distributions = map;
    }

    @Override // edu.umd.cs.psl.evaluation.result.FullConfidenceAnalysisResult
    public Map<AtomFunctionVariable, DoubleDist> getDistribution() {
        return this.distributions;
    }

    @Override // edu.umd.cs.psl.evaluation.result.FullConfidenceAnalysisResult
    public double KLdivergence(AtomFunctionVariable atomFunctionVariable, int i, FullConfidenceAnalysisResult fullConfidenceAnalysisResult) {
        double d = 0.0d;
        double[] histogram = getHistogram(atomFunctionVariable, i);
        double[] histogram2 = fullConfidenceAnalysisResult.getHistogram(atomFunctionVariable, i);
        for (int i2 = 0; i2 < i; i2++) {
            if (histogram[i2] >= 1.0E-8d) {
                d = histogram2[i2] > 1.0E-8d ? d + (histogram[i2] * Math.log(histogram[i2] / histogram2[i2])) : d + (histogram[i2] * Math.log(histogram[i2] / 1.0E-8d));
            }
        }
        return d;
    }

    @Override // edu.umd.cs.psl.evaluation.result.FullConfidenceAnalysisResult
    public double averageKLdivergence(Predicate predicate, int i, FullConfidenceAnalysisResult fullConfidenceAnalysisResult) {
        double d = 0.0d;
        int i2 = 0;
        for (AtomFunctionVariable atomFunctionVariable : this.distributions.keySet()) {
            if (atomFunctionVariable.getAtom().getPredicate().equals(predicate)) {
                d += KLdivergence(atomFunctionVariable, i, fullConfidenceAnalysisResult);
                i2++;
            }
        }
        return d / i2;
    }

    @Override // edu.umd.cs.psl.evaluation.result.FullConfidenceAnalysisResult
    public double[] getHistogram(AtomFunctionVariable atomFunctionVariable, int i) {
        if (!this.distributions.containsKey(atomFunctionVariable)) {
            return new double[i];
        }
        DoubleDist doubleDist = this.distributions.get(atomFunctionVariable);
        double d = doubleDist.totalCount();
        double[] dArr = new double[i];
        for (Double d2 : doubleDist.getBins()) {
            double count = doubleDist.getCount(d2) / d;
            int min = Math.min((int) Math.floor(d2.doubleValue() * i), i - 1);
            dArr[min] = dArr[min] + count;
        }
        return dArr;
    }
}
