File tree Expand file tree Collapse file tree 5 files changed +35
-11
lines changed Expand file tree Collapse file tree 5 files changed +35
-11
lines changed Original file line number Diff line number Diff line change @@ -39,3 +39,4 @@ architecture:
39
39
type : mse
40
40
weights : {}
41
41
reduction : sum
42
+ best_model_metric : " rmse_prod"
Original file line number Diff line number Diff line change 133
133
"log_mae" : {
134
134
"type" : " boolean"
135
135
},
136
+ "best_model_metric" : {
137
+ "type" : " string" ,
138
+ "enum" : [" rmse_prod" , " mae_prod" , " loss" ]
139
+ },
136
140
"loss" : {
137
141
"type" : " object" ,
138
142
"properties" : {
Original file line number Diff line number Diff line change 17
17
from ...utils .io import check_file_extension
18
18
from ...utils .logging import MetricLogger
19
19
from ...utils .loss import TensorMapDictLoss
20
- from ...utils .metrics import MAEAccumulator , RMSEAccumulator
20
+ from ...utils .metrics import MAEAccumulator , RMSEAccumulator , get_selected_metric
21
21
from ...utils .neighbor_lists import (
22
22
get_requested_neighbor_lists ,
23
23
get_system_with_neighbor_lists ,
@@ -430,8 +430,11 @@ def train(
430
430
patience = self .hypers ["scheduler_patience" ],
431
431
)
432
432
433
- if val_loss < self .best_loss :
434
- self .best_loss = val_loss
433
+ metric = get_selected_metric (
434
+ finalized_val_info , self .hypers ["best_model_metric" ]
435
+ )
436
+ if metric < self .best_loss :
437
+ self .best_loss = metric
435
438
self .best_model_state_dict = copy .deepcopy (
436
439
(
437
440
scripted_model .module if is_distributed else scripted_model
Original file line number Diff line number Diff line change @@ -178,3 +178,24 @@ def finalize(
178
178
finalized_info [out_key ] = value [0 ] / value [1 ]
179
179
180
180
return finalized_info
181
+
182
+
183
+ def get_selected_metric (metric_dict : Dict [str , float ], selected_metric : str ) -> float :
184
+ if selected_metric == "loss" :
185
+ metric = metric_dict ["loss" ]
186
+ elif selected_metric == "rmse_prod" :
187
+ metric = 1
188
+ for key in metric_dict :
189
+ if "RMSE" in key :
190
+ metric *= metric_dict [key ]
191
+ elif selected_metric == "mae_prod" :
192
+ metric = 1
193
+ for key in metric_dict :
194
+ if "MAE" in key :
195
+ metric *= metric_dict [key ]
196
+ else :
197
+ raise ValueError (
198
+ f"Selected metric { selected_metric } not recognized. "
199
+ "Please select from 'loss', 'rmse_prod', or 'mae_prod'."
200
+ )
201
+ return metric
Original file line number Diff line number Diff line change 1
1
seed : 42
2
2
3
3
architecture :
4
- name : experimental.phace
5
- # model:
6
- # radial_basis:
7
- # optimizable_lengthscales: true
4
+ name : experimental.soap_bpnn
8
5
training :
9
- batch_size : 8
10
- num_epochs : 10
11
- gradient_clipping : 10
12
- log_interval : 1
6
+ batch_size : 2
7
+ num_epochs : 1
13
8
14
9
training_set :
15
10
systems :
You can’t perform that action at this time.
0 commit comments