Skip to content

Commit 30b67c3

Browse files
committed
Allow different metrics than the loss for best model selection
1 parent 025c756 commit 30b67c3

File tree

5 files changed

+35
-11
lines changed

5 files changed

+35
-11
lines changed

src/metatrain/experimental/phace/default-hypers.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ architecture:
3939
type: mse
4040
weights: {}
4141
reduction: sum
42+
best_model_metric: "rmse_prod"

src/metatrain/experimental/phace/schema-hypers.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@
133133
"log_mae": {
134134
"type": "boolean"
135135
},
136+
"best_model_metric": {
137+
"type": "string",
138+
"enum": ["rmse_prod", "mae_prod", "loss"]
139+
},
136140
"loss": {
137141
"type": "object",
138142
"properties": {

src/metatrain/experimental/phace/trainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ...utils.io import check_file_extension
1818
from ...utils.logging import MetricLogger
1919
from ...utils.loss import TensorMapDictLoss
20-
from ...utils.metrics import MAEAccumulator, RMSEAccumulator
20+
from ...utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric
2121
from ...utils.neighbor_lists import (
2222
get_requested_neighbor_lists,
2323
get_system_with_neighbor_lists,
@@ -430,8 +430,11 @@ def train(
430430
patience=self.hypers["scheduler_patience"],
431431
)
432432

433-
if val_loss < self.best_loss:
434-
self.best_loss = val_loss
433+
metric = get_selected_metric(
434+
finalized_val_info, self.hypers["best_model_metric"]
435+
)
436+
if metric < self.best_loss:
437+
self.best_loss = metric
435438
self.best_model_state_dict = copy.deepcopy(
436439
(
437440
scripted_model.module if is_distributed else scripted_model

src/metatrain/utils/metrics.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,24 @@ def finalize(
178178
finalized_info[out_key] = value[0] / value[1]
179179

180180
return finalized_info
181+
182+
183+
def get_selected_metric(metric_dict: Dict[str, float], selected_metric: str) -> float:
184+
if selected_metric == "loss":
185+
metric = metric_dict["loss"]
186+
elif selected_metric == "rmse_prod":
187+
metric = 1
188+
for key in metric_dict:
189+
if "RMSE" in key:
190+
metric *= metric_dict[key]
191+
elif selected_metric == "mae_prod":
192+
metric = 1
193+
for key in metric_dict:
194+
if "MAE" in key:
195+
metric *= metric_dict[key]
196+
else:
197+
raise ValueError(
198+
f"Selected metric {selected_metric} not recognized. "
199+
"Please select from 'loss', 'rmse_prod', or 'mae_prod'."
200+
)
201+
return metric

tests/resources/options.yaml

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
seed: 42
22

33
architecture:
4-
name: experimental.phace
5-
# model:
6-
# radial_basis:
7-
# optimizable_lengthscales: true
4+
name: experimental.soap_bpnn
85
training:
9-
batch_size: 8
10-
num_epochs: 10
11-
gradient_clipping: 10
12-
log_interval: 1
6+
batch_size: 2
7+
num_epochs: 1
138

149
training_set:
1510
systems:

0 commit comments

Comments
 (0)