Skip to content

Commit 54b3593

Browse files
committed
Merge branch 'enhanced_validate.py' of github.com:ha405/pytorch-image-models into ha405-enhanced_validate.py
2 parents 6d30b82 + 74425f9 commit 54b3593

File tree

1 file changed

+51
-3
lines changed

1 file changed

+51
-3
lines changed

validate.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@
4242

4343
has_compile = hasattr(torch, 'compile')
4444

45+
try:
46+
from sklearn.metrics import precision_score, recall_score, f1_score
47+
has_sklearn = True
48+
except ImportError:
49+
has_sklearn = False
50+
4551
_logger = logging.getLogger('validate')
4652

4753

@@ -158,6 +164,11 @@
158164
parser.add_argument('--retry', default=False, action='store_true',
159165
help='Enable batch size decay & retry for single model validation')
160166

167+
parser.add_argument('--metrics-avg', type=str, default=None,
168+
choices=['micro', 'macro', 'weighted'],
169+
help='Enable precision, recall, F1-score calculation and specify the averaging method. '
170+
'Requires scikit-learn. (default: None)')
171+
161172
# NaFlex loader arguments
162173
parser.add_argument('--naflex-loader', action='store_true', default=False,
163174
help='Use NaFlex loader (Requires NaFlex compatible model)')
@@ -176,6 +187,11 @@ def validate(args):
176187

177188
device = torch.device(args.device)
178189

190+
if args.metrics_avg and not has_sklearn:
191+
_logger.warning(
192+
f"scikit-learn not installed, disabling metrics calculation. Please install with 'pip install scikit-learn'.")
193+
args.metrics_avg = None
194+
179195
model_dtype = None
180196
if args.model_dtype:
181197
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
@@ -346,6 +362,10 @@ def validate(args):
346362
top1 = AverageMeter()
347363
top5 = AverageMeter()
348364

365+
if args.metrics_avg:
366+
all_preds = []
367+
all_targets = []
368+
349369
model.eval()
350370
with torch.inference_mode():
351371
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
@@ -382,6 +402,11 @@ def validate(args):
382402
top1.update(acc1.item(), batch_size)
383403
top5.update(acc5.item(), batch_size)
384404

405+
if args.metrics_avg:
406+
predictions = torch.argmax(output, dim=1)
407+
all_preds.append(predictions.cpu())
408+
all_targets.append(target.cpu())
409+
385410
# measure elapsed time
386411
batch_time.update(time.time() - end)
387412
end = time.time()
@@ -408,18 +433,41 @@ def validate(args):
408433
top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
409434
else:
410435
top1a, top5a = top1.avg, top5.avg
436+
437+
metric_results = {}
438+
if args.metrics_avg:
439+
all_preds = torch.cat(all_preds).numpy()
440+
all_targets = torch.cat(all_targets).numpy()
441+
precision = precision_score(all_targets, all_preds, average=args.metrics_avg, zero_division=0)
442+
recall = recall_score(all_targets, all_preds, average=args.metrics_avg, zero_division=0)
443+
f1 = f1_score(all_targets, all_preds, average=args.metrics_avg, zero_division=0)
444+
metric_results = {
445+
f'{args.metrics_avg}_precision': round(precision, 4),
446+
f'{args.metrics_avg}_recall': round(recall, 4),
447+
f'{args.metrics_avg}_f1_score': round(f1, 4),
448+
}
449+
411450
results = OrderedDict(
412451
model=args.model,
413452
top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
414453
top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
454+
**metric_results,
415455
param_count=round(param_count / 1e6, 2),
416456
img_size=data_config['input_size'][-1],
417457
crop_pct=crop_pct,
418458
interpolation=data_config['interpolation'],
419459
)
420460

421-
_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
422-
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
461+
log_string = ' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
462+
results['top1'], results['top1_err'], results['top5'], results['top5_err'])
463+
if metric_results:
464+
log_string += ' | Precision({avg}) {prec:.3f} | Recall({avg}) {rec:.3f} | F1-score({avg}) {f1:.3f}'.format(
465+
avg=args.metrics_avg,
466+
prec=metric_results[f'{args.metrics_avg}_precision'],
467+
rec=metric_results[f'{args.metrics_avg}_recall'],
468+
f1=metric_results[f'{args.metrics_avg}_f1_score'],
469+
)
470+
_logger.info(log_string)
423471

424472
return results
425473

@@ -534,4 +582,4 @@ def write_results(results_file, results, format='csv'):
534582

535583

536584
if __name__ == '__main__':
537-
main()
585+
main()

0 commit comments

Comments
 (0)