package edu.umd.cs.psl.util.datasplitter.splitstep;

import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.database.Partition;
import edu.umd.cs.psl.database.loading.Inserter;
import edu.umd.cs.psl.model.argument.GroundTerm;
import edu.umd.cs.psl.model.atom.GroundAtom;
import edu.umd.cs.psl.model.predicate.StandardPredicate;
import edu.umd.cs.psl.util.database.Queries;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;

/* loaded from: input_file:edu/umd/cs/psl/util/datasplitter/splitstep/PredicateUniformSplitStep.class */
public class PredicateUniformSplitStep implements SplitStep {
    private static final int NO_GROUP = -1;
    private StandardPredicate target;
    private int numFolds;
    private int groupBy;

    public PredicateUniformSplitStep(StandardPredicate standardPredicate, int i, int i2) {
        this.target = standardPredicate;
        this.numFolds = i;
        this.groupBy = i2;
    }

    public PredicateUniformSplitStep(StandardPredicate standardPredicate, int i) {
        this(standardPredicate, i, NO_GROUP);
    }

    @Override // edu.umd.cs.psl.util.datasplitter.splitstep.SplitStep
    public List<Collection<Partition>> getSplits(Database database, Random random) {
        Collection<Set<GroundAtom>> values;
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        Set<GroundAtom> allAtoms = Queries.getAllAtoms(database, this.target);
        if (this.groupBy == NO_GROUP) {
            values = new ArrayList(allAtoms.size());
            for (GroundAtom groundAtom : allAtoms) {
                HashSet hashSet = new HashSet();
                hashSet.add(groundAtom);
                values.add(hashSet);
            }
        } else {
            for (GroundAtom groundAtom2 : allAtoms) {
                GroundTerm groundTerm = groundAtom2.getArguments()[this.groupBy];
                if (hashMap.get(groundTerm) == null) {
                    hashMap.put(groundTerm, new TreeSet());
                }
                ((Set) hashMap.get(groundTerm)).add(groundAtom2);
            }
            values = hashMap.values();
        }
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < this.numFolds; i++) {
            Partition nextPartition = database.getDataStore().getNextPartition();
            arrayList2.add(nextPartition);
            arrayList3.add(database.getDataStore().getInserter(this.target, nextPartition));
        }
        insertIntoPartitions(values, arrayList3, random);
        for (int i2 = 0; i2 < this.numFolds; i2++) {
            TreeSet treeSet = new TreeSet();
            for (int i3 = 0; i3 < this.numFolds; i3++) {
                if (i3 != i2) {
                    treeSet.add((Partition) arrayList2.get(i3));
                }
            }
            arrayList.add(treeSet);
        }
        return arrayList;
    }

    private void insertIntoPartitions(Collection<Set<GroundAtom>> collection, List<Inserter> list, Random random) {
        ArrayList arrayList = new ArrayList(collection.size());
        arrayList.addAll(collection);
        Collections.shuffle(arrayList, random);
        int i = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            for (GroundAtom groundAtom : (Set) it.next()) {
                list.get(i % this.numFolds).insertValue(groundAtom.getValue(), groundAtom.getArguments());
            }
            i++;
        }
    }
}
