package de.uni_mannheim.informatik.dws.winter.matching.rules;

import de.uni_mannheim.informatik.dws.winter.model.Correspondence;
import de.uni_mannheim.informatik.dws.winter.model.Matchable;
import de.uni_mannheim.informatik.dws.winter.model.Performance;
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.Attribute;
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.FeatureVectorDataSet;
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.Record;
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.comparators.RecordComparator;
import de.uni_mannheim.informatik.dws.winter.processing.Processable;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import weka.attributeSelection.AttributeSelection;
import weka.attributeSelection.GreedyStepwise;
import weka.attributeSelection.WrapperSubsetEval;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.Evaluation;
import weka.core.DenseInstance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.pmml.PMMLFactory;
import weka.gui.Logger;

/* loaded from: input_file:de/uni_mannheim/informatik/dws/winter/matching/rules/WekaMatchingRule.class */
public class WekaMatchingRule<RecordType extends Matchable, SchemaElementType extends Matchable> extends FilteringMatchingRule<RecordType, SchemaElementType> implements LearnableMatchingRule<RecordType, SchemaElementType> {
    private static final long serialVersionUID = 1;
    private String[] parameters;
    private Classifier classifier;
    private List<Comparator<RecordType, SchemaElementType>> comparators;
    private boolean forwardSelection;
    private boolean backwardSelection;
    private AttributeSelection fs;
    public final String trainingSet = "trainingSet";
    public final String machtSet = "matchSet";

    public WekaMatchingRule(double d, String str, String[] strArr) {
        super(d);
        this.forwardSelection = false;
        this.backwardSelection = false;
        this.trainingSet = "trainingSet";
        this.machtSet = "matchSet";
        this.parameters = strArr;
        try {
            this.classifier = (Classifier) Utils.forName(Classifier.class, str, strArr);
        } catch (Exception e) {
            e.printStackTrace();
        }
        this.comparators = new LinkedList();
    }

    public String[] getparameters() {
        return this.parameters;
    }

    public void setparameters(String[] strArr) {
        this.parameters = strArr;
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    public void setClassifier(Classifier classifier) {
        this.classifier = classifier;
    }

    public void addComparator(Comparator<RecordType, SchemaElementType> comparator) {
        this.comparators.add(comparator);
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule
    public Performance learnParameters(FeatureVectorDataSet featureVectorDataSet) {
        Instances transformToWeka = transformToWeka(featureVectorDataSet, "trainingSet");
        try {
            Evaluation evaluation = new Evaluation(transformToWeka);
            if (this.forwardSelection || this.backwardSelection) {
                GreedyStepwise greedyStepwise = new GreedyStepwise();
                greedyStepwise.setSearchBackwards(this.backwardSelection);
                this.fs = new AttributeSelection();
                WrapperSubsetEval wrapperSubsetEval = new WrapperSubsetEval();
                wrapperSubsetEval.buildEvaluator(transformToWeka);
                wrapperSubsetEval.setClassifier(this.classifier);
                wrapperSubsetEval.setFolds(10);
                wrapperSubsetEval.setThreshold(0.01d);
                this.fs.setEvaluator(wrapperSubsetEval);
                this.fs.setSearch(greedyStepwise);
                this.fs.SelectAttributes(transformToWeka);
                transformToWeka = this.fs.reduceDimensionality(transformToWeka);
            }
            evaluation.crossValidateModel(this.classifier, transformToWeka, 10, new Random(serialVersionUID), new Object[0]);
            System.out.println(evaluation.toSummaryString("\nResults\n\n", false));
            this.classifier.buildClassifier(transformToWeka);
            int numTruePositives = (int) evaluation.numTruePositives(transformToWeka.classIndex());
            return new Performance(numTruePositives, numTruePositives + ((int) evaluation.numFalsePositives(transformToWeka.classIndex())), numTruePositives + ((int) evaluation.numFalseNegatives(transformToWeka.classIndex())));
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    public Instances transformToWeka(FeatureVectorDataSet featureVectorDataSet, String str) {
        Instances defineDataset = defineDataset(featureVectorDataSet, str);
        for (RecordType recordtype : featureVectorDataSet.get()) {
            Collection<Attribute> collection = featureVectorDataSet.getSchema().get();
            double[] dArr = new double[collection.size()];
            int i = 0;
            for (Attribute attribute : collection) {
                if (!attribute.equals(FeatureVectorDataSet.ATTRIBUTE_LABEL)) {
                    String value = recordtype.getValue(attribute);
                    if (value != null) {
                        dArr[i] = Double.parseDouble(value);
                    } else {
                        dArr[i] = 0.0d;
                    }
                    i++;
                }
            }
            DenseInstance denseInstance = new DenseInstance(1.0d, dArr);
            if (str.equals("trainingSet")) {
                dArr[i] = defineDataset.attribute(i).indexOfValue(recordtype.getValue(FeatureVectorDataSet.ATTRIBUTE_LABEL));
            }
            defineDataset.add(denseInstance);
        }
        return defineDataset;
    }

    private Instances defineDataset(FeatureVectorDataSet featureVectorDataSet, String str) {
        ArrayList arrayList = new ArrayList();
        for (Attribute attribute : featureVectorDataSet.getSchema().get()) {
            if (!attribute.equals(FeatureVectorDataSet.ATTRIBUTE_LABEL)) {
                arrayList.add(new weka.core.Attribute(attribute.getIdentifier()));
            }
        }
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add("0");
        arrayList2.add("1");
        arrayList.add(new weka.core.Attribute("class", arrayList2));
        Instances instances = new Instances(str, arrayList, 0);
        instances.setClassIndex(arrayList.size() - 1);
        return instances;
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule
    public Record generateFeatures(RecordType recordtype, RecordType recordtype2, Processable<Correspondence<SchemaElementType, Matchable>> processable, FeatureVectorDataSet featureVectorDataSet) {
        Record record = new Record(String.format("%s-%s", recordtype.getIdentifier(), recordtype2.getIdentifier()), getClass().getSimpleName());
        for (int i = 0; i < this.comparators.size(); i++) {
            Comparator<RecordType, SchemaElementType> comparator = this.comparators.get(i);
            double compare = comparator.compare(recordtype, recordtype2, null);
            String str = "";
            String str2 = "";
            try {
                str = ((RecordComparator) comparator).getAttributeRecord1().toString();
                str2 = ((RecordComparator) comparator).getAttributeRecord2().toString();
            } catch (ClassCastException e) {
            }
            String format = String.format("[%d] %s %s %s", Integer.valueOf(i), comparator.getClass().getSimpleName(), str, str2);
            Attribute attribute = null;
            for (Attribute attribute2 : featureVectorDataSet.getSchema().get()) {
                if (attribute2.toString().equals(format)) {
                    attribute = attribute2;
                }
            }
            if (attribute == null) {
                attribute = new Attribute(format);
            }
            record.setValue(attribute, Double.toString(compare));
        }
        return record;
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.FilteringMatchingRule
    public Correspondence<RecordType, SchemaElementType> apply(RecordType recordtype, RecordType recordtype2, Processable<Correspondence<SchemaElementType, Matchable>> processable) {
        FeatureVectorDataSet initialiseFeatures = initialiseFeatures();
        initialiseFeatures.add((FeatureVectorDataSet) generateFeatures(recordtype, recordtype2, processable, initialiseFeatures));
        Instances transformToWeka = transformToWeka(initialiseFeatures, "matchSet");
        if ((this.backwardSelection || this.forwardSelection) && this.fs != null) {
            try {
                transformToWeka = this.fs.reduceDimensionality(transformToWeka);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        try {
            return new Correspondence<>(recordtype, recordtype2, this.classifier.classifyInstance(transformToWeka.firstInstance()), processable);
        } catch (Exception e2) {
            e2.printStackTrace();
            return null;
        }
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule
    public void storeModel(File file) {
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(file));
            objectOutputStream.writeObject(getClassifier());
            objectOutputStream.flush();
            objectOutputStream.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule
    public void readModel(File file) {
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(file));
            setClassifier((Classifier) objectInputStream.readObject());
            objectInputStream.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e2) {
            try {
                setClassifier(PMMLFactory.getPMMLModel(file, (Logger) null));
            } catch (Exception e3) {
                e3.printStackTrace();
            }
        } catch (ClassNotFoundException e4) {
            e4.printStackTrace();
        }
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.Comparator
    public double compare(RecordType recordtype, RecordType recordtype2, Correspondence<SchemaElementType, Matchable> correspondence) {
        return 0.0d;
    }

    @Override // de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule
    public FeatureVectorDataSet initialiseFeatures() {
        FeatureVectorDataSet featureVectorDataSet = new FeatureVectorDataSet();
        for (int i = 0; i < this.comparators.size(); i++) {
            Comparator<RecordType, SchemaElementType> comparator = this.comparators.get(i);
            String str = "";
            String str2 = "";
            try {
                str = ((RecordComparator) comparator).getAttributeRecord1().toString();
                str2 = ((RecordComparator) comparator).getAttributeRecord2().toString();
            } catch (ClassCastException e) {
            }
            featureVectorDataSet.addAttribute(new Attribute(String.format("[%d] %s %s %s", Integer.valueOf(i), comparator.getClass().getSimpleName(), str, str2)));
        }
        featureVectorDataSet.addAttribute(FeatureVectorDataSet.ATTRIBUTE_LABEL);
        return featureVectorDataSet;
    }

    public boolean isForwardSelection() {
        return this.forwardSelection;
    }

    public void setForwardSelection(boolean z) {
        this.forwardSelection = z;
    }

    public boolean isBackwardSelection() {
        return this.backwardSelection;
    }

    public void setBackwardSelection(boolean z) {
        this.backwardSelection = z;
    }
}
