package edu.umd.cs.psl.ui.functions.textsimilarity;

import cern.colt.list.tint.IntArrayList;
import cern.colt.map.tdouble.OpenIntDoubleHashMap;
import edu.umd.cs.psl.database.ReadOnlyDatabase;
import edu.umd.cs.psl.model.argument.ArgumentType;
import edu.umd.cs.psl.model.argument.GroundTerm;
import edu.umd.cs.psl.model.argument.StringAttribute;
import edu.umd.cs.psl.model.function.ExternalFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/ui/functions/textsimilarity/CosineSimilarity.class */
public class CosineSimilarity implements ExternalFunction {
    private static final Logger log;
    private static final double epsilon = 1.0E-5d;
    private static final double defaultSimilarityThreshold = 0.4d;
    private int numComputed;
    private final double similarityThreshold;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/umd/cs/psl/ui/functions/textsimilarity/CosineSimilarity$WordVector.class */
    public static class WordVector extends OpenIntDoubleHashMap {
        private static final long serialVersionUID = 2045184972598485102L;

        public int getNumWords() {
            return size();
        }

        public int getMaxWordIndex() {
            IntArrayList keys = keys();
            int[] elements = keys.elements();
            int i = -1;
            for (int i2 = 0; i2 < keys.size(); i2++) {
                if (elements[i2] >= i) {
                    i = elements[i2] + 1;
                }
            }
            return i;
        }

        public void addWord(int i, double d) {
            put(i, d);
        }
    }

    static {
        $assertionsDisabled = !CosineSimilarity.class.desiredAssertionStatus();
        log = LoggerFactory.getLogger(CosineSimilarity.class);
    }

    public CosineSimilarity() {
        this(defaultSimilarityThreshold);
    }

    public CosineSimilarity(double d) {
        this.numComputed = 0;
        this.similarityThreshold = d;
    }

    @Override // edu.umd.cs.psl.model.function.ExternalFunction
    public int getArity() {
        return 2;
    }

    @Override // edu.umd.cs.psl.model.function.ExternalFunction
    public ArgumentType[] getArgumentTypes() {
        return new ArgumentType[]{ArgumentType.String, ArgumentType.String};
    }

    @Override // edu.umd.cs.psl.model.function.ExternalFunction
    public double getValue(ReadOnlyDatabase readOnlyDatabase, GroundTerm... groundTermArr) {
        double cosineSimilarity = cosineSimilarity(getVector(((StringAttribute) groundTermArr[0]).getValue()), getVector(((StringAttribute) groundTermArr[1]).getValue()));
        this.numComputed++;
        if (this.numComputed % 10000 == 0) {
            log.debug("Num computed{} | Similarity {}", Integer.valueOf(this.numComputed), Double.valueOf(cosineSimilarity));
        }
        if (cosineSimilarity > this.similarityThreshold) {
            return cosineSimilarity;
        }
        return 0.0d;
    }

    public String toString() {
        return "Cosine Similarity";
    }

    public static WordVector getVector(String str) {
        WordVector wordVector = new WordVector();
        if (!str.isEmpty()) {
            String[] split = str.split(" ");
            if (!$assertionsDisabled && split.length <= 0) {
                throw new AssertionError();
            }
            for (int i = 0; i < split.length; i++) {
                String[] split2 = split[i].split(":");
                if (!$assertionsDisabled && split2.length != 2) {
                    throw new AssertionError(String.valueOf(split[i]) + " | " + str + ">");
                }
                wordVector.addWord(Integer.parseInt(split2[0]), Double.parseDouble(split2[1]));
            }
        }
        return wordVector;
    }

    private static double vecLength(WordVector wordVector) {
        double[] elements = wordVector.values().elements();
        double d = 0.0d;
        for (int i = 0; i < wordVector.size(); i++) {
            d += elements[i] * elements[i];
        }
        return Math.sqrt(d);
    }

    private static double multiplyVec(WordVector wordVector, WordVector wordVector2) {
        WordVector wordVector3 = wordVector;
        WordVector wordVector4 = wordVector2;
        if (wordVector2.size() < wordVector.size()) {
            wordVector3 = wordVector2;
            wordVector4 = wordVector;
        }
        double d = 0.0d;
        int[] elements = wordVector3.keys().elements();
        for (int i = 0; i < wordVector3.size(); i++) {
            if (wordVector4.containsKey(elements[i])) {
                d += wordVector3.get(elements[i]) * wordVector4.get(elements[i]);
            }
        }
        return d;
    }

    public static double cosineSimilarity(WordVector wordVector, WordVector wordVector2) {
        double vecLength = vecLength(wordVector);
        double vecLength2 = vecLength(wordVector2);
        double d = 0.0d;
        if (vecLength > 0.0d && vecLength2 > 0.0d) {
            d = multiplyVec(wordVector, wordVector2) / (vecLength * vecLength2);
        }
        if ($assertionsDisabled || (d >= -1.0E-5d && d <= 1.00001d)) {
            return d;
        }
        throw new AssertionError();
    }
}
