package org.talend.dataquality.nlp;

import com.intel.ssg.bdt.nlp.CRF;
import com.intel.ssg.bdt.nlp.CRFModel;
import com.intel.ssg.bdt.nlp.Sequence;
import com.intel.ssg.bdt.nlp.Token;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.regex.Pattern;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.talend.dataquality.nlp.CRFLabeling;
import org.talend.dataquality.nlp.toolkit.ToolkitType;
import scala.Tuple2;

/* loaded from: input_file:org/talend/dataquality/nlp/NLPProcessing.class */
public class NLPProcessing implements Serializable {
    private static final long serialVersionUID = 1;
    private static Pattern delimiter = Pattern.compile("\t");
    private static FeatureConstructor fc;
    private JavaRDD<List<String>> otherFeatureRDD;
    private String[] template;
    private JavaRDD<List<String>> tokensRDD;
    private JavaRDD<List<String>> labelRDD;
    private JavaRDD<String> textRDD;

    /* loaded from: input_file:org/talend/dataquality/nlp/NLPProcessing$FilterFunction.class */
    public class FilterFunction implements Function<Tuple2<Tuple2<Tuple2<List<String>, List<String>>, List<String>>, Long>, Boolean> {
        private static final long serialVersionUID = 1;
        private long upperbound;
        private long lowerbound;
        private boolean isTrain;

        public FilterFunction(long j, long j2, boolean z) {
            this.upperbound = j;
            this.lowerbound = j2;
            this.isTrain = z;
        }

        public Boolean call(Tuple2<Tuple2<Tuple2<List<String>, List<String>>, List<String>>, Long> tuple2) throws Exception {
            if (this.isTrain) {
                return Boolean.valueOf(this.upperbound > ((Long) tuple2._2()).longValue() && ((Long) tuple2._2()).longValue() >= this.lowerbound);
            }
            return Boolean.valueOf(this.upperbound <= ((Long) tuple2._2()).longValue() || ((Long) tuple2._2()).longValue() < this.lowerbound);
        }
    }

    /* loaded from: input_file:org/talend/dataquality/nlp/NLPProcessing$GetElem.class */
    public class GetElem implements Function<Tuple2<Tuple2<Tuple2<List<String>, List<String>>, List<String>>, Long>, List<String>> {
        private static final long serialVersionUID = 1;
        private String toGet;

        public GetElem(String str) {
            this.toGet = str;
        }

        public List<String> call(Tuple2<Tuple2<Tuple2<List<String>, List<String>>, List<String>>, Long> tuple2) throws Exception {
            return "token".equals(this.toGet) ? (List) ((Tuple2) ((Tuple2) tuple2._1())._1())._1() : "otherfeature".equals(this.toGet) ? (List) ((Tuple2) ((Tuple2) tuple2._1())._1())._2() : (List) ((Tuple2) tuple2._1())._2();
        }
    }

    /* loaded from: input_file:org/talend/dataquality/nlp/NLPProcessing$Model.class */
    public static class Model implements Serializable {
        private static final long serialVersionUID = 1;
        private String m;
        private LinkedHashMap<String, List<String>> mostFrequentPredecessorMap;
        private String pl;
        private String[] schema;
        private ToolkitType tool;

        public Model(String str, LinkedHashMap<String, List<String>> linkedHashMap, String str2, String[] strArr, ToolkitType toolkitType) {
            this.m = str;
            this.mostFrequentPredecessorMap = linkedHashMap;
            this.pl = str2;
            this.schema = strArr;
            this.tool = toolkitType;
        }

        public Model() {
        }

        public CRFModel getModel() {
            return CRFModel.load(this.m);
        }

        public LinkedHashMap<String, List<String>> getMap() {
            return this.mostFrequentPredecessorMap;
        }

        public String getPipeline() {
            return this.pl;
        }

        public String[] getSchema() {
            return this.schema;
        }

        public ToolkitType getToolKit() {
            return this.tool;
        }
    }

    public NLPProcessing(FeatureConstructor featureConstructor, JavaRDD<List<String>> javaRDD, JavaRDD<List<String>> javaRDD2, JavaRDD<List<String>> javaRDD3, String[] strArr) {
        fc = featureConstructor;
        this.otherFeatureRDD = javaRDD2;
        this.tokensRDD = javaRDD;
        this.labelRDD = javaRDD3;
        this.template = strArr;
    }

    public NLPProcessing(FeatureConstructor featureConstructor, JavaRDD<List<String>> javaRDD, JavaRDD<List<String>> javaRDD2, JavaRDD<String> javaRDD3) {
        fc = featureConstructor;
        this.tokensRDD = javaRDD;
        this.otherFeatureRDD = javaRDD2;
        this.textRDD = javaRDD3;
    }

    public void setText(JavaRDD<String> javaRDD) {
        this.textRDD = javaRDD;
    }

    /* JADX WARN: Type inference failed for: r0v7, types: [com.intel.ssg.bdt.nlp.Token[], com.intel.ssg.bdt.nlp.Token[][]] */
    public static Token[][] chunkArray(Token[] tokenArr, int i) {
        int ceil = (int) Math.ceil(tokenArr.length / i);
        ?? r0 = new Token[ceil];
        for (int i2 = 0; i2 < ceil; i2++) {
            int i3 = i2 * i;
            int min = Math.min(tokenArr.length - i3, i);
            Token[] tokenArr2 = new Token[min];
            System.arraycopy(tokenArr, i3, tokenArr2, 0, min);
            r0[i2] = tokenArr2;
        }
        return r0;
    }

    public JavaRDD<Sequence> featureConstruction(JavaRDD<List<String>> javaRDD, JavaRDD<List<String>> javaRDD2, JavaRDD<List<String>> javaRDD3) throws IOException {
        return fc.constructFeatures(javaRDD, javaRDD2, javaRDD3).zip(javaRDD3).map(new Function<Tuple2<List<String>, List<String>>, Sequence>() { // from class: org.talend.dataquality.nlp.NLPProcessing.2
            private static final long serialVersionUID = 1;

            public Sequence call(Tuple2<List<String>, List<String>> tuple2) {
                ArrayList arrayList = new ArrayList();
                for (int i = 0; i < ((List) tuple2._1()).size(); i++) {
                    arrayList.add(new Token((String) ((List) tuple2._2()).get(i), NLPProcessing.delimiter.split((CharSequence) ((List) tuple2._1()).get(i))));
                }
                return new Sequence((Token[]) arrayList.toArray(new Token[arrayList.size()]));
            }
        }).filter(new Function<Sequence, Boolean>() { // from class: org.talend.dataquality.nlp.NLPProcessing.1
            private static final long serialVersionUID = 1;

            public Boolean call(Sequence sequence) {
                return Boolean.valueOf(sequence.toArray().length != 0);
            }
        });
    }

    public Model train() throws InstantiationException, IllegalAccessException, ClassNotFoundException, IOException {
        return new Model(CRFModel.save(CRF.train(this.template, featureConstruction(this.tokensRDD, this.otherFeatureRDD, this.labelRDD).rdd(), 0.25d, 2, 300, 1.0E-4d)), fc.getPredecessorMap(), fc.getPipeline(), this.template, fc.getToolkit());
    }

    public JavaPairRDD<String, String> predict(Model model) throws InstantiationException, IllegalAccessException, ClassNotFoundException, IOException {
        fc.setMostFrequentPredecessorMap(model.getMap());
        return fc.predict(this.tokensRDD, this.textRDD, this.otherFeatureRDD, model.getModel());
    }

    public void evaluateCV(int i) throws InstantiationException, IllegalAccessException, ClassNotFoundException, IOException {
        ArrayList<CRFLabeling.Scores> arrayList = new ArrayList();
        HashSet<String> hashSet = new HashSet();
        JavaPairRDD zipWithIndex = this.tokensRDD.zip(this.otherFeatureRDD).zip(this.labelRDD).zipWithIndex();
        long count = zipWithIndex.count();
        long j = count / i;
        int i2 = 0;
        while (i2 < i) {
            long j2 = i2 == i - 1 ? count : (i2 + 1) * j;
            long j3 = i2 * j;
            JavaPairRDD filter = zipWithIndex.filter(new FilterFunction(j2, j3, true));
            JavaPairRDD filter2 = zipWithIndex.filter(new FilterFunction(j2, j3, false));
            List<CRFLabeling.Scores> evaluation = fc.evaluation(filter2.map(new GetElem("token")), filter2.map(new GetElem("otherfeature")), filter2.map(new GetElem("label")), CRF.train(this.template, featureConstruction(filter.map(new GetElem("token")), filter.map(new GetElem("otherfeature")), filter.map(new GetElem("label"))).rdd(), 0.25d, 2, 300, 1.0E-4d));
            if (i2 == 0) {
                Iterator<CRFLabeling.Scores> it = evaluation.iterator();
                while (it.hasNext()) {
                    hashSet.add(it.next().getClassName());
                }
            }
            arrayList.addAll(evaluation);
            i2++;
        }
        System.out.println("\n" + i + " fold cross validation\n\n");
        for (String str : hashSet) {
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (CRFLabeling.Scores scores : arrayList) {
                if (scores.getClassName().equals(str)) {
                    d += scores.getTruePositive();
                    d2 += scores.getPredictedTrue();
                    d3 += scores.getLabelTrue();
                }
            }
            double d4 = d / d2;
            double d5 = d / d3;
            System.out.println("Class Name : " + str);
            System.out.println("True Positive : " + ((int) d) + "\tPredicted True : " + ((int) d2) + "\tLabeled True : " + ((int) d3));
            System.out.println("Precision\t:\t" + d4 + "\nRecall\t\t:\t" + d5 + "\nF1 Score\t:\t" + (2.0d * ((d5 * d4) / (d5 + d4))) + "\n");
        }
    }
}
