package de.uni_mannheim.informatik.dws.winter.matching.rules;

import de.uni_mannheim.informatik.dws.winter.matching.algorithms.RuleLearner;
import de.uni_mannheim.informatik.dws.winter.model.Correspondence;
import de.uni_mannheim.informatik.dws.winter.model.DataSet;
import de.uni_mannheim.informatik.dws.winter.model.Matchable;
import de.uni_mannheim.informatik.dws.winter.model.MatchingGoldStandard;
import de.uni_mannheim.informatik.dws.winter.model.Performance;
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.Attribute;
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.FeatureVectorDataSet;
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.Record;
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.RecordCSVFormatter;
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.comparators.RecordComparator;
import de.uni_mannheim.informatik.dws.winter.processing.Processable;
import de.uni_mannheim.informatik.dws.winter.utils.WinterLogManager;
import de.uni_mannheim.informatik.dws.winter.utils.query.Q;
import de.uni_mannheim.informatik.dws.winter.utils.weka.EvaluationWithBalancing;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import org.apache.commons.lang.StringUtils;
import org.apache.logging.log4j.Logger;
import org.dom4j.Document;
import org.dom4j.DocumentException;
import org.dom4j.Element;
import org.dom4j.io.OutputFormat;
import org.dom4j.io.SAXReader;
import org.dom4j.io.XMLWriter;
import weka.attributeSelection.AttributeSelection;
import weka.attributeSelection.GreedyStepwise;
import weka.attributeSelection.WrapperSubsetEval;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.Evaluation;
import weka.core.DenseInstance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.pmml.PMMLFactory;
import weka.filters.Filter;
import weka.filters.supervised.instance.Resample;

/* loaded from: input_file:de/uni_mannheim/informatik/dws/winter/matching/rules/WekaMatchingRule.class */
public class WekaMatchingRule<RecordType extends Matchable, SchemaElementType extends Matchable> extends FilteringMatchingRule<RecordType, SchemaElementType> implements LearnableMatchingRule<RecordType, SchemaElementType> {
    private static final long serialVersionUID = 1;
    private String[] parameters;
    private Classifier classifier;
    private List<Comparator<RecordType, SchemaElementType>> comparators;
    private static final Logger logger = WinterLogManager.getLogger();
    private boolean forwardSelection;
    private boolean backwardSelection;
    private AttributeSelection fs;
    private boolean balanceTrainingData;
    private int randomSeed;
    public final String trainingSet = "trainingSet";
    public final String matchSet = "matchSet";

    public WekaMatchingRule(double d, String str, String[] strArr) {
        super(d);
        this.forwardSelection = false;
        this.backwardSelection = false;
        this.balanceTrainingData = false;
        this.randomSeed = -1;
        this.trainingSet = "trainingSet";
        this.matchSet = "matchSet";
        initialiseClassifier(str, strArr);
        this.comparators = new LinkedList();
    }

    public WekaMatchingRule(double d) {
        super(d);
        this.forwardSelection = false;
        this.backwardSelection = false;
        this.balanceTrainingData = false;
        this.randomSeed = -1;
        this.trainingSet = "trainingSet";
        this.matchSet = "matchSet";
        this.comparators = new LinkedList();
    }

    public String[] getparameters() {
        return this.parameters;
    }

    public void setparameters(String[] strArr) {
        this.parameters = strArr;
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    public void setClassifier(Classifier classifier) {
        this.classifier = classifier;
    }

    public void initialiseClassifier(String str, String[] strArr) {
        this.parameters = strArr;
        try {
            this.classifier = (Classifier) Utils.forName(Classifier.class, str, strArr);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void addComparator(Comparator<RecordType, SchemaElementType> comparator) {
        this.comparators.add(comparator);
        if (isDebugReportActive()) {
            comparator.setComparisonLog(new ComparatorLogger(comparator.getClass().getName()));
            addComparatorToLog(comparator);
        }
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule
    public Performance learnParameters(FeatureVectorDataSet featureVectorDataSet) {
        if (this.classifier == null) {
            logger.error("Please initialise a classifier!");
            return null;
        }
        Instances transformToWeka = transformToWeka(featureVectorDataSet, "trainingSet");
        try {
            if (this.forwardSelection || this.backwardSelection) {
                GreedyStepwise greedyStepwise = new GreedyStepwise();
                greedyStepwise.setSearchBackwards(this.backwardSelection);
                this.fs = new AttributeSelection();
                WrapperSubsetEval wrapperSubsetEval = new WrapperSubsetEval();
                wrapperSubsetEval.buildEvaluator(transformToWeka);
                wrapperSubsetEval.setClassifier(this.classifier);
                wrapperSubsetEval.setFolds(10);
                wrapperSubsetEval.setThreshold(0.01d);
                this.fs.setEvaluator(wrapperSubsetEval);
                this.fs.setSearch(greedyStepwise);
                this.fs.SelectAttributes(transformToWeka);
                transformToWeka = this.fs.reduceDimensionality(transformToWeka);
            }
            Evaluation evaluation = new Evaluation(transformToWeka);
            if (this.balanceTrainingData) {
                Resample resample = new Resample();
                if (this.randomSeed != -1) {
                    resample.setRandomSeed(this.randomSeed);
                }
                resample.setBiasToUniformClass(1.0d);
                resample.setInputFormat(transformToWeka);
                resample.setSampleSizePercent(100.0d);
                evaluation = new EvaluationWithBalancing(transformToWeka, resample);
            }
            evaluation.crossValidateModel(this.classifier, transformToWeka, Math.min(10, transformToWeka.size()), new Random(serialVersionUID), new Object[0]);
            for (String str : evaluation.toSummaryString("\nResults\n\n", false).split("\n")) {
                logger.info(str);
            }
            for (String str2 : evaluation.toClassDetailsString().split("\n")) {
                logger.info(str2);
            }
            for (String str3 : evaluation.toMatrixString().split("\n")) {
                logger.info(str3);
            }
            if (this.balanceTrainingData) {
                Resample resample2 = new Resample();
                if (this.randomSeed != -1) {
                    resample2.setRandomSeed(this.randomSeed);
                }
                resample2.setBiasToUniformClass(1.0d);
                resample2.setInputFormat(transformToWeka);
                resample2.setSampleSizePercent(100.0d);
                transformToWeka = Filter.useFilter(transformToWeka, resample2);
            }
            this.classifier.buildClassifier(transformToWeka);
            int indexOfValue = transformToWeka.attribute(transformToWeka.classIndex()).indexOfValue("1");
            int numTruePositives = (int) evaluation.numTruePositives(indexOfValue);
            return new Performance(numTruePositives, numTruePositives + ((int) evaluation.numFalsePositives(indexOfValue)), numTruePositives + ((int) evaluation.numFalseNegatives(indexOfValue)));
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    public Instances transformToWeka(FeatureVectorDataSet featureVectorDataSet, String str) {
        Instances defineDataset = defineDataset(featureVectorDataSet, str);
        for (RecordType recordtype : featureVectorDataSet.get()) {
            Collection<Attribute> collection = featureVectorDataSet.getSchema().get();
            double[] dArr = new double[collection.size()];
            int i = 0;
            for (Attribute attribute : collection) {
                if (!attribute.equals(FeatureVectorDataSet.ATTRIBUTE_LABEL)) {
                    String value = recordtype.getValue(attribute);
                    if (value != null) {
                        dArr[i] = Double.parseDouble(value);
                    } else {
                        dArr[i] = 0.0d;
                    }
                    i++;
                }
            }
            DenseInstance denseInstance = new DenseInstance(1.0d, dArr);
            if (str.equals("trainingSet")) {
                dArr[i] = defineDataset.attribute(i).indexOfValue(recordtype.getValue(FeatureVectorDataSet.ATTRIBUTE_LABEL));
            }
            defineDataset.add(denseInstance);
        }
        return defineDataset;
    }

    private Instances defineDataset(FeatureVectorDataSet featureVectorDataSet, String str) {
        ArrayList arrayList = new ArrayList();
        for (Attribute attribute : featureVectorDataSet.getSchema().get()) {
            if (!attribute.equals(FeatureVectorDataSet.ATTRIBUTE_LABEL)) {
                arrayList.add(new weka.core.Attribute(attribute.getIdentifier()));
            }
        }
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add("1");
        arrayList2.add("0");
        arrayList.add(new weka.core.Attribute(FeatureVectorDataSet.ATTRIBUTE_LABEL.getIdentifier(), arrayList2));
        Instances instances = new Instances(str, arrayList, 0);
        instances.setClassIndex(arrayList.size() - 1);
        return instances;
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule
    public Record generateFeatures(RecordType recordtype, RecordType recordtype2, Processable<Correspondence<SchemaElementType, Matchable>> processable, FeatureVectorDataSet featureVectorDataSet) {
        Record record = new Record(String.format("%s-%s", recordtype.getIdentifier(), recordtype2.getIdentifier()), getClass().getSimpleName());
        Record record2 = null;
        if (isDebugReportActive() && continueCollectDebugResults()) {
            record2 = initializeDebugRecord(recordtype, recordtype2, -1);
        }
        for (int i = 0; i < this.comparators.size(); i++) {
            Comparator<RecordType, SchemaElementType> comparator = this.comparators.get(i);
            double compare = comparator.compare(recordtype, recordtype2, processable != null ? getCorrespondenceForComparator(processable, recordtype, recordtype2, comparator) : null);
            String str = "";
            String str2 = "";
            try {
                str = ((RecordComparator) comparator).getAttributeRecord1().toString();
                str2 = ((RecordComparator) comparator).getAttributeRecord2().toString();
            } catch (ClassCastException e) {
            }
            String trim = String.format("[%d] %s %s %s", Integer.valueOf(i), getComparatorName(comparator), str, str2).trim();
            Attribute attribute = null;
            for (Attribute attribute2 : featureVectorDataSet.getSchema().get()) {
                if (attribute2.toString().equals(trim)) {
                    attribute = attribute2;
                }
            }
            if (attribute == null) {
                attribute = new Attribute(trim);
            }
            record.setValue(attribute, Double.toString(compare));
            if (isDebugReportActive() && continueCollectDebugResults()) {
                record2 = fillDebugRecord(record2, comparator, i);
                addDebugRecordShort(recordtype, recordtype2, comparator, i);
            }
        }
        if (isDebugReportActive() && continueCollectDebugResults()) {
            fillSimilarity(record2, null);
        }
        return record;
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.FilteringMatchingRule
    public Correspondence<RecordType, SchemaElementType> apply(RecordType recordtype, RecordType recordtype2, Processable<Correspondence<SchemaElementType, Matchable>> processable) {
        if (this.classifier == null) {
            logger.error("Please initialise a classifier!");
            return null;
        }
        FeatureVectorDataSet initialiseFeatures = initialiseFeatures();
        Record generateFeatures = generateFeatures(recordtype, recordtype2, processable, initialiseFeatures);
        initialiseFeatures.add((FeatureVectorDataSet) generateFeatures);
        Instances transformToWeka = transformToWeka(initialiseFeatures, "matchSet");
        if ((this.backwardSelection || this.forwardSelection) && this.fs != null) {
            try {
                transformToWeka = this.fs.reduceDimensionality(transformToWeka);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        try {
            double d = this.classifier.distributionForInstance(transformToWeka.firstInstance())[transformToWeka.attribute(transformToWeka.classIndex()).indexOfValue("1")];
            if (isDebugReportActive()) {
                fillSimilarity(recordtype, recordtype2, d);
            }
            return new Correspondence<>(recordtype, recordtype2, d, processable);
        } catch (Exception e2) {
            e2.printStackTrace();
            Logger logger2 = logger;
            Object[] objArr = new Object[2];
            objArr[0] = generateFeatures == null ? "null" : generateFeatures.toString();
            objArr[1] = e2.getMessage();
            logger2.error(String.format("Classifier Exception for Record '%s': %s", objArr));
            return null;
        }
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule
    public void exportModel(File file) {
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(file));
            objectOutputStream.writeObject(getClassifier());
            objectOutputStream.flush();
            objectOutputStream.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule
    public void readModel(File file) {
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(file));
            setClassifier((Classifier) objectInputStream.readObject());
            objectInputStream.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e2) {
            try {
                setClassifier(PMMLFactory.getPMMLModel(file, (weka.gui.Logger) null));
            } catch (Exception e3) {
                if (!e3.getMessage().contains("[TargetMetaInfo]")) {
                    e3.printStackTrace();
                } else {
                    transformPMMLModel(file);
                    readModel(file);
                }
            }
        } catch (ClassNotFoundException e4) {
            e4.printStackTrace();
        }
    }

    private void transformPMMLModel(File file) {
        try {
            Document read = new SAXReader().read(file);
            for (Element element : read.selectNodes("//*")) {
                if (element.getQualifiedName().equals("TargetValue") && element.attribute("priorProbability") == null) {
                    element.addAttribute("priorProbability", "0.50");
                }
                if (element.getQualifiedName().equals("Value") && element.attributeValue("value").equals("MISSING_VALUE")) {
                    element.detach();
                }
            }
            new XMLWriter(new FileOutputStream(file), OutputFormat.createPrettyPrint()).write(read);
            logger.info("PPML model transformed!");
        } catch (UnsupportedEncodingException e) {
            e.printStackTrace();
        } catch (DocumentException e2) {
            e2.printStackTrace();
        } catch (IOException e3) {
            e3.printStackTrace();
        }
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.Comparator
    public double compare(RecordType recordtype, RecordType recordtype2, Correspondence<SchemaElementType, Matchable> correspondence) {
        return 0.0d;
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule
    public FeatureVectorDataSet initialiseFeatures() {
        FeatureVectorDataSet featureVectorDataSet = new FeatureVectorDataSet();
        for (int i = 0; i < this.comparators.size(); i++) {
            Comparator<RecordType, SchemaElementType> comparator = this.comparators.get(i);
            String str = "";
            String str2 = "";
            try {
                str = ((RecordComparator) comparator).getAttributeRecord1().toString();
                str2 = ((RecordComparator) comparator).getAttributeRecord2().toString();
            } catch (ClassCastException e) {
            }
            featureVectorDataSet.addAttribute(new Attribute(String.format("[%d] %s %s %s", Integer.valueOf(i), getComparatorName(comparator), str, str2).trim()));
        }
        featureVectorDataSet.addAttribute(FeatureVectorDataSet.ATTRIBUTE_LABEL);
        return featureVectorDataSet;
    }

    protected String getComparatorName(Comparator<RecordType, SchemaElementType> comparator) {
        return comparator.getClass().getSimpleName();
    }

    public boolean isForwardSelection() {
        return this.forwardSelection;
    }

    public void setForwardSelection(boolean z) {
        this.forwardSelection = z;
    }

    public boolean isBackwardSelection() {
        return this.backwardSelection;
    }

    public void setBackwardSelection(boolean z) {
        this.backwardSelection = z;
    }

    public void setBalanceTrainingData(boolean z) {
        this.balanceTrainingData = z;
    }

    public void setRandomSeed(int i) {
        this.randomSeed = i;
    }

    public String getModelDescription() {
        return String.format("%s", this.classifier);
    }

    public String toString() {
        return String.format("WekaMatchingRule: p(match|%s)", StringUtils.join(Q.project(this.comparators, comparator -> {
            return comparator;
        }), ", "));
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule
    public void exportTrainingData(DataSet<RecordType, SchemaElementType> dataSet, DataSet<RecordType, SchemaElementType> dataSet2, MatchingGoldStandard matchingGoldStandard, File file) throws IOException {
        new RecordCSVFormatter().writeCSV(file, new RuleLearner().generateTrainingDataForLearning(dataSet, dataSet2, matchingGoldStandard, this, null), null);
    }
}
