42
42
43
43
has_compile = hasattr (torch , 'compile' )
44
44
45
+ try :
46
+ from sklearn .metrics import precision_score , recall_score , f1_score
47
+ has_sklearn = True
48
+ except ImportError :
49
+ has_sklearn = False
50
+
45
51
_logger = logging .getLogger ('validate' )
46
52
47
53
158
164
parser .add_argument ('--retry' , default = False , action = 'store_true' ,
159
165
help = 'Enable batch size decay & retry for single model validation' )
160
166
167
+ parser .add_argument ('--metrics-avg' , type = str , default = None ,
168
+ choices = ['micro' , 'macro' , 'weighted' ],
169
+ help = 'Enable precision, recall, F1-score calculation and specify the averaging method. '
170
+ 'Requires scikit-learn. (default: None)' )
171
+
161
172
# NaFlex loader arguments
162
173
parser .add_argument ('--naflex-loader' , action = 'store_true' , default = False ,
163
174
help = 'Use NaFlex loader (Requires NaFlex compatible model)' )
@@ -176,6 +187,11 @@ def validate(args):
176
187
177
188
device = torch .device (args .device )
178
189
190
+ if args .metrics_avg and not has_sklearn :
191
+ _logger .warning (
192
+ f"scikit-learn not installed, disabling metrics calculation. Please install with 'pip install scikit-learn'." )
193
+ args .metrics_avg = None
194
+
179
195
model_dtype = None
180
196
if args .model_dtype :
181
197
assert args .model_dtype in ('float32' , 'float16' , 'bfloat16' )
@@ -346,6 +362,10 @@ def validate(args):
346
362
top1 = AverageMeter ()
347
363
top5 = AverageMeter ()
348
364
365
+ if args .metrics_avg :
366
+ all_preds = []
367
+ all_targets = []
368
+
349
369
model .eval ()
350
370
with torch .inference_mode ():
351
371
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
@@ -382,6 +402,11 @@ def validate(args):
382
402
top1 .update (acc1 .item (), batch_size )
383
403
top5 .update (acc5 .item (), batch_size )
384
404
405
+ if args .metrics_avg :
406
+ predictions = torch .argmax (output , dim = 1 )
407
+ all_preds .append (predictions .cpu ())
408
+ all_targets .append (target .cpu ())
409
+
385
410
# measure elapsed time
386
411
batch_time .update (time .time () - end )
387
412
end = time .time ()
@@ -408,18 +433,41 @@ def validate(args):
408
433
top1a , top5a = real_labels .get_accuracy (k = 1 ), real_labels .get_accuracy (k = 5 )
409
434
else :
410
435
top1a , top5a = top1 .avg , top5 .avg
436
+
437
+ metric_results = {}
438
+ if args .metrics_avg :
439
+ all_preds = torch .cat (all_preds ).numpy ()
440
+ all_targets = torch .cat (all_targets ).numpy ()
441
+ precision = precision_score (all_targets , all_preds , average = args .metrics_avg , zero_division = 0 )
442
+ recall = recall_score (all_targets , all_preds , average = args .metrics_avg , zero_division = 0 )
443
+ f1 = f1_score (all_targets , all_preds , average = args .metrics_avg , zero_division = 0 )
444
+ metric_results = {
445
+ f'{ args .metrics_avg } _precision' : round (precision , 4 ),
446
+ f'{ args .metrics_avg } _recall' : round (recall , 4 ),
447
+ f'{ args .metrics_avg } _f1_score' : round (f1 , 4 ),
448
+ }
449
+
411
450
results = OrderedDict (
412
451
model = args .model ,
413
452
top1 = round (top1a , 4 ), top1_err = round (100 - top1a , 4 ),
414
453
top5 = round (top5a , 4 ), top5_err = round (100 - top5a , 4 ),
454
+ ** metric_results ,
415
455
param_count = round (param_count / 1e6 , 2 ),
416
456
img_size = data_config ['input_size' ][- 1 ],
417
457
crop_pct = crop_pct ,
418
458
interpolation = data_config ['interpolation' ],
419
459
)
420
460
421
- _logger .info (' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})' .format (
422
- results ['top1' ], results ['top1_err' ], results ['top5' ], results ['top5_err' ]))
461
+ log_string = ' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})' .format (
462
+ results ['top1' ], results ['top1_err' ], results ['top5' ], results ['top5_err' ])
463
+ if metric_results :
464
+ log_string += ' | Precision({avg}) {prec:.3f} | Recall({avg}) {rec:.3f} | F1-score({avg}) {f1:.3f}' .format (
465
+ avg = args .metrics_avg ,
466
+ prec = metric_results [f'{ args .metrics_avg } _precision' ],
467
+ rec = metric_results [f'{ args .metrics_avg } _recall' ],
468
+ f1 = metric_results [f'{ args .metrics_avg } _f1_score' ],
469
+ )
470
+ _logger .info (log_string )
423
471
424
472
return results
425
473
@@ -534,4 +582,4 @@ def write_results(results_file, results, format='csv'):
534
582
535
583
536
584
if __name__ == '__main__' :
537
- main ()
585
+ main ()
0 commit comments