package edu.umd.cs.psl.application.topicmodel;

import edu.umd.cs.psl.application.ModelApplication;
import edu.umd.cs.psl.application.topicmodel.kernel.LDAgroundLogLoss;
import edu.umd.cs.psl.application.topicmodel.reasoner.admm.LatentTopicNetworkADMMReasoner;
import edu.umd.cs.psl.application.util.GroundKernels;
import edu.umd.cs.psl.application.util.Grounding;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.DataStore;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.database.Partition;
import edu.umd.cs.psl.database.loading.Inserter;
import edu.umd.cs.psl.evaluation.result.FullInferenceResult;
import edu.umd.cs.psl.evaluation.result.memory.MemoryFullInferenceResult;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.argument.GroundTerm;
import edu.umd.cs.psl.model.argument.UniqueID;
import edu.umd.cs.psl.model.atom.GroundAtom;
import edu.umd.cs.psl.model.atom.PersistedAtomManager;
import edu.umd.cs.psl.model.atom.RandomVariableAtom;
import edu.umd.cs.psl.model.parameters.PositiveWeight;
import edu.umd.cs.psl.model.predicate.Predicate;
import edu.umd.cs.psl.model.predicate.PredicateFactory;
import edu.umd.cs.psl.model.predicate.StandardPredicate;
import edu.umd.cs.psl.reasoner.Reasoner;
import edu.umd.cs.psl.util.database.Queries;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/application/topicmodel/LatentTopicNetwork.class */
public class LatentTopicNetwork implements ModelApplication {
    private static final Logger log = LoggerFactory.getLogger(LatentTopicNetwork.class);
    public static final String CONFIG_PREFIX = "latentTopicNetworks";
    public static final String HINGE_LOSS_THETA_KEY = "latentTopicNetworks.hingeLossTheta";
    public static final boolean HINGE_LOSS_THETA_DEFAULT = true;
    public static final String HINGE_LOSS_PHI_KEY = "latentTopicNetworks.hingeLossPhi";
    public static final boolean HINGE_LOSS_PHI_DEFAULT = false;
    public static final String NUM_ITERATIONS_KEY = "latentTopicNetworks.numIterations";
    public static final int NUM_ITERATIONS_DEFAULT = 200;
    public static final String NUM_BURNIN_KEY = "latentTopicNetworks.numBurnIn";
    public static final int NUM_BURNIN_DEFAULT = 0;
    public static final String NUM_TOPICS_KEY = "latentTopicNetworks.numTopics";
    public static final int NUM_TOPICS_DEFAULT = 20;
    public static final String ALPHA_KEY = "latentTopicNetworks.alpha";
    public static final double ALPHA_DEFAULT = 1.01d;
    public static final String BETA_KEY = "latentTopicNetworks.beta";
    public static final double BETA_DEFAULT = 1.01d;
    public static final String WEIGHT_LEARNING_KEY = "latentTopicNetworks.weightLearning";
    public static final boolean WEIGHT_LEARNING_DEFAULT = false;
    public static final String FIRST_W_LEARNING_ITER_KEY = "latentTopicNetworks.firstWLearningIter";
    public static final int FIRST_W_LEARNING_ITER_DEFAULT = 50;
    public static final String W_LEARNING_GAP_KEY = "latentTopicNetworks.WLearningGap";
    public static final int W_LEARNING_GAP_DEFAULT = 10;
    public static final String INIT_MSTEP_TO_LDA_THETA_KEY = "latentTopicNetworks.initMStepToLDAtheta";
    public static final boolean INIT_MSTEP_TO_LDA_THETA_DEFAULT = false;
    public static final String INIT_MSTEP_TO_LDA_PHI_KEY = "latentTopicNetworks.initMStepToLDAphi";
    public static final boolean INIT_MSTEP_TO_LDA_PHI_DEFAULT = true;
    public static final String SAVE_DIR_KEY = "latentTopicNetworks.saveDir";
    public static final String SAVE_DIR_DEFAULT = "";
    private Model modelTheta;
    private Model modelPhi;
    private Database dbTheta;
    private Database dbPhi;
    LatentTopicNetworkADMMReasoner reasonerTheta;
    LatentTopicNetworkADMMReasoner reasonerPhi;
    PersistedAtomManager atomManagerTheta;
    PersistedAtomManager atomManagerPhi;
    private ConfigBundle config;
    int numDocuments;
    int numTopics;
    int numWords;
    int numIterations;
    int burnIn;
    double[][] expectedCountsTheta;
    double[][] expectedCountsPhi;
    double[][] theta;
    double[][] phi;
    final boolean hingeLossTheta;
    final boolean hingeLossPhi;
    final boolean initMStepToLDAtheta;
    final boolean initMStepToLDAphi;
    boolean doWeightLearning;
    StandardPredicate[] X;
    StandardPredicate[] Y;
    StandardPredicate[] Z;
    int firstWLearningIter;
    int wLearningGap;
    DataStore dataStore;
    Database rvDBweightLearningTheta;
    Database observedDBweightLearningTheta;
    Database rvDBweightLearningPhi;
    Database observedDBweightLearningPhi;
    Partition rvPartitionTheta;
    Partition labelPartitionTheta;
    Partition rvPartitionPhi;
    Partition labelPartitionPhi;
    double alpha;
    double beta;
    int[][] docWords;
    int[][] docCounts;
    final String saveDir;

    public LatentTopicNetwork(Model model, Database database, Model model2, Database database2, int[][] iArr, int[][] iArr2, int i, String[] strArr, String[] strArr2, String[] strArr3, DataStore dataStore, ConfigBundle configBundle) {
        this(model, database, model2, database2, iArr, iArr2, i, configBundle);
        this.doWeightLearning = configBundle.getBoolean(WEIGHT_LEARNING_KEY, false);
        if (this.doWeightLearning) {
            this.firstWLearningIter = configBundle.getInt(FIRST_W_LEARNING_ITER_KEY, 50);
            this.wLearningGap = configBundle.getInt(W_LEARNING_GAP_KEY, 10);
            this.X = new StandardPredicate[strArr.length];
            for (int i2 = 0; i2 < strArr.length; i2++) {
                this.X[i2] = (StandardPredicate) PredicateFactory.getFactory().getPredicate(strArr[i2]);
            }
            this.Y = new StandardPredicate[strArr2.length];
            for (int i3 = 0; i3 < strArr2.length; i3++) {
                this.Y[i3] = (StandardPredicate) PredicateFactory.getFactory().getPredicate(strArr2[i3]);
            }
            this.Z = new StandardPredicate[strArr3.length];
            for (int i4 = 0; i4 < strArr3.length; i4++) {
                this.Z[i4] = (StandardPredicate) PredicateFactory.getFactory().getPredicate(strArr3[i4]);
            }
            this.dataStore = dataStore;
            this.rvPartitionTheta = dataStore.getNextPartition();
            this.labelPartitionTheta = new Partition(this.rvPartitionTheta.getID() + 1);
            this.rvPartitionPhi = new Partition(this.rvPartitionTheta.getID() + 2);
            this.labelPartitionPhi = new Partition(this.rvPartitionTheta.getID() + 3);
        }
    }

    public LatentTopicNetwork(Model model, Database database, Model model2, Database database2, int[][] iArr, int[][] iArr2, int i, ConfigBundle configBundle) {
        this.firstWLearningIter = Integer.MAX_VALUE;
        this.wLearningGap = Integer.MAX_VALUE;
        this.numIterations = configBundle.getInt(NUM_ITERATIONS_KEY, NUM_ITERATIONS_DEFAULT);
        if (this.numIterations <= 0) {
            throw new IllegalArgumentException("Number of iterations must be positive.");
        }
        this.burnIn = configBundle.getInt(NUM_BURNIN_KEY, 0);
        if (this.burnIn < 0) {
            throw new IllegalArgumentException("Number of burn-in iterations must be non-negative.");
        }
        this.hingeLossTheta = configBundle.getBoolean(HINGE_LOSS_THETA_KEY, true);
        this.hingeLossPhi = configBundle.getBoolean(HINGE_LOSS_PHI_KEY, false);
        this.numTopics = configBundle.getInt(NUM_TOPICS_KEY, 20);
        if (this.numTopics <= 0) {
            throw new IllegalArgumentException("Number of topics iterations must be positive");
        }
        this.alpha = configBundle.getDouble(ALPHA_KEY, 1.01d);
        if (this.alpha <= 0.0d) {
            throw new IllegalArgumentException("alpha must be positive.");
        }
        this.beta = configBundle.getDouble(BETA_KEY, 1.01d);
        if (this.beta <= 0.0d) {
            throw new IllegalArgumentException("beta must be positive.");
        }
        this.doWeightLearning = false;
        this.initMStepToLDAtheta = configBundle.getBoolean(INIT_MSTEP_TO_LDA_THETA_KEY, false);
        this.initMStepToLDAphi = configBundle.getBoolean(INIT_MSTEP_TO_LDA_PHI_KEY, true);
        this.saveDir = configBundle.getString(SAVE_DIR_KEY, SAVE_DIR_DEFAULT);
        this.modelTheta = model;
        this.dbTheta = database;
        this.modelPhi = model2;
        this.dbPhi = database2;
        this.config = configBundle;
        this.numWords = i;
        this.docWords = iArr;
        this.docCounts = iArr2;
        this.numDocuments = iArr.length;
        this.expectedCountsTheta = new double[this.numDocuments][this.numTopics];
        this.theta = new double[this.numDocuments][this.numTopics];
        this.expectedCountsPhi = new double[this.numTopics][i];
        this.phi = new double[this.numTopics][i];
    }

    public double[][] getTheta() {
        return this.theta;
    }

    public double[][] getPhi() {
        return this.phi;
    }

    public void trainModel() throws ClassNotFoundException, IllegalAccessException, InstantiationException, IOException {
        initialize();
        boolean z = true;
        for (int i = 0; i < this.numIterations; i++) {
            log.debug("Iteration " + i);
            eStep();
            try {
                if (i >= this.burnIn) {
                    mStep(z);
                    z = false;
                } else {
                    log.debug("Burn in, LDA M-step");
                    LdaMStep();
                }
            } catch (Exception e) {
                System.err.println("Unexpected error!");
                e.printStackTrace();
                System.exit(-1);
            }
            log.debug("Log-likelihood after Iteration " + i + ": " + logLikelihood());
            if (this.doWeightLearning && i >= this.firstWLearningIter && (i - this.firstWLearningIter) % this.wLearningGap == 0) {
                weightLearning();
            }
            if (this.saveDir.length() > 0) {
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(String.valueOf(this.saveDir) + "theta_iteration_" + i + ".ser"));
                objectOutputStream.writeObject(this.theta);
                objectOutputStream.flush();
                objectOutputStream.close();
                ObjectOutputStream objectOutputStream2 = new ObjectOutputStream(new FileOutputStream(String.valueOf(this.saveDir) + "phi_iteration_" + i + ".ser"));
                objectOutputStream2.writeObject(this.phi);
                objectOutputStream2.flush();
                objectOutputStream2.close();
            }
        }
    }

    protected void initialize() throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        for (int i = 0; i < this.numDocuments; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                this.theta[i][i2] = Math.random();
                d += this.theta[i][i2];
            }
            for (int i3 = 0; i3 < this.numTopics; i3++) {
                double[] dArr = this.theta[i];
                int i4 = i3;
                dArr[i4] = dArr[i4] / d;
            }
        }
        for (int i5 = 0; i5 < this.numTopics; i5++) {
            double d2 = 0.0d;
            for (int i6 = 0; i6 < this.numWords; i6++) {
                this.phi[i5][i6] = Math.random();
                d2 += this.phi[i5][i6];
            }
            for (int i7 = 0; i7 < this.numWords; i7++) {
                double[] dArr2 = this.phi[i5];
                int i8 = i7;
                dArr2[i8] = dArr2[i8] / d2;
            }
        }
        eStep();
        this.reasonerTheta = new LatentTopicNetworkADMMReasoner(this.config);
        this.reasonerPhi = new LatentTopicNetworkADMMReasoner(this.config);
        this.atomManagerTheta = new PersistedAtomManager(this.dbTheta);
        this.atomManagerPhi = new PersistedAtomManager(this.dbPhi);
        Predicate predicate = PredicateFactory.getFactory().getPredicate("Theta");
        if (this.hingeLossTheta) {
            initializeForReasoner(this.reasonerTheta, this.atomManagerTheta, this.modelTheta, this.dbTheta, this.expectedCountsTheta, this.theta, predicate);
        }
        Predicate predicate2 = PredicateFactory.getFactory().getPredicate("Phi");
        if (this.hingeLossPhi) {
            initializeForReasoner(this.reasonerPhi, this.atomManagerPhi, this.modelPhi, this.dbPhi, this.expectedCountsPhi, this.phi, predicate2);
        }
    }

    protected void initializeForReasoner(Reasoner reasoner, PersistedAtomManager persistedAtomManager, Model model, Database database, double[][] dArr, double[][] dArr2, Predicate predicate) {
        log.info("Grounding out model.");
        Grounding.groundAll(model, persistedAtomManager, reasoner);
        log.info("Adding log loss ground kernels");
        int length = dArr.length;
        int length2 = dArr[0].length;
        for (int i = 0; i < length; i++) {
            UniqueID uniqueID = database.getUniqueID(new Integer(i));
            ArrayList arrayList = new ArrayList(length2);
            ArrayList arrayList2 = new ArrayList(length2);
            for (int i2 = 0; i2 < length2; i2++) {
                GroundAtom atom = persistedAtomManager.getAtom(predicate, uniqueID, database.getUniqueID(new Integer(i2)));
                atom.getVariable().setValue(dArr2[i][i2]);
                arrayList.add(atom);
                arrayList2.add(Double.valueOf(dArr[i][i2]));
            }
            LDAgroundLogLoss lDAgroundLogLoss = new LDAgroundLogLoss(null, arrayList, arrayList2, dArr[i]);
            lDAgroundLogLoss.setWeight(new PositiveWeight(1.0d));
            reasoner.addGroundKernel(lDAgroundLogLoss);
        }
    }

    protected void eStep() {
        log.info("E-step");
        for (int i = 0; i < this.numDocuments; i++) {
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                this.expectedCountsTheta[i][i2] = this.alpha - 1.0d;
            }
        }
        for (int i3 = 0; i3 < this.numWords; i3++) {
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                this.expectedCountsPhi[i4][i3] = this.beta - 1.0d;
            }
        }
        double[] dArr = new double[this.numTopics];
        for (int i5 = 0; i5 < this.numDocuments; i5++) {
            for (int i6 = 0; i6 < this.docWords[i5].length; i6++) {
                int i7 = this.docWords[i5][i6];
                int i8 = this.docCounts[i5][i6];
                double d = 0.0d;
                for (int i9 = 0; i9 < this.numTopics; i9++) {
                    dArr[i9] = this.theta[i5][i9] * this.phi[i9][i7];
                    d += dArr[i9];
                }
                for (int i10 = 0; i10 < this.numTopics; i10++) {
                    double d2 = dArr[i10] / d;
                    if (Double.isNaN(d2)) {
                        log.debug("IsNan! " + i5 + " " + i6 + " " + i10 + " " + d2 + " " + d);
                        System.exit(-1);
                    }
                    double[] dArr2 = this.expectedCountsTheta[i5];
                    int i11 = i10;
                    dArr2[i11] = dArr2[i11] + (i8 * d2);
                    double[] dArr3 = this.expectedCountsPhi[i10];
                    dArr3[i7] = dArr3[i7] + (i8 * d2);
                }
            }
        }
    }

    protected void LdaMStep() {
        for (int i = 0; i < this.numDocuments; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                d += this.expectedCountsTheta[i][i2];
            }
            for (int i3 = 0; i3 < this.numTopics; i3++) {
                this.theta[i][i3] = this.expectedCountsTheta[i][i3] / d;
            }
        }
        for (int i4 = 0; i4 < this.numTopics; i4++) {
            double d2 = 0.0d;
            for (int i5 = 0; i5 < this.numWords; i5++) {
                d2 += this.expectedCountsPhi[i4][i5];
            }
            for (int i6 = 0; i6 < this.numWords; i6++) {
                this.phi[i4][i6] = this.expectedCountsPhi[i4][i6] / d2;
            }
        }
    }

    protected void mStep(boolean z) throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        log.info("LDA M-step");
        LdaMStep();
        if (this.hingeLossTheta) {
            log.info("M-Step. Inference for theta");
            if (!z && this.initMStepToLDAtheta) {
                this.reasonerTheta.initDirichletTerms();
            }
            Predicate predicate = PredicateFactory.getFactory().getPredicate("Theta");
            mpeInference(this.reasonerTheta, this.atomManagerTheta, this.modelTheta, this.dbTheta, this.expectedCountsTheta, this.theta, predicate);
            log.info("finished inference for theta");
            for (GroundAtom groundAtom : Queries.getAllAtoms(this.dbTheta, predicate)) {
                GroundTerm[] arguments = groundAtom.getArguments();
                this.theta[Integer.valueOf(arguments[0].toString()).intValue()][Integer.valueOf(arguments[1].toString()).intValue()] = groundAtom.getValue();
            }
            if (log.isDebugEnabled()) {
                log.debug("theta totals: ");
                for (int i = 0; i < 10; i++) {
                    double d = 0.0d;
                    for (int i2 = 0; i2 < this.numTopics; i2++) {
                        d += this.theta[i][i2];
                    }
                    log.debug(new StringBuilder().append(d).toString());
                }
                log.debug(SAVE_DIR_DEFAULT);
            }
        }
        if (this.hingeLossPhi) {
            log.info("inference for phi");
            if (!z && this.initMStepToLDAphi) {
                this.reasonerPhi.initDirichletTerms();
            }
            Predicate predicate2 = PredicateFactory.getFactory().getPredicate("Phi");
            mpeInference(this.reasonerPhi, this.atomManagerPhi, this.modelPhi, this.dbPhi, this.expectedCountsPhi, this.phi, predicate2);
            log.info("finished for phi");
            for (GroundAtom groundAtom2 : Queries.getAllAtoms(this.dbPhi, predicate2)) {
                GroundTerm[] arguments2 = groundAtom2.getArguments();
                this.phi[Integer.valueOf(arguments2[0].toString()).intValue()][Integer.valueOf(arguments2[1].toString()).intValue()] = groundAtom2.getValue();
            }
            if (log.isDebugEnabled()) {
                log.debug("phi totals ");
                for (int i3 = 0; i3 < 10; i3++) {
                    double d2 = 0.0d;
                    for (int i4 = 0; i4 < this.numWords; i4++) {
                        d2 += this.phi[i3][i4];
                    }
                    log.debug(new StringBuilder().append(d2).toString());
                }
                log.debug(SAVE_DIR_DEFAULT);
            }
            for (int i5 = 0; i5 < this.numTopics; i5++) {
                double d3 = 0.0d;
                for (int i6 = 0; i6 < this.numWords; i6++) {
                    d3 += this.phi[i5][i6];
                }
                for (int i7 = 0; i7 < this.numWords; i7++) {
                    double[] dArr = this.phi[i5];
                    int i8 = i7;
                    dArr[i8] = dArr[i8] / d3;
                }
            }
            if (log.isDebugEnabled()) {
                log.debug("entropy after normalizing");
                for (int i9 = 0; i9 < 10; i9++) {
                    double d4 = 0.0d;
                    for (int i10 = 0; i10 < this.numWords; i10++) {
                        d4 -= this.phi[i9][i10] * Math.log(this.phi[i9][i10]);
                    }
                    log.debug(String.valueOf(d4) + " ");
                }
                log.debug(SAVE_DIR_DEFAULT);
            }
        }
    }

    protected void weightLearning() throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        HashSet hashSet = new HashSet(Arrays.asList(this.X));
        HashSet hashSet2 = new HashSet(Arrays.asList(this.Z));
        if (this.hingeLossTheta) {
            log.info("Populating databases for Theta");
            StandardPredicate standardPredicate = (StandardPredicate) PredicateFactory.getFactory().getPredicate("Theta");
            for (int i = 0; i < this.X.length; i++) {
                updateVariablesForWeightLearning(this.rvPartitionTheta, this.X[i], this.dbTheta);
            }
            for (int i2 = 0; i2 < this.Y.length; i2++) {
                updateVariablesForWeightLearning(this.rvPartitionTheta, this.Y[i2], this.dbTheta);
                updateVariablesForWeightLearning(this.labelPartitionTheta, this.Y[i2], this.dbTheta);
            }
            for (int i3 = 0; i3 < this.Z.length; i3++) {
                updateVariablesForWeightLearning(this.rvPartitionTheta, this.Z[i3], this.dbTheta);
                updateVariablesForWeightLearning(this.labelPartitionTheta, this.Z[i3], this.dbTheta);
            }
            updateVariablesForWeightLearning(this.rvPartitionTheta, standardPredicate, this.dbTheta);
            updateVariablesForWeightLearning(this.labelPartitionTheta, standardPredicate, this.dbTheta);
            HashSet hashSet3 = new HashSet(Arrays.asList(this.Y));
            hashSet3.addAll(hashSet2);
            hashSet3.add(standardPredicate);
            this.rvDBweightLearningTheta = this.dataStore.getDatabase(this.rvPartitionTheta, hashSet, new Partition[0]);
            this.observedDBweightLearningTheta = this.dataStore.getDatabase(this.labelPartitionTheta, hashSet3, new Partition[0]);
            log.info("Running weight learning for Theta");
            new LatentTopicNetworkMaxPseudoLikelihood(this.modelTheta, this.rvDBweightLearningTheta, this.observedDBweightLearningTheta, this.config, this.alpha, standardPredicate).learn();
            this.rvDBweightLearningTheta.close();
            this.observedDBweightLearningTheta.close();
            this.dataStore.deletePartition(this.rvPartitionTheta);
            this.dataStore.deletePartition(this.labelPartitionTheta);
            log.info("Finished running weight learning for Theta");
            log.info(this.modelTheta.toString());
        }
        if (this.hingeLossPhi) {
            log.info("Populating databases for Phi");
            StandardPredicate standardPredicate2 = (StandardPredicate) PredicateFactory.getFactory().getPredicate("Phi");
            for (int i4 = 0; i4 < this.X.length; i4++) {
                updateVariablesForWeightLearning(this.rvPartitionPhi, this.X[i4], this.dbPhi);
            }
            for (int i5 = 0; i5 < this.Y.length; i5++) {
                updateVariablesForWeightLearning(this.rvPartitionPhi, this.Y[i5], this.dbPhi);
                updateVariablesForWeightLearning(this.labelPartitionPhi, this.Y[i5], this.dbPhi);
            }
            for (int i6 = 0; i6 < this.Z.length; i6++) {
                updateVariablesForWeightLearning(this.rvPartitionPhi, this.Z[i6], this.dbPhi);
                updateVariablesForWeightLearning(this.labelPartitionPhi, this.Z[i6], this.dbPhi);
            }
            updateVariablesForWeightLearning(this.rvPartitionPhi, standardPredicate2, this.dbPhi);
            updateVariablesForWeightLearning(this.labelPartitionPhi, standardPredicate2, this.dbPhi);
            HashSet hashSet4 = new HashSet(Arrays.asList(this.Y));
            hashSet4.addAll(hashSet2);
            hashSet4.add(standardPredicate2);
            this.rvDBweightLearningPhi = this.dataStore.getDatabase(this.rvPartitionPhi, hashSet, new Partition[0]);
            this.observedDBweightLearningPhi = this.dataStore.getDatabase(this.labelPartitionPhi, hashSet4, new Partition[0]);
            log.info("Running weight learning for Phi");
            new LatentTopicNetworkMaxPseudoLikelihood(this.modelPhi, this.rvDBweightLearningPhi, this.observedDBweightLearningPhi, this.config, this.beta, standardPredicate2).learn();
            this.rvDBweightLearningPhi.close();
            this.observedDBweightLearningPhi.close();
            this.dataStore.deletePartition(this.rvPartitionPhi);
            this.dataStore.deletePartition(this.labelPartitionPhi);
            log.info("Finished running weight learning for Phi");
            log.info(this.modelPhi.toString());
        }
    }

    protected void updateVariablesForWeightLearning(Partition partition, StandardPredicate standardPredicate, Database database) {
        Inserter inserter = this.dataStore.getInserter(standardPredicate, partition);
        Iterator<GroundAtom> it = Queries.getAllAtoms(database, standardPredicate).iterator();
        while (it.hasNext()) {
            GroundTerm[] arguments = it.next().getArguments();
            inserter.insertValue(database.getAtom(standardPredicate, arguments).getValue(), arguments);
        }
    }

    protected FullInferenceResult mpeInference(Reasoner reasoner, PersistedAtomManager persistedAtomManager, Model model, Database database, double[][] dArr, double[][] dArr2, Predicate predicate) throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        log.info("Beginning inference.");
        reasoner.optimize();
        log.info("Inference complete. Writing results to Database.");
        int i = 0;
        Iterator<RandomVariableAtom> it = persistedAtomManager.getPersistedRVAtoms().iterator();
        while (it.hasNext()) {
            it.next().commitToDB();
            i++;
        }
        return new MemoryFullInferenceResult(GroundKernels.getTotalWeightedIncompatibility(reasoner.getCompatibilityKernels()), GroundKernels.getInfeasibilityNorm(reasoner.getConstraintKernels()), i, reasoner.size());
    }

    @Override // edu.umd.cs.psl.application.ModelApplication
    public void close() {
        this.modelTheta = null;
        this.dbTheta = null;
        this.modelPhi = null;
        this.dbPhi = null;
        this.config = null;
        this.reasonerTheta.close();
        this.reasonerPhi.close();
    }

    public double logLikelihood() {
        double d = 0.0d;
        for (int i = 0; i < this.numDocuments; i++) {
            for (int i2 = 0; i2 < this.docWords[i].length; i2++) {
                int i3 = this.docWords[i][i2];
                int i4 = this.docCounts[i][i2];
                double d2 = 0.0d;
                for (int i5 = 0; i5 < this.numTopics; i5++) {
                    d2 += this.theta[i][i5] * this.phi[i5][i3];
                }
                d += i4 * Math.log(d2);
            }
        }
        return d;
    }
}
