Source code for towhee.functional.mixins.metric

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from typing import Tuple

from towhee.hparam import param_scope


class Collector:
    """
    Collector class for metric
    """

    # pylint: disable=dangerous-default-value
    def __init__(self,
                 metrics: list = None,
                 labels: dict = None,
                 scores: dict = None):
        self.metrics = metrics if metrics is not None else []
        self.scores = scores if scores is not None else {}
        self.labels = labels if labels is not None else {}

    def add_metrics(self, value: str):
        self.metrics.append(value)

    def add_scores(self, value: dict):
        self.scores.update(value)

    def add_labels(self, value: dict):
        self.labels.update(value)


def encode_fig_img(mat, size=300):
    # pylint: disable=import-outside-toplevel
    from towhee.utils.sklearn_utils import ConfusionMatrixDisplay
    from towhee.utils.matplotlib_utils import matplotlib as mpl
    import base64
    import io
    mpl.use('Agg')  # Prevent showing stuff
    cm = mat.astype('float') / mat.sum(axis=1)[:, np.newaxis]  # normalize
    fig = ConfusionMatrixDisplay(cm)
    fig.plot(cmap='GnBu')
    buf = io.BytesIO()
    fig.figure_.savefig(buf, format='jpg')
    buf.seek(0)
    buf = buf.read()
    src = 'src="data:image/jpeg;base64,' + base64.b64encode(buf).decode() + '" '
    w = 'width = "' + str(size) + 'px" '
    h = 'height = "' + str(size) + 'px" '
    return '<img ' + src + w + h + '>'


def get_scores_dict(collector: Collector):
    scores_dict = {}
    for metric in collector.metrics:
        scores_dict[metric] = []
        for model in collector.scores:
            scores_dict[metric].append(collector.scores[model][metric])

    return scores_dict

def mean_hit_ratio(actual, predicted):
    ratios = []
    for act, pre in zip(actual, predicted):
        hit_num = len(set(act) & set(pre))
        ratios.append(hit_num / len(act))

    return sum(ratios) / len(ratios)

def mean_average_precision(actual, predicted):
    aps = []
    for act, pre in zip(actual, predicted):
        cnt = 0
        precision_sum = 0
        for i, p in enumerate(pre):
            if p in act:
                cnt += 1
                precision_sum += cnt/(i+1)
            ap = precision_sum / cnt if cnt else 0
        aps.append(ap)

    return sum(aps) / len(aps)


def _evaluate_callback(self):

    def wrapper(_: str, index: Tuple[str], *arg, **kws):
        # pylint: disable=import-outside-toplevel
        # pylint: disable=unused-argument
        actual, predicted = index
        name = None
        if 'name' in kws:
            name = kws['name']
        elif arg:
            name, = arg
        self.collector.add_labels(
            {name: {
                'actual': actual,
                'predicted': predicted
            }})
        score = {name: {}}
        actual_list = []
        predicted_list = []
        for x in self:
            actual_list.append(getattr(x, actual))
            predicted_list.append(getattr(x, predicted))

        from towhee.utils import sklearn_utils
        for metric_type in self.collector.metrics:
            if metric_type == 'accuracy':
                re = sklearn_utils.accuracy_score(actual_list, predicted_list)
            elif metric_type == 'recall':
                re = sklearn_utils.recall_score(actual_list,
                                                predicted_list,
                                                average='weighted')
            elif metric_type == 'confusion_matrix':
                re = sklearn_utils.confusion_matrix(actual_list,
                                                    predicted_list)
            elif metric_type == 'mean_hit_ratio':
                re = mean_hit_ratio(actual_list,
                               predicted_list)
            elif metric_type == 'mean_average_precision':
                re = mean_average_precision(actual_list,
                                            predicted_list)
            score[name].update({metric_type: re})
        self.collector.add_scores(score)
        return self

    return wrapper


[docs]class MetricMixin: """ Mixin for metric """ # pylint: disable=import-outside-toplevel def __init__(self): super().__init__() self.collector = Collector() self.evaluate = param_scope().dispatch(_evaluate_callback(self)) def with_metrics(self, metric_types: list = None): self.collector.metrics = metric_types return self
[docs] def report(self): """ report the metric scores Examples: >>> from towhee import DataCollection >>> from towhee import Entity >>> dc1 = DataCollection([Entity(a=a, b=b, c=c) for a, b, c in zip([0,1,1,0,0], [0,1,1,1,0], [0,1,1,0,0])]) >>> dc1.with_metrics(['accuracy', 'recall']).evaluate['a', 'c'](name='lr').evaluate['a', 'b'](name='rf').report() accuracy recall lr 1.0 1.0 rf 0.8 0.8 {'lr': {'accuracy': 1.0, 'recall': 1.0}, 'rf': {'accuracy': 0.8, 'recall': 0.8}} >>> dc2 = DataCollection([Entity(pred=[1,6,2,7,8,3,9,10,4,5], act=[1,2,3,4,5])]) >>> dc2.with_metrics(['mean_average_precision', 'mean_hit_ratio']).evaluate['act', 'pred'](name='test').report() mean_average_precision mean_hit_ratio test 0.622222 1.0 {'test': {'mean_average_precision': 0.6222222222222221, 'mean_hit_ratio': 1.0}} """ from towhee.utils.ipython_utils import HTML, display from towhee.utils.pandas_utils import pandas as pd scores_dict = get_scores_dict(self.collector) df = pd.DataFrame(data=scores_dict, index=list(self.collector.scores.keys())) if 'confusion_matrix' in self.collector.metrics: display(HTML(df.to_html(formatters={'confusion_matrix': encode_fig_img}, escape=False))) else: display(df) return self.collector.scores