246 lines
11 KiB
Python
246 lines
11 KiB
Python
import argparse
|
|
import os
|
|
import sys
|
|
import json
|
|
|
|
sys.path.append(os.getcwd())
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision
|
|
from torchvision import datasets, models, transforms
|
|
from load_model import initialize_model
|
|
from sklearn.metrics import f1_score
|
|
from sklearn.metrics import recall_score
|
|
from sklearn.metrics import precision_score
|
|
|
|
def accuracy(output, target, topk=(1,), e2e=False) :
|
|
"""
|
|
Computes the accuracy over the k top predictions for the specified values of k
|
|
In top-5 accuracy you give yourself credit for having the right answer
|
|
if the right answer appears in your top five guesses.
|
|
|
|
ref:
|
|
- https://pytorch.org/docs/stable/generated/torch.topk.html
|
|
- https://discuss.pytorch.org/t/imagenet-example-accuracy-calculation/7840
|
|
- https://gist.github.com/weiaicunzai/2a5ae6eac6712c70bde0630f3e76b77b
|
|
- https://discuss.pytorch.org/t/top-k-error-calculation/48815/2
|
|
- https://stackoverflow.com/questions/59474987/how-to-get-top-k-accuracy-in-semantic-segmentation-using-pytorch
|
|
|
|
:param output: output is the prediction of the model e.g. scores, logits, raw y_pred before normalization or getting classes
|
|
:param target: target is the truth
|
|
:param topk: tuple of topk's to compute e.g. (1, 2, 5) computes top 1, top 2 and top 5.
|
|
e.g. in top 2 it means you get a +1 if your models's top 2 predictions are in the right label.
|
|
So if your model predicts cat, dog (0, 1) and the true label was bird (3) you get zero
|
|
but if it were either cat or dog you'd accumulate +1 for that example.
|
|
:return: list of topk accuracy [top1st, top2nd, ...] depending on your topk input
|
|
"""
|
|
with torch.no_grad():
|
|
# ---- get the topk most likely labels according to your model
|
|
# get the largest k \in [n_classes] (i.e. the number of most likely probabilities we will use)
|
|
maxk = max(topk) # max number labels we will consider in the right choices for out model
|
|
batch_size = target.size(0)
|
|
|
|
# get top maxk indicies that correspond to the most likely probability scores
|
|
# (note _ means we don't care about the actual top maxk scores just their corresponding indicies/labels)
|
|
if e2e:
|
|
y_pred = output
|
|
else:
|
|
_, y_pred = output.topk(k=maxk, dim=1) # _, [B, n_classes] -> [B, maxk]
|
|
y_pred = y_pred.t() # [B, maxk] -> [maxk, B] Expects input to be <= 2-D tensor and transposes dimensions 0 and 1.
|
|
|
|
# - get the credit for each example if the models predictions is in maxk values (main crux of code)
|
|
# for any example, the model will get credit if it's prediction matches the ground truth
|
|
# for each example we compare if the model's best prediction matches the truth. If yes we get an entry of 1.
|
|
# if the k'th top answer of the model matches the truth we get 1.
|
|
# Note: this for any example in batch we can only ever get 1 match (so we never overestimate accuracy <1)
|
|
target_reshaped = target.view(1, -1).expand_as(y_pred) # [B] -> [B, 1] -> [maxk, B]
|
|
# compare every topk's model prediction with the ground truth & give credit if any matches the ground truth
|
|
correct = (y_pred == target_reshaped) # [maxk, B] were for each example we know which topk prediction matched truth
|
|
# original: correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
|
|
# -- get topk accuracy
|
|
list_topk_accs = [] # idx is topk1, topk2, ... etc
|
|
for k in topk:
|
|
# get tensor of which topk answer was right
|
|
ind_which_topk_matched_truth = correct[:k] # [maxk, B] -> [k, B]
|
|
# flatten it to help compute if we got it correct for each example in batch
|
|
flattened_indicator_which_topk_matched_truth = ind_which_topk_matched_truth.reshape(-1).float() # [k, B] -> [kB]
|
|
# get if we got it right for any of our top k prediction for each example in batch
|
|
tot_correct_topk = flattened_indicator_which_topk_matched_truth.float().sum(dim=0, keepdim=True) # [kB] -> [1]
|
|
# compute topk accuracy - the accuracy of the mode's ability to get it right within it's top k guesses/preds
|
|
topk_acc = tot_correct_topk / batch_size # topk accuracy for entire batch
|
|
list_topk_accs.append(topk_acc.cpu().numpy()[0])
|
|
return np.array(list_topk_accs) # array of topk accuracies for entire batch [topk1, topk2, ... etc]
|
|
|
|
def evaluate(data_dir, backbone, model_def_path, pretrained_path, device, topk=(1,)):
|
|
|
|
num_classes = len([f for f in os.listdir(data_dir) if not f.startswith('.')])
|
|
if max(topk) > num_classes:
|
|
topk = np.array(topk)
|
|
topk = topk[topk<=num_classes].tolist()
|
|
|
|
model_structure, input_size = initialize_model(backbone, num_classes, False, model_def_path)
|
|
model_structure.load_state_dict(torch.load(pretrained_path))
|
|
model = model_structure.eval()
|
|
model = model.to(device)
|
|
|
|
data_transforms = transforms.Compose([
|
|
transforms.Resize(input_size),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0, 0, 0], [1/255.0, 1/255.0, 1/255.0]),
|
|
transforms.Normalize([0.5*256, 0.5*256, 0.5*256], [256.0, 256.0, 256.0])
|
|
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
])
|
|
image_datasets = datasets.ImageFolder(data_dir, data_transforms)
|
|
batch_size = 32
|
|
dataloaders = torch.utils.data.DataLoader(image_datasets, shuffle=False, batch_size=batch_size, num_workers=4)
|
|
|
|
list_topk_accs = np.zeros(len(topk))
|
|
y_preds = []
|
|
y_labels = []
|
|
for inputs, labels in dataloaders:
|
|
with torch.no_grad():
|
|
inputs = inputs.to(device)
|
|
labels = labels.to(device)
|
|
|
|
outputs = model(inputs)
|
|
|
|
list_topk_accs += accuracy(outputs, labels, topk)* len(labels)
|
|
_, y_pred = outputs.topk(k=1)
|
|
y_preds += y_pred.cpu().numpy().tolist()
|
|
y_labels += labels.cpu().numpy().tolist()
|
|
|
|
print()
|
|
list_topk_accs = list_topk_accs/len(image_datasets)
|
|
with open('eval_results.txt', 'w') as writefile:
|
|
for i, k in enumerate(topk):
|
|
if k is None:
|
|
break
|
|
acc_str = 'top '+ str(k) + ' accuracy: ' + str(list_topk_accs[i])
|
|
print(acc_str)
|
|
writefile.write(acc_str)
|
|
print()
|
|
writefile.write('\n')
|
|
class_id = image_datasets.class_to_idx
|
|
class_id = dict([(value, key) for key, value in class_id.items()])
|
|
f1 = f1_score(y_labels, y_preds, average=None)
|
|
recall = recall_score(y_labels, y_preds, average=None)
|
|
precision = precision_score(y_labels, y_preds, average=None)
|
|
header = 'Label Precision Recall F1 score'
|
|
itn_line = '{:10} {:8.3f} {:8.3f} {:8.3f}'
|
|
writefile.write(header)
|
|
print(header )
|
|
for i, score in enumerate(f1):
|
|
res_str = itn_line.format(class_id[i], precision[i], recall[i], score)
|
|
print( res_str )
|
|
writefile.write(res_str)
|
|
return list_topk_accs, f1, recall,precision
|
|
|
|
def evaluate_e2e(gt_path, classification_path, topk=[1,5,10]):
|
|
preds = {}
|
|
for file in os.listdir(classification_path):
|
|
if file.split('.')[-1] != 'json':
|
|
continue
|
|
|
|
full_filename = os.path.join(classification_path, file)
|
|
with open(full_filename,'r') as fi:
|
|
dic = json.load(fi)
|
|
preds[dic['img_path'] ] = dic["0_0"] # {img_id: [[score1,label1], [score2,label2]]}
|
|
preds[dic['img_path'] ].sort(reverse=True)
|
|
with open(gt_path, 'r') as json_file2:
|
|
gts = json.load(json_file2) # {img_id: label}
|
|
|
|
pred_scores = []
|
|
pred_labels = []
|
|
pred_labels_ = []
|
|
y_true = []
|
|
|
|
for img_name in preds:
|
|
res = preds[img_name]
|
|
res0 = list(zip(*res))
|
|
pred_scores.append(list(res0[0]))
|
|
pred_labels.append(res0[1][0])
|
|
pred_labels_.append(res0[1])
|
|
y_true.append(gts[img_name])
|
|
|
|
nc = len(set(y_true))
|
|
|
|
if max(topk) > nc:
|
|
topk = np.array(topk)
|
|
topk = topk[topk<=nc].tolist()
|
|
|
|
list_topk_accs = accuracy(torch.FloatTensor(pred_labels_), torch.FloatTensor(y_true), topk=topk,e2e=True)
|
|
print()
|
|
with open('eval_results.txt', 'w') as writefile:
|
|
for i, k in enumerate(topk):
|
|
if k is None:
|
|
break
|
|
acc_str = 'top '+ str(k) + ' accuracy: ' + str(list_topk_accs[i])
|
|
print(acc_str)
|
|
writefile.write(acc_str+'\n')
|
|
print()
|
|
writefile.write('\n')
|
|
|
|
f1 = f1_score(y_true, pred_labels, average=None)
|
|
recall = recall_score(y_true, pred_labels, average=None)
|
|
precision = precision_score(y_true, pred_labels, average=None)
|
|
|
|
header = 'Label Precision Recall F1 score'
|
|
itn_line = '{:10} {:8.3f} {:8.3f} {:8.3f}'
|
|
writefile.write(header+'\n')
|
|
print(header )
|
|
for i, score in enumerate(f1):
|
|
res_str = itn_line.format(str(i), precision[i], recall[i], score)
|
|
print( res_str )
|
|
writefile.write(res_str+'\n')
|
|
|
|
return list_topk_accs, f1, recall,precision
|
|
|
|
|
|
def check_args(parsed_args):
|
|
""" Function to check for inherent contradictions within parsed arguments.
|
|
Args
|
|
parsed_args: parser.parse_args()
|
|
Returns
|
|
parsed_args
|
|
"""
|
|
if parsed_args.gpu >= 0 and torch.cuda.is_available() == False:
|
|
raise ValueError("No gpu is available")
|
|
return parsed_args
|
|
|
|
|
|
def parse_args(args):
|
|
"""
|
|
Parse the arguments.
|
|
"""
|
|
|
|
parser = argparse.ArgumentParser(description='Simple training script for training a image classification network.')
|
|
parser.add_argument('--data-dir', type=str, help='Path to the image directory')
|
|
parser.add_argument('--model-def-path', type=str, help='Path to pretrained model definition', default=None )
|
|
parser.add_argument('--backbone', help='Backbone model.', default='resnet18', type=str)
|
|
parser.add_argument('--snapshot', help='Path to the pretrained models.', default=None)
|
|
parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi). (-1 for cpu)',type=int,default=-1)
|
|
parser.add_argument('--preds', help='path to predicted results',type=str,default=None)
|
|
parser.add_argument('--gts', help='path to ground truth',type=str,default=None)
|
|
|
|
print(vars(parser.parse_args(args)))
|
|
return check_args(parser.parse_args(args))
|
|
|
|
|
|
def main(args=None):
|
|
# parse arguments
|
|
if args is None:
|
|
args = sys.argv[1:]
|
|
|
|
args = parse_args(args)
|
|
device = "cuda:"+str(args.gpu) if args.gpu >= 0 else "cpu"
|
|
if args.preds is not None:
|
|
list_topk_accs, f1, recall,precision = evaluate_e2e(args.gts, args.preds)
|
|
else:
|
|
list_topk_accs, f1, recall,precision = evaluate(args.data_dir, args.backbone, args.model_def_path, args.snapshot, device, [1,5,10])
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|