package edu.umd.cs.psl.model.kernel.setdefinition;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import edu.umd.cs.psl.application.groundkernelstore.GroundKernelStore;
import edu.umd.cs.psl.database.DatabaseQuery;
import edu.umd.cs.psl.database.ResultList;
import edu.umd.cs.psl.model.argument.GroundTerm;
import edu.umd.cs.psl.model.argument.Term;
import edu.umd.cs.psl.model.argument.Variable;
import edu.umd.cs.psl.model.argument.VariableTypeMap;
import edu.umd.cs.psl.model.atom.AtomEvent;
import edu.umd.cs.psl.model.atom.AtomEventFramework;
import edu.umd.cs.psl.model.atom.AtomManager;
import edu.umd.cs.psl.model.atom.GroundAtom;
import edu.umd.cs.psl.model.atom.QueryAtom;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.atom.VariableAssignment;
import edu.umd.cs.psl.model.formula.Conjunction;
import edu.umd.cs.psl.model.formula.Formula;
import edu.umd.cs.psl.model.formula.FormulaAnalysis;
import edu.umd.cs.psl.model.formula.traversal.FormulaGrounder;
import edu.umd.cs.psl.model.kernel.AbstractKernel;
import edu.umd.cs.psl.model.kernel.ConstraintKernel;
import edu.umd.cs.psl.model.kernel.Kernel;
import edu.umd.cs.psl.model.kernel.rule.AbstractGroundRule;
import edu.umd.cs.psl.model.parameters.Parameters;
import edu.umd.cs.psl.model.predicate.Predicate;
import edu.umd.cs.psl.model.predicate.SpecialPredicate;
import edu.umd.cs.psl.model.predicate.StandardPredicate;
import edu.umd.cs.psl.model.set.aggregator.EntityAggregatorFunction;
import edu.umd.cs.psl.model.set.membership.SoftTermMembership;
import edu.umd.cs.psl.model.set.membership.TermMembership;
import edu.umd.cs.psl.model.set.term.BasicSetTerm;
import edu.umd.cs.psl.model.set.term.SetTerm;
import edu.umd.cs.psl.ui.aggregators.AggregateConstantSetOverlap;
import edu.umd.cs.psl.ui.aggregators.AggregateSetAverage;
import edu.umd.cs.psl.ui.aggregators.AggregateSetEquality;
import edu.umd.cs.psl.ui.aggregators.AggregateSetOverlap;
import edu.umd.cs.psl.util.dynamicclass.DynamicClassLoader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/model/kernel/setdefinition/SetDefinitionKernel.class */
public class SetDefinitionKernel extends AbstractKernel implements ConstraintKernel {
    private static final Logger log;
    final SetTerm set1;
    final SetTerm set2;
    final Variable[] argumentVariableMap;
    final Predicate comparisonPredicate;
    final EntityAggregatorFunction setCompareFct;
    final StandardPredicate setPredicate;
    final Map<Variable, Integer> variablePosition;
    private final List<FormulaAnalysis.DNFClause> triggerFormulas;
    private final List<Variable> projection;
    private final List<Set<BasicSetTerm>> sets;
    private final boolean isSoftSet;
    private final int hashcode;
    static final Map<String, Class<? extends EntityAggregatorFunction>> definedSetComparatorFun;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !SetDefinitionKernel.class.desiredAssertionStatus();
        log = LoggerFactory.getLogger(SetDefinitionKernel.class);
        definedSetComparatorFun = new ImmutableMap.Builder().put("setequality", AggregateSetEquality.class).put("setaverage", AggregateSetAverage.class).put("setoverlap", AggregateSetOverlap.class).put("setconstant", AggregateConstantSetOverlap.class).build();
    }

    public SetDefinitionKernel(StandardPredicate standardPredicate, SetTerm setTerm, SetTerm setTerm2, Variable[] variableArr, Predicate predicate, EntityAggregatorFunction entityAggregatorFunction, boolean z) {
        if (!(predicate instanceof StandardPredicate) && !(predicate instanceof SpecialPredicate)) {
            throw new IllegalArgumentException("Expected basic predicate for comparison!");
        }
        this.set1 = setTerm;
        this.set2 = setTerm2;
        this.argumentVariableMap = variableArr;
        this.comparisonPredicate = predicate;
        this.setPredicate = standardPredicate;
        this.setCompareFct = entityAggregatorFunction;
        this.isSoftSet = z;
        this.projection = Arrays.asList(this.argumentVariableMap);
        this.sets = Lists.newArrayList();
        this.sets.add(this.set1.getBasicTerms());
        this.sets.add(this.set2.getBasicTerms());
        this.triggerFormulas = new ArrayList(this.sets.get(0).size() * this.sets.get(1).size());
        if (this.setPredicate.getArity() != variableArr.length) {
            throw new IllegalArgumentException("Number of variables does not match predicate arity");
        }
        VariableTypeMap anchorVariables = setTerm.getAnchorVariables(new VariableTypeMap());
        setTerm2.getAnchorVariables(anchorVariables);
        for (int i = 0; i < variableArr.length; i++) {
            if (!anchorVariables.hasVariable(variableArr[i])) {
                throw new IllegalArgumentException("Variable does not occur in either set term: " + variableArr[i]);
            }
            if (!anchorVariables.getType(variableArr[i]).equals(this.setPredicate.getArgumentType(i))) {
                throw new IllegalArgumentException("Variable does not have matching predicate argument type: " + variableArr[i]);
            }
        }
        for (BasicSetTerm basicSetTerm : this.sets.get(0)) {
            for (BasicSetTerm basicSetTerm2 : this.sets.get(1)) {
                if (!$assertionsDisabled && (basicSetTerm.getLeaf() instanceof Variable) && (basicSetTerm2.getLeaf() instanceof Variable) && basicSetTerm.getLeaf().equals(basicSetTerm2.getLeaf())) {
                    throw new AssertionError();
                }
                Formula queryAtom = new QueryAtom(this.comparisonPredicate, basicSetTerm.getLeaf(), basicSetTerm2.getLeaf());
                Formula[] formulaArr = {basicSetTerm.getFormula(), basicSetTerm2.getFormula()};
                for (int i2 = 0; i2 < formulaArr.length; i2++) {
                    if (formulaArr[i2] != null) {
                        queryAtom = new Conjunction(formulaArr[i2], queryAtom);
                    }
                }
                this.triggerFormulas.add(new FormulaAnalysis(queryAtom).getDNFClause(0));
            }
        }
        this.hashcode = new HashCodeBuilder().append(this.setPredicate).toHashCode();
        this.variablePosition = new HashMap(this.argumentVariableMap.length);
        for (int i3 = 0; i3 < this.argumentVariableMap.length; i3++) {
            this.variablePosition.put(this.argumentVariableMap[i3], Integer.valueOf(i3));
        }
    }

    public SetDefinitionKernel(StandardPredicate standardPredicate, SetTerm setTerm, SetTerm setTerm2, Variable[] variableArr, Predicate predicate, String str, boolean z) {
        this(standardPredicate, setTerm, setTerm2, variableArr, predicate, parseDefinition(str), z);
    }

    public SetDefinitionKernel(StandardPredicate standardPredicate, SetTerm setTerm, SetTerm setTerm2, Variable[] variableArr, Predicate predicate, String str) {
        this(standardPredicate, setTerm, setTerm2, variableArr, predicate, parseDefinition(str));
    }

    public SetDefinitionKernel(StandardPredicate standardPredicate, SetTerm setTerm, SetTerm setTerm2, Variable[] variableArr, Predicate predicate, EntityAggregatorFunction entityAggregatorFunction) {
        this(standardPredicate, setTerm, setTerm2, variableArr, predicate, entityAggregatorFunction, false);
    }

    public EntityAggregatorFunction getAggregator() {
        return this.setCompareFct;
    }

    @Override // edu.umd.cs.psl.model.kernel.AbstractKernel, edu.umd.cs.psl.model.kernel.Kernel
    /* renamed from: clone */
    public Kernel m64clone() {
        return new SetDefinitionKernel(this.setPredicate, this.set1, this.set2, this.argumentVariableMap, this.comparisonPredicate, this.setCompareFct, this.isSoftSet);
    }

    public String getName() {
        StringBuilder sb = new StringBuilder();
        sb.append("{").append(this.set1).append("} ");
        sb.append(this.setCompareFct.getName()).append("(").append(this.comparisonPredicate).append(")");
        sb.append(" {").append(this.set2).append("}");
        sb.append(" =: ").append(this.setPredicate);
        sb.append(this.isSoftSet ? "[soft]" : "[]");
        return sb.toString();
    }

    @Override // edu.umd.cs.psl.model.kernel.AbstractKernel, edu.umd.cs.psl.model.kernel.Kernel
    public Parameters getParameters() {
        return Parameters.NoParameters;
    }

    @Override // edu.umd.cs.psl.model.kernel.AbstractKernel, edu.umd.cs.psl.model.kernel.Kernel
    public void setParameters(Parameters parameters) {
        throw new UnsupportedOperationException("Aggregate Predicates have no parameters!");
    }

    public String toString() {
        return getName();
    }

    @Override // edu.umd.cs.psl.model.kernel.Kernel
    public void groundAll(AtomManager atomManager, GroundKernelStore groundKernelStore) {
        for (int i = 0; i < this.triggerFormulas.size(); i++) {
            DatabaseQuery databaseQuery = new DatabaseQuery(this.triggerFormulas.get(i).getQueryFormula());
            databaseQuery.getProjectionSubset().addAll(this.projection);
            ResultList executeQuery = atomManager.executeQuery(databaseQuery);
            log.debug("Grounding size {} for formula {}", Integer.valueOf(executeQuery.size()), this.triggerFormulas.get(i).getQueryFormula());
            for (int i2 = 0; i2 < executeQuery.size(); i2++) {
                newSetDefinition(atomManager, groundKernelStore, executeQuery.get(i2), true);
            }
        }
    }

    @Override // edu.umd.cs.psl.model.kernel.AbstractKernel
    public void notifyAtomEvent(AtomEvent atomEvent, GroundKernelStore groundKernelStore) {
        RandomVariableAtom atom = atomEvent.getAtom();
        AtomEventFramework eventFramework = atomEvent.getEventFramework();
        if (!atomEvent.getType().equals(AtomEvent.Type.ActivatedRVAtom)) {
            throw new UnsupportedOperationException("Currently, only activation events are supporte: " + atomEvent);
        }
        if (atom.getPredicate().equals(this.setPredicate)) {
            if (atom.getRegisteredGroundKernels(this).isEmpty()) {
                newSetDefinition(eventFramework, groundKernelStore, atom.getArguments(), true);
                return;
            }
            return;
        }
        if (!atom.getPredicate().equals(this.comparisonPredicate)) {
            throw new UnsupportedOperationException("Currently, the set membership formulas must be fact based only!");
        }
        int i = 0;
        for (FormulaAnalysis.DNFClause dNFClause : this.triggerFormulas) {
            List<VariableAssignment> traceAtomEvent = dNFClause.traceAtomEvent(atom);
            if (!traceAtomEvent.isEmpty()) {
                i += traceAtomEvent.size();
                if (traceAtomEvent.size() > 1) {
                    throw new UnsupportedOperationException("Second order ativation is not yet supported!");
                }
                for (VariableAssignment variableAssignment : traceAtomEvent) {
                    log.trace("{}", dNFClause.getQueryFormula());
                    DatabaseQuery databaseQuery = new DatabaseQuery(dNFClause.getQueryFormula());
                    databaseQuery.getPartialGrounding().putAll(variableAssignment);
                    databaseQuery.getProjectionSubset().addAll(this.projection);
                    ResultList executeQuery = eventFramework.executeQuery(databaseQuery);
                    for (int i2 = 0; i2 < executeQuery.size(); i2++) {
                        newSetDefinition(eventFramework, groundKernelStore, executeQuery.get(i2), false);
                    }
                }
            }
        }
        if (i == 0) {
            throw new IllegalArgumentException("No event is actually triggered!");
        }
    }

    private void newSetDefinition(AtomManager atomManager, GroundKernelStore groundKernelStore, GroundTerm[] groundTermArr, boolean z) {
        GroundAtom atom = atomManager.getAtom(this.setPredicate, groundTermArr);
        if (atom.getRegisteredGroundKernels(this).isEmpty()) {
            VariableAssignment variableAssignment = new VariableAssignment();
            for (int i = 0; i < groundTermArr.length; i++) {
                variableAssignment.assign(this.argumentVariableMap[i], groundTermArr[i]);
            }
            SoftTermMembership[] softTermMembershipArr = new SoftTermMembership[2];
            for (int i2 = 0; i2 < 2; i2++) {
                softTermMembershipArr[i2] = new SoftTermMembership();
                for (BasicSetTerm basicSetTerm : this.sets.get(i2)) {
                    if (basicSetTerm.getFormula() == null) {
                        Term leaf = basicSetTerm.getLeaf();
                        if (leaf instanceof GroundTerm) {
                            softTermMembershipArr[i2].addMember((GroundTerm) leaf, 1.0d);
                        } else {
                            softTermMembershipArr[i2].addMember(variableAssignment.getVariable((Variable) leaf), 1.0d);
                        }
                    } else {
                        DatabaseQuery databaseQuery = new DatabaseQuery(basicSetTerm.getFormula());
                        databaseQuery.getPartialGrounding().putAll(variableAssignment);
                        if (this.isSoftSet) {
                            FormulaGrounder formulaGrounder = new FormulaGrounder(atomManager, atomManager.executeQuery(databaseQuery), variableAssignment);
                            while (formulaGrounder.hasNext()) {
                                softTermMembershipArr[i2].addMember(formulaGrounder.getResultVariable((Variable) basicSetTerm.getLeaf()), AbstractGroundRule.formulaNorm.getTruthValue(formulaGrounder.ground(basicSetTerm.getFormula())));
                                formulaGrounder.next();
                            }
                        } else {
                            databaseQuery.getProjectionSubset().add((Variable) basicSetTerm.getLeaf());
                            ResultList executeQuery = atomManager.executeQuery(databaseQuery);
                            for (int i3 = 0; i3 < executeQuery.size(); i3++) {
                                softTermMembershipArr[i2].addMember(executeQuery.get(i3)[0], 1.0d);
                            }
                        }
                    }
                }
            }
            if (z || enoughSupport(atomManager, softTermMembershipArr[0], softTermMembershipArr[1])) {
                HashSet hashSet = new HashSet();
                boolean z2 = true;
                Iterator<GroundTerm> it = softTermMembershipArr[0].iterator();
                while (it.hasNext()) {
                    GroundTerm next = it.next();
                    Iterator<GroundTerm> it2 = softTermMembershipArr[1].iterator();
                    while (it2.hasNext()) {
                        hashSet.add(atomManager.getAtom(this.comparisonPredicate, next, it2.next()));
                        z2 = false;
                    }
                }
                if (z2) {
                    groundKernelStore.addGroundKernel(new GroundEmptySetDefinition(this, atom, this.setCompareFct.aggregateValue(softTermMembershipArr[0], softTermMembershipArr[1], hashSet)));
                } else {
                    groundKernelStore.addGroundKernel(new GroundSetDefinition(this, atom, softTermMembershipArr[0], softTermMembershipArr[1], hashSet));
                }
            }
        }
    }

    private boolean enoughSupport(AtomManager atomManager, TermMembership termMembership, TermMembership termMembership2) {
        HashSet hashSet = new HashSet();
        Iterator it = termMembership.iterator();
        while (it.hasNext()) {
            GroundTerm groundTerm = (GroundTerm) it.next();
            Iterator it2 = termMembership2.iterator();
            while (it2.hasNext()) {
                GroundAtom atom = atomManager.getAtom(this.comparisonPredicate, groundTerm, (GroundTerm) it2.next());
                if (atom != null) {
                    hashSet.add(atom);
                }
            }
        }
        return this.setCompareFct.enoughSupport(termMembership, termMembership2, hashSet);
    }

    @Override // edu.umd.cs.psl.model.kernel.AbstractKernel
    public void registerForAtomEvents(AtomEventFramework atomEventFramework) {
        Iterator<FormulaAnalysis.DNFClause> it = this.triggerFormulas.iterator();
        while (it.hasNext()) {
            it.next().registerClauseForEvents(atomEventFramework, AtomEvent.ActivatedEventTypeSet, this);
        }
        atomEventFramework.registerAtomEventListener(AtomEvent.ActivatedEventTypeSet, this.setPredicate, this);
    }

    @Override // edu.umd.cs.psl.model.kernel.AbstractKernel
    public void unregisterForAtomEvents(AtomEventFramework atomEventFramework) {
        Iterator<FormulaAnalysis.DNFClause> it = this.triggerFormulas.iterator();
        while (it.hasNext()) {
            it.next().unregisterClauseForEvents(atomEventFramework, AtomEvent.ActivatedEventTypeSet, this);
        }
        atomEventFramework.unregisterAtomEventListener(AtomEvent.ActivatedEventTypeSet, this.setPredicate, this);
    }

    public int hashCode() {
        return this.hashcode;
    }

    static final EntityAggregatorFunction parseDefinition(String str) {
        try {
            return (EntityAggregatorFunction) DynamicClassLoader.loadClassArbitraryArgs(str, definedSetComparatorFun, EntityAggregatorFunction.class);
        } catch (Throwable th) {
            th.printStackTrace();
            throw new AssertionError("Unknown similarity function: " + str);
        }
    }
}
