diff --git a/mmseg/apis/test.py b/mmseg/apis/test.py index 1597df6..9728de4 100644 --- a/mmseg/apis/test.py +++ b/mmseg/apis/test.py @@ -149,7 +149,7 @@ def multi_gpu_test(model, results.append(result) if rank == 0: - batch_size = data['img'][0].size(0) + batch_size = len(result) for _ in range(batch_size * world_size): prog_bar.update()