package org.talend.dataquality.parsing.util;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/talend/dataquality/parsing/util/CrfEvaluator.class */
public class CrfEvaluator {
    private static List<List<String>> predictions = new ArrayList();
    private static List<List<String>> references = new ArrayList();
    private static boolean has_computed = false;
    private static Logger log = LoggerFactory.getLogger(CrfEvaluator.class);
    private static Map<String, TagWiseMetrics> tagWiseMetricsMap = new HashMap();
    private static OverallMetrics metrics = new OverallMetrics();

    public static void init(List<List<String>> list, List<List<String>> list2) {
        if (list == null || list.isEmpty()) {
            throw new IllegalArgumentException("NULL or empty list of predictions");
        }
        if (list2 == null || list2.isEmpty()) {
            throw new IllegalArgumentException("NULL or empty list of references");
        }
        if (list2.size() != list.size()) {
            throw new IllegalArgumentException("the size of prediction list should equals to the size of real tag list");
        }
        predictions.addAll(list);
        references.addAll(list2);
        metrics.setInst_total_num(references.size());
    }

    public static void reset() {
        has_computed = false;
        if (!predictions.isEmpty()) {
            predictions.clear();
        }
        if (!references.isEmpty()) {
            references.clear();
        }
        tagWiseMetricsMap.clear();
        metrics.setNum_tags(0);
        metrics.setItem_total_correct(0);
        metrics.setItem_total_num(0);
        metrics.setItem_total_observation(0);
        metrics.setItem_total_prediction(0);
        metrics.setItem_accuracy(0.0d);
        metrics.setInst_total_correct(0);
        metrics.setInst_total_num(0);
        metrics.setInst_accuracy(0.0d);
        metrics.setMacro_precision(0.0d);
        metrics.setMacro_recall(0.0d);
        metrics.setMacro_fmeasure(0.0d);
        metrics.setMicro_precision(0.0d);
        metrics.setMicro_recall(0.0d);
        metrics.setMicro_fmeasure(0.0d);
    }

    public static void compute() {
        for (int i = 0; i < metrics.getInst_total_num(); i++) {
            List<String> list = predictions.get(i);
            List<String> list2 = references.get(i);
            if (list.size() != list2.size()) {
                log.warn("Exclude instance " + i + ": {Predict tags = " + list + ", Reference tags = " + list2 + "}\nCaused by #Predicts not equals to #References (" + list.size() + " != " + list2.size() + ")");
            } else {
                accumulate(list, list2);
            }
        }
        if (tagWiseMetricsMap.size() == 0) {
            log.error("Exit! Impossible to compute the metrics, all instances are invalid.");
            return;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        metrics.setNum_tags(tagWiseMetricsMap.keySet().size());
        for (String str : tagWiseMetricsMap.keySet()) {
            TagWiseMetrics tagWiseMetrics = tagWiseMetricsMap.get(str);
            int num_correct = tagWiseMetrics.getNum_correct();
            int num_prediction = tagWiseMetrics.getNum_prediction();
            int num_observation = tagWiseMetrics.getNum_observation();
            metrics.setItem_total_correct(metrics.getItem_total_correct() + num_correct);
            metrics.setItem_total_prediction(metrics.getItem_total_prediction() + num_prediction);
            metrics.setItem_total_observation(metrics.getItem_total_observation() + num_observation);
            if (num_observation != 0) {
                double d4 = 0.0d;
                double d5 = 0.0d;
                if (num_prediction > 0) {
                    d4 = num_correct / num_prediction;
                    tagWiseMetrics.setPrecision(d4);
                }
                if (num_observation > 0) {
                    d5 = num_correct / num_observation;
                    tagWiseMetrics.setRecall(d5);
                }
                if (num_prediction + num_observation > 0) {
                    tagWiseMetrics.setFmeasure((2 * num_correct) / (num_prediction + num_observation));
                }
                d += metrics.getMacro_precision() + d4;
                d2 += metrics.getMacro_recall() + d5;
                d3 += metrics.getMacro_fmeasure() + tagWiseMetrics.getFmeasure();
                tagWiseMetricsMap.replace(str, tagWiseMetrics);
            }
        }
        metrics.setMacro_precision(d / metrics.getNum_tags());
        metrics.setMacro_recall(d2 / metrics.getNum_tags());
        metrics.setMacro_fmeasure(d3 / metrics.getNum_tags());
        if (metrics.getItem_total_prediction() > 0) {
            metrics.setMicro_precision(metrics.getItem_total_correct() / metrics.getItem_total_prediction());
        }
        if (metrics.getItem_total_observation() > 0) {
            metrics.setMicro_recall(metrics.getItem_total_correct() / metrics.getItem_total_observation());
        }
        if (metrics.getItem_total_prediction() + metrics.getItem_total_observation() > 0) {
            metrics.setMicro_fmeasure((2 * metrics.getItem_total_correct()) / (metrics.getItem_total_prediction() + metrics.getItem_total_observation()));
        }
        if (metrics.getItem_total_num() > 0) {
            metrics.setItem_accuracy(metrics.getItem_total_correct() / metrics.getItem_total_num());
        }
        if (metrics.getInst_total_num() > 0) {
            metrics.setInst_accuracy(metrics.getInst_total_correct() / metrics.getInst_total_num());
        }
        has_computed = true;
    }

    private static void accumulate(List<String> list, List<String> list2) {
        int i = 0;
        for (int i2 = 0; i2 < list2.size(); i2++) {
            String str = list2.get(i2);
            String str2 = list.get(i2);
            increaseTagWiseMetricsMap(str, "num_observation");
            increaseTagWiseMetricsMap(str2, "num_prediction");
            if (str.equals(str2)) {
                increaseTagWiseMetricsMap(str, "num_correct");
                i++;
            }
            metrics.setItem_total_num(metrics.getItem_total_num() + 1);
        }
        if (i == list2.size()) {
            metrics.setInst_total_correct(metrics.getInst_total_correct() + 1);
        }
    }

    private static void increaseTagWiseMetricsMap(String str, String str2) {
        TagWiseMetrics orDefault = tagWiseMetricsMap.getOrDefault(str, new TagWiseMetrics());
        orDefault.increaseNum(str2);
        if (tagWiseMetricsMap.containsKey(str)) {
            tagWiseMetricsMap.replace(str, orDefault);
        } else {
            tagWiseMetricsMap.put(str, orDefault);
        }
    }

    public static void output() {
        if (!has_computed) {
            log.error("The metrics haven't been computed or have been reset, please call Evaluator.compute()");
            return;
        }
        System.out.println(System.getProperty("line.separator") + "+---------------------------------------------------------------------------------------------------------+\n|                                            Model Performance                                            |\n+---------------------------------------------------------------------------------------------------------+\n");
        System.out.println(" Overall accuracy:");
        System.out.format(" - Item accuracy: %d / %d (%1.3f)%n", Integer.valueOf(metrics.getItem_total_correct()), Integer.valueOf(metrics.getItem_total_num()), Double.valueOf(metrics.getItem_accuracy()));
        System.out.format(" - Instance accuracy: %d / %d (%1.3f)%n", Integer.valueOf(metrics.getInst_total_correct()), Integer.valueOf(metrics.getInst_total_num()), Double.valueOf(metrics.getInst_accuracy()));
        System.out.println(System.getProperty("line.separator"));
        System.out.println(" Performance by label:\n ");
        System.out.format(" %15s%15s%15s%15s%15s%15s%15s%n", "   ", "Precision", "Recall", "F1-measure", "#Match", "#Predict", "#Ref");
        tagWiseMetricsMap.forEach((str, tagWiseMetrics) -> {
            if (tagWiseMetrics != null) {
                System.out.format(" %15s%15.3f%15.3f%15.3f%15d%15d%15d%n", "\"" + str + "\"", Double.valueOf(tagWiseMetrics.getPrecision()), Double.valueOf(tagWiseMetrics.getRecall()), Double.valueOf(tagWiseMetrics.getFmeasure()), Integer.valueOf(tagWiseMetrics.getNum_correct()), Integer.valueOf(tagWiseMetrics.getNum_prediction()), Integer.valueOf(tagWiseMetrics.getNum_observation()));
            }
        });
        System.out.format("%n %15s%15.3f%15.3f%15.3f%15s%15s%15s%n", "Macro-average", Double.valueOf(metrics.getMacro_precision()), Double.valueOf(metrics.getMacro_recall()), Double.valueOf(metrics.getMacro_fmeasure()), "-", "-", "-");
        System.out.format(" %15s%15.3f%15.3f%15.3f%15s%15s%15s%n", "Micro-average", Double.valueOf(metrics.getMicro_precision()), Double.valueOf(metrics.getMicro_recall()), Double.valueOf(metrics.getMicro_fmeasure()), "-", "-", "-");
        System.out.println("===========================================================================================================");
        System.out.println(System.getProperty("line.separator"));
    }

    public static OverallMetrics getMetrics() {
        log.info(metrics.toString());
        return metrics;
    }

    public static Map<String, TagWiseMetrics> getTagWiseMetricsMap() {
        log.info(tagWiseMetricsMap.toString());
        return tagWiseMetricsMap;
    }
}
