package edu.umd.cs.psl.application.learning.weight;

import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.database.DatabaseQuery;
import edu.umd.cs.psl.database.ResultList;
import edu.umd.cs.psl.model.argument.GroundTerm;
import edu.umd.cs.psl.model.argument.Variable;
import edu.umd.cs.psl.model.atom.AtomManager;
import edu.umd.cs.psl.model.atom.GroundAtom;
import edu.umd.cs.psl.model.atom.ObservedAtom;
import edu.umd.cs.psl.model.atom.QueryAtom;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.predicate.Predicate;
import edu.umd.cs.psl.model.predicate.StandardPredicate;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/TrainingMap.class */
public class TrainingMap implements AtomManager {
    private final Database rvDB;
    private final Map<RandomVariableAtom, ObservedAtom> trainingMap = new HashMap();
    private final Set<RandomVariableAtom> latentVariables = new HashSet();

    public TrainingMap(Database database, Database database2) {
        this.rvDB = database;
        for (StandardPredicate standardPredicate : database.getRegisteredPredicates()) {
            if (!database.isClosed(standardPredicate)) {
                Variable[] variableArr = new Variable[standardPredicate.getArity()];
                for (int i = 0; i < variableArr.length; i++) {
                    variableArr[i] = new Variable("V" + String.valueOf(i));
                }
                ResultList executeQuery = database.executeQuery(new DatabaseQuery(new QueryAtom(standardPredicate, variableArr)));
                for (int i2 = 0; i2 < executeQuery.size(); i2++) {
                    GroundAtom atom = database.getAtom(standardPredicate, executeQuery.get(i2));
                    if (atom instanceof RandomVariableAtom) {
                        GroundAtom atom2 = database2.getAtom(standardPredicate, executeQuery.get(i2));
                        if (atom2 instanceof ObservedAtom) {
                            this.trainingMap.put((RandomVariableAtom) atom, (ObservedAtom) atom2);
                        } else {
                            this.latentVariables.add((RandomVariableAtom) atom);
                        }
                    }
                }
            }
        }
    }

    public Map<RandomVariableAtom, ObservedAtom> getTrainingMap() {
        return Collections.unmodifiableMap(this.trainingMap);
    }

    public Set<RandomVariableAtom> getLatentVariables() {
        return Collections.unmodifiableSet(this.latentVariables);
    }

    @Override // edu.umd.cs.psl.model.atom.AtomManager
    public GroundAtom getAtom(Predicate predicate, GroundTerm... groundTermArr) {
        GroundAtom atom = this.rvDB.getAtom(predicate, groundTermArr);
        if (!(atom instanceof RandomVariableAtom)) {
            return atom;
        }
        if (this.trainingMap.containsKey(atom) || this.latentVariables.contains(atom)) {
            return atom;
        }
        throw new IllegalArgumentException("Can only call getAtom() on persisted RandomVariableAtoms using a TrainingMap. Cannot access " + atom);
    }

    @Override // edu.umd.cs.psl.model.atom.AtomManager
    public ResultList executeQuery(DatabaseQuery databaseQuery) {
        return this.rvDB.executeQuery(databaseQuery);
    }

    @Override // edu.umd.cs.psl.model.atom.AtomManager
    public boolean isClosed(StandardPredicate standardPredicate) {
        return this.rvDB.isClosed(standardPredicate);
    }
}
