Skip to content

Commit 63d6ac3

Browse files
committed
Suggestions from review
1 parent da82cc8 commit 63d6ac3

File tree

7 files changed

+19
-17
lines changed

7 files changed

+19
-17
lines changed

src/metatrain/cli/train.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -422,14 +422,17 @@ def train_model(
422422
###########################
423423

424424
logger.info("Calling trainer")
425-
trainer.train(
426-
model=model,
427-
dtype=dtype,
428-
devices=devices,
429-
train_datasets=train_datasets,
430-
val_datasets=val_datasets,
431-
checkpoint_dir=str(checkpoint_dir),
432-
)
425+
try:
426+
trainer.train(
427+
model=model,
428+
dtype=dtype,
429+
devices=devices,
430+
train_datasets=train_datasets,
431+
val_datasets=val_datasets,
432+
checkpoint_dir=str(checkpoint_dir),
433+
)
434+
except Exception as e:
435+
raise ArchitectureError(e)
433436

434437
if not is_main_process():
435438
return # only save and evaluate on the main process

src/metatrain/experimental/phace/default-hypers.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@ architecture:
22
name: experimental.phace
33

44
model:
5-
nu_max: 3
5+
max_correlation_order_per_layer: 3
66
num_message_passing_layers: 2
77
cutoff: 5.0
88
cutoff_width: 1.0
99
num_element_channels: 64
1010
radial_basis:
1111
mlp: true
12-
E_max: 50.0
12+
max_eigenvalue: 50.0
1313
scale: 0.7
1414
optimizable_lengthscales: false
1515
nu_scaling: 0.1

src/metatrain/experimental/phace/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
7474
)
7575
self.register_buffer("species_to_species_index", species_to_species_index)
7676

77-
self.nu_max = model_hypers["nu_max"]
77+
self.nu_max = model_hypers["max_correlation_order_per_layer"]
7878
self.num_message_passing_layers = model_hypers["num_message_passing_layers"]
7979
if self.num_message_passing_layers < 1:
8080
raise ValueError("Number of message-passing layers must be at least 1")

src/metatrain/experimental/phace/modules/physical_basis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def get_physical_basis_spliner(E_max, r_cut, normalize):
2929
for l in range(l_max + 1): # noqa: E741
3030
n_max_l.append(np.where(E_nl[:, l] <= E_max)[0][-1] + 1)
3131
if n_max_l[0] > n_max:
32-
raise ValueError("n_max too large, try decreasing E_max")
32+
raise ValueError("n_max too large, try decreasing max_eigenvalue")
3333

3434
def function_for_splining(n, l, x): # noqa: E741
3535
ret = physical_basis.compute(n, l, x)

src/metatrain/experimental/phace/modules/radial_basis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, hypers, all_species) -> None:
1616
for species in all_species:
1717
lengthscales[species] = np.log(hypers["scale"] * covalent_radii[species])
1818
self.n_max_l, self.spliner = get_physical_basis_spliner(
19-
hypers["E_max"], hypers["cutoff"], normalize=True
19+
hypers["max_eigenvalue"], hypers["cutoff"], normalize=True
2020
)
2121
if hypers["optimizable_lengthscales"]:
2222
self.lengthscales = torch.nn.Parameter(lengthscales)

src/metatrain/experimental/phace/schema-hypers.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"model": {
1010
"type": "object",
1111
"properties": {
12-
"nu_max": {
12+
"max_correlation_order_per_layer": {
1313
"type": "integer"
1414
},
1515
"num_message_passing_layers": {
@@ -30,7 +30,7 @@
3030
"mlp": {
3131
"type": "boolean"
3232
},
33-
"E_max": {
33+
"max_eigenvalue": {
3434
"type": "number"
3535
},
3636
"scale": {

tox.ini

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,10 @@ commands =
137137
pytest {posargs}
138138

139139
[testenv:phace-tests]
140-
description = Run NanoPET tests with pytest
140+
description = Run PhACE tests with pytest
141141
passenv = *
142142
deps =
143143
pytest
144-
spherical # for nanoPET spherical target
145144
extras = phace
146145
changedir = src/metatrain/experimental/phace/tests/
147146
commands =

0 commit comments

Comments
 (0)