diff --git a/docs/en/useful_tools.md b/docs/en/useful_tools.md index d6dc576..56710c5 100644 --- a/docs/en/useful_tools.md +++ b/docs/en/useful_tools.md @@ -378,3 +378,49 @@ configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py \ checkpoint/fcn_r50-d8_512x1024_40k_cityscapes_20200604_192608-efe53f0d.pth \ fcn ``` + +## Confusion Matrix + +In order to generate and plot a ```nxn``` confusion matrix where ```n``` is the number of classes, you can follow the steps: + +### 1.Generate a prediction result in pkl format using `test.py` + +```shell +python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${PATH_TO_RESULT_FILE}] +``` + +Note that the argument for ```--eval``` should be ```None``` so that the result file contains numpy type of prediction results. The usage for distribution test is just the same. + +Example: + +```shell +python tools/test.py \ +configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py \ +checkpoint/fcn_r50-d8_512x1024_40k_cityscapes_20200604_192608-efe53f0d.pth \ +--out result/pred_result.pkl +``` + +### 2. Use ```confusion_matrix.py``` to generate and plot a confusion matrix + +```shell +python tools/confusion_matrix.py ${CONFIG_FILE} ${PATH_TO_RESULT_FILE} ${SAVE_DIR} --show +``` + +Description of arguments: + +- `config`: Path to the test config file. +- `prediction_path`: Path to the prediction .pkl result. +- `save_dir`: Directory where confusion matrix will be saved. +- `--show`: Enable result visualize. +- `--color-theme`: Theme of the matrix color map. +- `--cfg_options`: Custom options to replace the config file. + +Example: + +```shell +python tools/confusion_matrix.py \ +configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py \ +result/pred_result.pkl \ +result/confusion_matrix \ +--show +``` diff --git a/tools/confusion_matrix.py b/tools/confusion_matrix.py new file mode 100644 index 0000000..41d308b --- /dev/null +++ b/tools/confusion_matrix.py @@ -0,0 +1,178 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os + +import matplotlib.pyplot as plt +import mmcv +import numpy as np +from matplotlib.ticker import MultipleLocator +from mmcv import Config, DictAction + +from mmseg.datasets import build_dataset + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Generate confusion matrix from segmentation results') + parser.add_argument('config', help='test config file path') + parser.add_argument( + 'prediction_path', help='prediction path where test .pkl result') + parser.add_argument( + 'save_dir', help='directory where confusion matrix will be saved') + parser.add_argument( + '--show', action='store_true', help='show confusion matrix') + parser.add_argument( + '--color-theme', + default='winter', + help='theme of the matrix color map') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def calculate_confusion_matrix(dataset, results): + """Calculate the confusion matrix. + + Args: + dataset (Dataset): Test or val dataset. + results (list[ndarray]): A list of segmentation results in each image. + """ + n = len(dataset.CLASSES) + confusion_matrix = np.zeros(shape=[n, n]) + assert len(dataset) == len(results) + prog_bar = mmcv.ProgressBar(len(results)) + for idx, per_img_res in enumerate(results): + res_segm = per_img_res + gt_segm = dataset.get_gt_seg_map_by_idx(idx) + inds = n * gt_segm + res_segm + inds = inds.flatten() + mat = np.bincount(inds, minlength=n**2).reshape(n, n) + confusion_matrix += mat + prog_bar.update() + return confusion_matrix + + +def plot_confusion_matrix(confusion_matrix, + labels, + save_dir=None, + show=True, + title='Normalized Confusion Matrix', + color_theme='winter'): + """Draw confusion matrix with matplotlib. + + Args: + confusion_matrix (ndarray): The confusion matrix. + labels (list[str]): List of class names. + save_dir (str|optional): If set, save the confusion matrix plot to the + given path. Default: None. + show (bool): Whether to show the plot. Default: True. + title (str): Title of the plot. Default: `Normalized Confusion Matrix`. + color_theme (str): Theme of the matrix color map. Default: `winter`. + """ + # normalize the confusion matrix + per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis] + confusion_matrix = \ + confusion_matrix.astype(np.float32) / per_label_sums * 100 + + num_classes = len(labels) + fig, ax = plt.subplots( + figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=180) + cmap = plt.get_cmap(color_theme) + im = ax.imshow(confusion_matrix, cmap=cmap) + plt.colorbar(mappable=im, ax=ax) + + title_font = {'weight': 'bold', 'size': 12} + ax.set_title(title, fontdict=title_font) + label_font = {'size': 10} + plt.ylabel('Ground Truth Label', fontdict=label_font) + plt.xlabel('Prediction Label', fontdict=label_font) + + # draw locator + xmajor_locator = MultipleLocator(1) + xminor_locator = MultipleLocator(0.5) + ax.xaxis.set_major_locator(xmajor_locator) + ax.xaxis.set_minor_locator(xminor_locator) + ymajor_locator = MultipleLocator(1) + yminor_locator = MultipleLocator(0.5) + ax.yaxis.set_major_locator(ymajor_locator) + ax.yaxis.set_minor_locator(yminor_locator) + + # draw grid + ax.grid(True, which='minor', linestyle='-') + + # draw label + ax.set_xticks(np.arange(num_classes)) + ax.set_yticks(np.arange(num_classes)) + ax.set_xticklabels(labels) + ax.set_yticklabels(labels) + + ax.tick_params( + axis='x', bottom=False, top=True, labelbottom=False, labeltop=True) + plt.setp( + ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor') + + # draw confusion matrix value + for i in range(num_classes): + for j in range(num_classes): + ax.text( + j, + i, + '{}%'.format( + round(confusion_matrix[i, j], 2 + ) if not np.isnan(confusion_matrix[i, j]) else -1), + ha='center', + va='center', + color='w', + size=7) + + ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1 + + fig.tight_layout() + if save_dir is not None: + plt.savefig( + os.path.join(save_dir, 'confusion_matrix.png'), format='png') + if show: + plt.show() + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + results = mmcv.load(args.prediction_path) + + assert isinstance(results, list) + if isinstance(results[0], np.ndarray): + pass + else: + raise TypeError('invalid type of prediction results') + + if isinstance(cfg.data.test, dict): + cfg.data.test.test_mode = True + elif isinstance(cfg.data.test, list): + for ds_cfg in cfg.data.test: + ds_cfg.test_mode = True + + dataset = build_dataset(cfg.data.test) + confusion_matrix = calculate_confusion_matrix(dataset, results) + plot_confusion_matrix( + confusion_matrix, + dataset.CLASSES, + save_dir=args.save_dir, + show=args.show) + + +if __name__ == '__main__': + main()