package de.uni_mannheim.informatik.dws.winter.matching.algorithms;

import de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule;
import de.uni_mannheim.informatik.dws.winter.model.Correspondence;
import de.uni_mannheim.informatik.dws.winter.model.DataSet;
import de.uni_mannheim.informatik.dws.winter.model.Matchable;
import de.uni_mannheim.informatik.dws.winter.model.MatchingGoldStandard;
import de.uni_mannheim.informatik.dws.winter.model.Performance;
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.Attribute;
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.FeatureVectorDataSet;
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.Record;
import de.uni_mannheim.informatik.dws.winter.processing.Processable;
import de.uni_mannheim.informatik.dws.winter.processing.parallel.ParallelProcessableCollection;
import de.uni_mannheim.informatik.dws.winter.utils.ProgressReporter;
import de.uni_mannheim.informatik.dws.winter.utils.WinterLogManager;
import de.uni_mannheim.informatik.dws.winter.utils.query.Q;
import edu.stanford.nlp.util.StringUtils;
import java.lang.invoke.SerializedLambda;
import java.time.Duration;
import java.time.LocalDateTime;
import java.util.Iterator;
import java.util.LinkedList;
import org.apache.commons.lang3.time.DurationFormatUtils;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/uni_mannheim/informatik/dws/winter/matching/algorithms/RuleLearner.class */
public class RuleLearner<RecordType extends Matchable, SchemaElementType extends Matchable> {
    private static final Logger logger = WinterLogManager.getLogger();

    public Performance learnMatchingRule(DataSet<RecordType, SchemaElementType> dataSet, DataSet<RecordType, SchemaElementType> dataSet2, Processable<? extends Correspondence<SchemaElementType, ?>> processable, LearnableMatchingRule<RecordType, SchemaElementType> learnableMatchingRule, MatchingGoldStandard matchingGoldStandard) {
        return learnMatchingRule(dataSet, dataSet2, processable, learnableMatchingRule, matchingGoldStandard, false);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Performance learnMatchingRule(DataSet<RecordType, SchemaElementType> dataSet, DataSet<RecordType, SchemaElementType> dataSet2, Processable<? extends Correspondence<SchemaElementType, ?>> processable, LearnableMatchingRule<RecordType, SchemaElementType> learnableMatchingRule, MatchingGoldStandard matchingGoldStandard, boolean z) {
        FeatureVectorDataSet generateTrainingDataForLearning = generateTrainingDataForLearning(dataSet, dataSet2, matchingGoldStandard, learnableMatchingRule, processable);
        if (z) {
            generateTrainingDataForLearning = deduplicateFeatureDataSet(generateTrainingDataForLearning);
        }
        return learnableMatchingRule.learnParameters(generateTrainingDataForLearning);
    }

    public FeatureVectorDataSet deduplicateFeatureDataSet(FeatureVectorDataSet featureVectorDataSet) {
        FeatureVectorDataSet featureVectorDataSet2 = new FeatureVectorDataSet();
        LinkedList<Attribute> linkedList = new LinkedList();
        for (Attribute attribute : featureVectorDataSet.getSchema().get()) {
            linkedList.add(attribute);
            featureVectorDataSet2.addAttribute(attribute);
        }
        for (RecordType recordtype : featureVectorDataSet.get()) {
            Record record = new Record(StringUtils.join(Q.project(linkedList, attribute2 -> {
                return recordtype.getValue(attribute2);
            })));
            for (Attribute attribute3 : linkedList) {
                record.setValue(attribute3, recordtype.getValue(attribute3));
            }
            featureVectorDataSet2.add((FeatureVectorDataSet) record);
        }
        logger.info(String.format("Deduplication removed %d/%d examples.", Integer.valueOf(featureVectorDataSet.size() - featureVectorDataSet2.size()), Integer.valueOf(featureVectorDataSet.size())));
        return featureVectorDataSet2;
    }

    public FeatureVectorDataSet generateTrainingDataForLearning(DataSet<RecordType, SchemaElementType> dataSet, DataSet<RecordType, SchemaElementType> dataSet2, MatchingGoldStandard matchingGoldStandard, LearnableMatchingRule<RecordType, SchemaElementType> learnableMatchingRule, Processable<? extends Correspondence<SchemaElementType, ? extends Matchable>> processable) {
        LocalDateTime now = LocalDateTime.now();
        FeatureVectorDataSet initialiseFeatures = learnableMatchingRule.initialiseFeatures();
        matchingGoldStandard.printBalanceReport();
        logger.info(String.format("Starting GenerateFeatures", now.toString()));
        new ProgressReporter(matchingGoldStandard.getPositiveExamples().size() + matchingGoldStandard.getNegativeExamples().size(), "GenerateFeatures");
        Processable<OutputRecordType> map = new ParallelProcessableCollection(matchingGoldStandard.getPositiveExamples()).map(pair -> {
            RecordType record = dataSet.getRecord((String) pair.getFirst());
            RecordType record2 = dataSet2.getRecord((String) pair.getSecond());
            if (record == null && record2 == null) {
                record = dataSet2.getRecord((String) pair.getFirst());
                record2 = dataSet.getRecord((String) pair.getSecond());
            }
            if (record == null || record2 == null) {
                return null;
            }
            Record generateFeatures = learnableMatchingRule.generateFeatures(record, record2, Correspondence.toMatchable(processable), initialiseFeatures);
            generateFeatures.setValue(FeatureVectorDataSet.ATTRIBUTE_LABEL, "1");
            return generateFeatures;
        });
        Processable<OutputRecordType> map2 = new ParallelProcessableCollection(matchingGoldStandard.getNegativeExamples()).map(pair2 -> {
            RecordType record = dataSet.getRecord((String) pair2.getFirst());
            RecordType record2 = dataSet2.getRecord((String) pair2.getSecond());
            if (record == null && record2 == null) {
                record = dataSet2.getRecord((String) pair2.getFirst());
                record2 = dataSet.getRecord((String) pair2.getSecond());
            }
            if (record == null || record2 == null) {
                return null;
            }
            Record generateFeatures = learnableMatchingRule.generateFeatures(record, record2, Correspondence.toMatchable(processable), initialiseFeatures);
            generateFeatures.setValue(FeatureVectorDataSet.ATTRIBUTE_LABEL, "0");
            return generateFeatures;
        });
        Iterator it = map.get().iterator();
        while (it.hasNext()) {
            initialiseFeatures.add((FeatureVectorDataSet) it.next());
        }
        Iterator it2 = map2.get().iterator();
        while (it2.hasNext()) {
            initialiseFeatures.add((FeatureVectorDataSet) it2.next());
        }
        logger.info(String.format("GenerateFeatures finished after %s; created %,d examples.", DurationFormatUtils.formatDurationHMS(Duration.between(now, LocalDateTime.now()).toMillis()), Integer.valueOf(initialiseFeatures.size())));
        return initialiseFeatures;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1647345004:
                if (implMethodName.equals("lambda$1")) {
                    z = false;
                    break;
                }
                break;
            case -1647345003:
                if (implMethodName.equals("lambda$2")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("de/uni_mannheim/informatik/dws/winter/processing/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("execute") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("de/uni_mannheim/informatik/dws/winter/matching/algorithms/RuleLearner") && serializedLambda.getImplMethodSignature().equals("(Lde/uni_mannheim/informatik/dws/winter/model/DataSet;Lde/uni_mannheim/informatik/dws/winter/model/DataSet;Lde/uni_mannheim/informatik/dws/winter/matching/rules/LearnableMatchingRule;Lde/uni_mannheim/informatik/dws/winter/processing/Processable;Lde/uni_mannheim/informatik/dws/winter/model/defaultmodel/FeatureVectorDataSet;Lde/uni_mannheim/informatik/dws/winter/model/Pair;)Lde/uni_mannheim/informatik/dws/winter/model/defaultmodel/Record;")) {
                    DataSet dataSet = (DataSet) serializedLambda.getCapturedArg(0);
                    DataSet dataSet2 = (DataSet) serializedLambda.getCapturedArg(1);
                    LearnableMatchingRule learnableMatchingRule = (LearnableMatchingRule) serializedLambda.getCapturedArg(2);
                    Processable processable = (Processable) serializedLambda.getCapturedArg(3);
                    FeatureVectorDataSet featureVectorDataSet = (FeatureVectorDataSet) serializedLambda.getCapturedArg(4);
                    return pair -> {
                        RecordType record = dataSet.getRecord((String) pair.getFirst());
                        RecordType record2 = dataSet2.getRecord((String) pair.getSecond());
                        if (record == null && record2 == null) {
                            record = dataSet2.getRecord((String) pair.getFirst());
                            record2 = dataSet.getRecord((String) pair.getSecond());
                        }
                        if (record == null || record2 == null) {
                            return null;
                        }
                        Record generateFeatures = learnableMatchingRule.generateFeatures(record, record2, Correspondence.toMatchable(processable), featureVectorDataSet);
                        generateFeatures.setValue(FeatureVectorDataSet.ATTRIBUTE_LABEL, "1");
                        return generateFeatures;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("de/uni_mannheim/informatik/dws/winter/processing/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("execute") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("de/uni_mannheim/informatik/dws/winter/matching/algorithms/RuleLearner") && serializedLambda.getImplMethodSignature().equals("(Lde/uni_mannheim/informatik/dws/winter/model/DataSet;Lde/uni_mannheim/informatik/dws/winter/model/DataSet;Lde/uni_mannheim/informatik/dws/winter/matching/rules/LearnableMatchingRule;Lde/uni_mannheim/informatik/dws/winter/processing/Processable;Lde/uni_mannheim/informatik/dws/winter/model/defaultmodel/FeatureVectorDataSet;Lde/uni_mannheim/informatik/dws/winter/model/Pair;)Lde/uni_mannheim/informatik/dws/winter/model/defaultmodel/Record;")) {
                    DataSet dataSet3 = (DataSet) serializedLambda.getCapturedArg(0);
                    DataSet dataSet4 = (DataSet) serializedLambda.getCapturedArg(1);
                    LearnableMatchingRule learnableMatchingRule2 = (LearnableMatchingRule) serializedLambda.getCapturedArg(2);
                    Processable processable2 = (Processable) serializedLambda.getCapturedArg(3);
                    FeatureVectorDataSet featureVectorDataSet2 = (FeatureVectorDataSet) serializedLambda.getCapturedArg(4);
                    return pair2 -> {
                        RecordType record = dataSet3.getRecord((String) pair2.getFirst());
                        RecordType record2 = dataSet4.getRecord((String) pair2.getSecond());
                        if (record == null && record2 == null) {
                            record = dataSet4.getRecord((String) pair2.getFirst());
                            record2 = dataSet3.getRecord((String) pair2.getSecond());
                        }
                        if (record == null || record2 == null) {
                            return null;
                        }
                        Record generateFeatures = learnableMatchingRule2.generateFeatures(record, record2, Correspondence.toMatchable(processable2), featureVectorDataSet2);
                        generateFeatures.setValue(FeatureVectorDataSet.ATTRIBUTE_LABEL, "0");
                        return generateFeatures;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
