package org.talend.dataquality.parsing.util;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/talend/dataquality/parsing/util/CrfEvaluator.class */
public class CrfEvaluator {
    private static List<List<String>> predictions = new ArrayList();
    private static List<List<String>> references = new ArrayList();
    private static boolean hasComputed = false;
    private static final Logger LOGGER = LoggerFactory.getLogger(CrfEvaluator.class);
    private static Map<String, TagWiseMetrics> tagWiseMetricsMap = new HashMap();
    private static OverallMetrics metrics = new OverallMetrics();

    public static void init(List<List<String>> list, List<List<String>> list2) {
        if (list == null || list.isEmpty()) {
            throw new IllegalArgumentException("NULL or empty list of predictions");
        }
        if (list2 == null || list2.isEmpty()) {
            throw new IllegalArgumentException("NULL or empty list of references");
        }
        if (list2.size() != list.size()) {
            throw new IllegalArgumentException("the size of prediction list should equals to the size of real tag list");
        }
        predictions.addAll(list);
        references.addAll(list2);
        metrics.setInstTotalNum(references.size());
    }

    public static void reset() {
        hasComputed = false;
        if (!predictions.isEmpty()) {
            predictions.clear();
        }
        if (!references.isEmpty()) {
            references.clear();
        }
        tagWiseMetricsMap.clear();
        metrics.setNumTags(0);
        metrics.setItemTotalCorrect(0);
        metrics.setItemTotalNum(0);
        metrics.setItemTotalObservation(0);
        metrics.setItemTotalPrediction(0);
        metrics.setItemAccuracy(0.0d);
        metrics.setInstTotalCorrect(0);
        metrics.setInstTotalNum(0);
        metrics.setInstAccuracy(0.0d);
        metrics.setMacroPrecision(0.0d);
        metrics.setMacroRecall(0.0d);
        metrics.setMacroFMeasure(0.0d);
        metrics.setMicroPrecision(0.0d);
        metrics.setMicroRecall(0.0d);
        metrics.setMicroFMeasure(0.0d);
    }

    public static void compute() {
        for (int i = 0; i < metrics.getInstTotalNum(); i++) {
            List<String> list = predictions.get(i);
            List<String> list2 = references.get(i);
            if (list.size() != list2.size()) {
                LOGGER.warn("Exclude instance " + i + ": {Predict tags = " + list + ", Reference tags = " + list2 + "}\nCaused by #Predicts not equals to #References (" + list.size() + " != " + list2.size() + ")");
            } else {
                accumulate(list, list2);
            }
        }
        if (tagWiseMetricsMap.size() == 0) {
            LOGGER.error("Exit! Impossible to compute the metrics, all instances are invalid.");
            return;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        metrics.setNumTags(tagWiseMetricsMap.keySet().size());
        for (String str : tagWiseMetricsMap.keySet()) {
            TagWiseMetrics tagWiseMetrics = tagWiseMetricsMap.get(str);
            int num_correct = tagWiseMetrics.getNum_correct();
            int num_prediction = tagWiseMetrics.getNum_prediction();
            int num_observation = tagWiseMetrics.getNum_observation();
            metrics.setItemTotalCorrect(metrics.getItemTotalCorrect() + num_correct);
            metrics.setItemTotalPrediction(metrics.getItemTotalPrediction() + num_prediction);
            metrics.setItemTotalObservation(metrics.getItemTotalObservation() + num_observation);
            if (num_observation != 0) {
                double d4 = 0.0d;
                double d5 = 0.0d;
                if (num_prediction > 0) {
                    d4 = num_correct / num_prediction;
                    tagWiseMetrics.setPrecision(d4);
                }
                if (num_observation > 0) {
                    d5 = num_correct / num_observation;
                    tagWiseMetrics.setRecall(d5);
                }
                if (num_prediction + num_observation > 0) {
                    tagWiseMetrics.setFmeasure((2 * num_correct) / (num_prediction + num_observation));
                }
                d += metrics.getMacroPrecision() + d4;
                d2 += metrics.getMacroRecall() + d5;
                d3 += metrics.getMacroFMeasure() + tagWiseMetrics.getFmeasure();
                tagWiseMetricsMap.replace(str, tagWiseMetrics);
            }
        }
        metrics.setMacroPrecision(d / metrics.getNumTags());
        metrics.setMacroRecall(d2 / metrics.getNumTags());
        metrics.setMacroFMeasure(d3 / metrics.getNumTags());
        if (metrics.getItemTotalPrediction() > 0) {
            metrics.setMicroPrecision(metrics.getItemTotalCorrect() / metrics.getItemTotalPrediction());
        }
        if (metrics.getItemTotalObservation() > 0) {
            metrics.setMicroRecall(metrics.getItemTotalCorrect() / metrics.getItemTotalObservation());
        }
        if (metrics.getItemTotalPrediction() + metrics.getItemTotalObservation() > 0) {
            metrics.setMicroFMeasure((2 * metrics.getItemTotalCorrect()) / (metrics.getItemTotalPrediction() + metrics.getItemTotalObservation()));
        }
        if (metrics.getItemTotalNum() > 0) {
            metrics.setItemAccuracy(metrics.getItemTotalCorrect() / metrics.getItemTotalNum());
        }
        if (metrics.getInstTotalNum() > 0) {
            metrics.setInstAccuracy(metrics.getInstTotalCorrect() / metrics.getInstTotalNum());
        }
        hasComputed = true;
    }

    private static void accumulate(List<String> list, List<String> list2) {
        int i = 0;
        for (int i2 = 0; i2 < list2.size(); i2++) {
            String str = list2.get(i2);
            String str2 = list.get(i2);
            increaseTagWiseMetricsMap(str, "num_observation");
            increaseTagWiseMetricsMap(str2, "num_prediction");
            if (str.equals(str2)) {
                increaseTagWiseMetricsMap(str, "num_correct");
                i++;
            }
            metrics.setItemTotalNum(metrics.getItemTotalNum() + 1);
        }
        if (i == list2.size()) {
            metrics.setInstTotalCorrect(metrics.getInstTotalCorrect() + 1);
        }
    }

    private static void increaseTagWiseMetricsMap(String str, String str2) {
        TagWiseMetrics orDefault = tagWiseMetricsMap.getOrDefault(str, new TagWiseMetrics());
        orDefault.increaseNum(str2);
        if (tagWiseMetricsMap.containsKey(str)) {
            tagWiseMetricsMap.replace(str, orDefault);
        } else {
            tagWiseMetricsMap.put(str, orDefault);
        }
    }

    public static String resultsAsString() {
        if (!hasComputed) {
            LOGGER.error("The metrics haven't been computed or have been reset, please call Evaluator.compute()");
            return "";
        }
        String property = System.getProperty("line.separator");
        StringBuilder sb = new StringBuilder(property);
        sb.append("+---------------------------------------------------------------------------------------------------------+\n");
        sb.append("|                                            Model Performance                                            |\n");
        sb.append("+---------------------------------------------------------------------------------------------------------+\n");
        sb.append(" Overall accuracy:");
        sb.append(String.format(" - Item accuracy: %d / %d (%1.3f)%n", Integer.valueOf(metrics.getItemTotalCorrect()), Integer.valueOf(metrics.getItemTotalNum()), Double.valueOf(metrics.getItemAccuracy())));
        sb.append(String.format(" - Instance accuracy: %d / %d (%1.3f)%n", Integer.valueOf(metrics.getInstTotalCorrect()), Integer.valueOf(metrics.getInstTotalNum()), Double.valueOf(metrics.getInstAccuracy())));
        sb.append(property);
        sb.append(" Performance by label:\n ");
        sb.append(String.format(" %15s%15s%15s%15s%15s%15s%15s%n", "   ", "Precision", "Recall", "F1-measure", "#Match", "#Predict", "#Ref"));
        tagWiseMetricsMap.forEach((str, tagWiseMetrics) -> {
            if (tagWiseMetrics != null) {
                sb.append(String.format(" %15s%15.3f%15.3f%15.3f%15d%15d%15d%n", "\"" + str + "\"", Double.valueOf(tagWiseMetrics.getPrecision()), Double.valueOf(tagWiseMetrics.getRecall()), Double.valueOf(tagWiseMetrics.getFmeasure()), Integer.valueOf(tagWiseMetrics.getNum_correct()), Integer.valueOf(tagWiseMetrics.getNum_prediction()), Integer.valueOf(tagWiseMetrics.getNum_observation())));
            }
        });
        sb.append(String.format("%n %15s%15.3f%15.3f%15.3f%15s%15s%15s%n", "Macro-average", Double.valueOf(metrics.getMacroPrecision()), Double.valueOf(metrics.getMacroRecall()), Double.valueOf(metrics.getMacroFMeasure()), "-", "-", "-"));
        sb.append(String.format(" %15s%15.3f%15.3f%15.3f%15s%15s%15s%n", "Micro-average", Double.valueOf(metrics.getMicroPrecision()), Double.valueOf(metrics.getMicroRecall()), Double.valueOf(metrics.getMicroFMeasure()), "-", "-", "-"));
        sb.append("===========================================================================================================");
        sb.append(property);
        return sb.toString();
    }

    public static OverallMetrics getMetrics() {
        LOGGER.info(metrics.toString());
        return metrics;
    }

    public static Map<String, TagWiseMetrics> getTagWiseMetricsMap() {
        LOGGER.info(tagWiseMetricsMap.toString());
        return tagWiseMetricsMap;
    }
}
