MMPretrain 实现 AUC 评价指标

2024.09.16

在使用 MMPretrain 训练模型的时候,发现其提供的指标中并没有 AUC 指标。但自己需要用到该指标,于是按照官方文档的指导,实现并在训练过程中使用了该指标。

官方文档对添加指标的说明

MMPretrain 支持为追求更高定制化的用户实现定制化的评估指标。

您需要在 mmpretrain/evaluation/metrics 下创建一个新文件,并在该文件中实现新的指标,例如,在 mmpretrain/evaluation/>metrics/my_metric.py 中。并创建一个自定义的评估指标类 MyMetric 继承 MMEngine 中的 BaseMetric

需要分别覆盖数据格式处理方法process和度量计算方法compute_metrics。 将其添加到“METRICS”注册表以实施任何自定义评估指标。

python
from mmengine.evaluator import BaseMetric
from mmpretrain.registry import METRICS

@METRICS.register_module()
class MyMetric(BaseMetric):

   def process(self, data_batch: Sequence[Dict], data_samples: Sequence[Dict]):
   """ The processed results should be stored in ``self.results``, which will
       be used to computed the metrics when all batches have been processed.
       `data_batch` stores the batch data from dataloader,
       and `data_samples` stores the batch outputs from model.
   """
       ...

   def compute_metrics(self, results: List):
   """ Compute the metrics from processed results and returns the evaluation results.
   """
       ...

然后,将其导入 mmpretrain/evaluation/metrics/__init__.py 以将其添加到 mmpretrain.evaluation 包中。

python
# In mmpretrain/evaluation/metrics/__init__.py
...
from .my_metric import MyMetric

__all__ = [..., 'MyMetric']

最后,在配置文件的 val_evaluatortest_evaluator 字段中使用 MyMetric

python
val_evaluator = dict(type='MyMetric', ...)
test_evaluator = val_evaluator
更多的细节可以参考 {external+mmengine:doc}`MMEngine 文档: Evaluation <design/evaluation>`.

错误尝试

按照文档指引,在 mmpretrain/evaluation/metrics 目录下创建新的文件 auc.py

python
@METRICS.register_module()
class SingleLabelAUC(BaseMetric):
  default_prefix: Optional[str] = 'auc' # 在终端打印时的前缀,如 `compute_metrics` 中的字典是 {'foo': 1, 'bar': 0.9},前缀是 `bzz` 那么在打印结果的时候会显示 `bzz/foo: 1 bzz/bar: 0.9`
  def process(self, data_batch, data_samples: Sequence[dict]) -> None:
    # 这一函数将训练的结果存入 self.results 变量中

  def compute_metrics(self, results: List) -> dict:
    # 这一函数计算并返回评价指标,字典格式,返回的字典会在验证和测试的时候将所有的键值对输出到终端中

process 函数的写法,参照 mmpretrain/evaluation/metrics/single_label.py 文件中的写法:

python
for data_sample in data_samples:
    result = dict()
    if 'pred_score' in data_sample:
        result['pred_score'] = data_sample['pred_score'].cpu()
    else:
        result['pred_label'] = data_sample['pred_label'].cpu()
    result['gt_label'] = data_sample['gt_label'].cpu()
    # Save the result to `self.results`.
    self.results.append(result)

从这段方法不难看出,process 函数仅仅是将模型预测的结果与真实的标签存入 self.results 中,而计算 auc 也只需要这两个参数,所以在 compute_metrics 函数中,我们只需要读取这两个参数并计算。

对于 auc 的计算,scikit-learn 库中已有现成的方法,所以在实现 auc 的时候,直接调用该库的函数:

python
metrics = {}

target = torch.cat([res['gt_label'] for res in results]) # 拼接所有的正确标签
if 'pred_score' in results[0]:
    pred = torch.stack([res['pred_score'] for res in results]) # 拼接所有的预测结果
    auc = roc_auc_score(target, pred, average='macro', sample_weight=None,
                        max_fpr=None, multi_class='ovr', labels=None)

    metrics['auc'] = auc
else:
    # If only label in the `pred_label`.
    pred = torch.cat([res['pred_label'] for res in results]) # 拼接所有的预测结果
    auc = roc_auc_score(target, pred, average='macro', sample_weight=None,
                        max_fpr=None, multi_class='ovr', labels=None)
    metrics['auc'] = auc

return metrics

在这个函数中,返回的字典键名没有特殊要求,可以随便写。

写完这个类之后,我们需要在 mmpretrain/evaluation/metrics/__init__.py 中注册自定义的评价指标:

python
from .auc import SingleLabelAUC

_all_ = [..., 'SingleLabelAUC']

最后在我们的训练配置文件添加如下配置:val_evaluator = dict(type='SingleLabelAUC')

测试的时候使用的是 resnet18_8xb16_cifar10.py 配置文件,能够输出运算的结果,完整的 auc.py 文件如下:

python
import torch

from mmengine.evaluator import BaseMetric
from mmpretrain.registry import METRICS
from sklearn.metrics import roc_auc_score
from typing import List, Optional, Sequence, Any


@METRICS.register_module()
class SingleLabelAUC(BaseMetric):
    default_prefix: Optional[str] = 'auc'

    def process(self, data_batch, data_samples: Sequence[dict]) -> None:
        for data_sample in data_samples:
            result = dict()
            if 'pred_score' in data_sample:
                result['pred_score'] = data_sample['pred_score'].cpu()
            else:
                result['pred_label'] = data_sample['pred_label'].cpu()
            result['gt_label'] = data_sample['gt_label'].cpu()
            self.results.append(result)

    def compute_metrics(self, results: List) -> dict:
        metrics = {}

        # concat
        target = torch.cat([res['gt_label'] for res in results])
        if 'pred_score' in results[0]:
            pred = torch.stack([res['pred_score'] for res in results])

            auc = roc_auc_score(target, pred, average='macro', sample_weight=None,
                                max_fpr=None, multi_class='ovr', labels=None)

            metrics['auc'] = auc
        else:
            pred = torch.cat([res['pred_label'] for res in results])
            auc = roc_auc_score(target, pred, average='macro', sample_weight=None,
                                max_fpr=None, multi_class='ovr', labels=None)
            metrics['auc'] = auc

        return metrics

正确写法

在所有文件设置完成之后,进行了一次测试,发现该方法在 acc=0.5 的情况下,auc 达到了 0.9(训练时的第一次验证),不是很正常,于是放弃自己写的代码,直接复制 SingleLabelMetric 的代码,并替换一些关键算法:

python
import torch
import numpy as np
import torch.nn.functional as F

from mmengine.evaluator import BaseMetric
from mmpretrain.registry import METRICS
from sklearn.metrics import roc_auc_score
from typing import List, Optional, Sequence, Union
from .single_label import to_tensor


@METRICS.register_module()
class SingleLabelAUC(BaseMetric):

    default_prefix: Optional[str] = 'single-label'

    def __init__(self,
                 thrs: Union[float, Sequence[Union[float, None]], None] = 0.,
                 average: Optional[str] = 'macro',
                 num_classes: Optional[int] = None,
                 collect_device: str = 'cpu',
                 prefix: Optional[str] = None) -> None:
        super().__init__(collect_device=collect_device, prefix=prefix)

        if isinstance(thrs, float) or thrs is None:
            self.thrs = (thrs, )
        else:
            self.thrs = tuple(thrs)

        self.average = average
        self.num_classes = num_classes

    def process(self, data_batch, data_samples: Sequence[dict]):


        for data_sample in data_samples:
            result = dict()
            if 'pred_score' in data_sample:
                result['pred_score'] = data_sample['pred_score'].cpu()
            else:
                num_classes = self.num_classes or data_sample.get(
                    'num_classes')
                assert num_classes is not None, \
                    'The `num_classes` must be specified if no `pred_score`.'
                result['pred_label'] = data_sample['pred_label'].cpu()
                result['num_classes'] = num_classes
            result['gt_label'] = data_sample['gt_label'].cpu()
            self.results.append(result)

    def compute_metrics(self, results: List):

        metrics = {}

        # concat
        target = torch.cat([res['gt_label'] for res in results])
        if 'pred_score' in results[0]:
            pred = torch.stack([res['pred_score'] for res in results])
            auc = self.calculate(
                pred, target, thrs=self.thrs, average=self.average)

            multi_thrs = len(self.thrs) > 1
            for i, thr in enumerate(self.thrs):
                if multi_thrs:
                    suffix = 'auc_no-thr' if thr is None else f'_thr-{thr:.2f}'
                else:
                    suffix = 'auc'
                print(type(auc), auc)
                for k, v in enumerate(auc):
                    print(k, v)
                    metrics[str(k)+suffix] = v
        else:
            pred = torch.cat([res['pred_label'] for res in results])
            auc = self.calculate(
                pred,
                target,
                average=self.average,
                num_classes=results[0]['num_classes'])
            metrics['auc'] = auc

        return metrics

    @staticmethod
    def calculate(
        pred: Union[torch.Tensor, np.ndarray, Sequence],
        target: Union[torch.Tensor, np.ndarray, Sequence],
        thrs: Sequence[Union[float, None]] = (0., ),
        average: Optional[str] = 'macro',
        num_classes: Optional[int] = None,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        average_options = ['micro', 'macro', None]
        assert average in average_options, 'Invalid `average` argument, ' \
            f'please specify from {average_options}.'

        pred = to_tensor(pred)
        target = to_tensor(target).to(torch.int64)
        assert pred.size(0) == target.size(0), \
            f"The size of pred ({pred.size(0)}) doesn't match "\
            f'the target ({target.size(0)}).'

        if pred.ndim == 1:
            assert num_classes is not None, \
                'Please specify the `num_classes` if the `pred` is labels ' \
                'intead of scores.'
            gt_positive = F.one_hot(target.flatten(), num_classes)
            pred_positive = F.one_hot(pred.to(torch.int64), num_classes)
            return roc_auc_score(gt_positive, pred_positive, ,
                                                average=average)
        else:
            # For pred score, calculate on all thresholds.
            num_classes = pred.size(1)
            pred_score, pred_label = torch.topk(pred, k=1)
            pred_score = pred_score.flatten()
            pred_label = pred_label.flatten()

            gt_positive = F.one_hot(target.flatten(), num_classes)

            results = []
            for thr in thrs:
                pred_positive = F.one_hot(pred_label, num_classes)
                if thr is not None:
                    pred_positive[pred_score <= thr] = 0
                results.append(
                    roc_auc_score(gt_positive, pred_positive, 
                                                 average=average))

            return results

经过再一次的实验,在其他条件同样的情况下,第一次训练验证时,auc 在 0.75 左右,符合预期。

本文最后更新于 2024 年 9 月 24 日