package edu.umd.cs.psl.application.learning.weight;

import com.google.common.collect.Iterables;
import edu.umd.cs.psl.application.ModelApplication;
import edu.umd.cs.psl.application.util.Grounding;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.atom.ObservedAtom;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.kernel.CompatibilityKernel;
import edu.umd.cs.psl.reasoner.Reasoner;
import edu.umd.cs.psl.reasoner.ReasonerFactory;
import edu.umd.cs.psl.reasoner.admm.ADMMReasonerFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Observable;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/WeightLearningApplication.class */
public abstract class WeightLearningApplication extends Observable implements ModelApplication {
    public static final String CONFIG_PREFIX = "weightlearning";
    public static final String REASONER_KEY = "weightlearning.reasoner";
    public static final ReasonerFactory REASONER_DEFAULT = new ADMMReasonerFactory();
    protected Model model;
    protected Database rvDB;
    protected Database observedDB;
    protected ConfigBundle config;
    protected final List<CompatibilityKernel> kernels = new ArrayList();
    protected final List<CompatibilityKernel> immutableKernels = new ArrayList();
    protected TrainingMap trainingMap;
    protected Reasoner reasoner;

    public WeightLearningApplication(Model model, Database database, Database database2, ConfigBundle configBundle) {
        this.model = model;
        this.rvDB = database;
        this.observedDB = database2;
        this.config = configBundle;
    }

    public void learn() throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        for (CompatibilityKernel compatibilityKernel : Iterables.filter(this.model.getKernels(), CompatibilityKernel.class)) {
            if (compatibilityKernel.isWeightMutable()) {
                this.kernels.add(compatibilityKernel);
            } else {
                this.immutableKernels.add(compatibilityKernel);
            }
        }
        initGroundModel();
        doLearn();
        this.kernels.clear();
        cleanUpGroundModel();
    }

    protected abstract void doLearn();

    /* JADX INFO: Access modifiers changed from: protected */
    public void initGroundModel() throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        this.trainingMap = new TrainingMap(this.rvDB, this.observedDB);
        this.reasoner = ((ReasonerFactory) this.config.getFactory(REASONER_KEY, REASONER_DEFAULT)).getReasoner(this.config);
        if (this.trainingMap.getLatentVariables().size() > 0) {
            throw new IllegalArgumentException("All RandomVariableAtoms must have corresponding ObservedAtoms. Latent variables are not supported by this WeightLearningApplication. Example latent variable: " + this.trainingMap.getLatentVariables().iterator().next());
        }
        Grounding.groundAll(this.model, this.trainingMap, this.reasoner);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void cleanUpGroundModel() {
        this.trainingMap = null;
        this.reasoner.close();
        this.reasoner = null;
    }

    @Override // edu.umd.cs.psl.application.ModelApplication
    public void close() {
        this.model = null;
        this.rvDB = null;
        this.config = null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setLabeledRandomVariables() {
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            entry.getKey().setValue(entry.getValue().getValue());
        }
    }
}
