Skip to content

Commit ebb35b2

Browse files
committed
Suggestions from review
1 parent da82cc8 commit ebb35b2

File tree

10 files changed

+26
-29
lines changed

10 files changed

+26
-29
lines changed

src/metatrain/cli/train.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -422,14 +422,17 @@ def train_model(
422422
###########################
423423

424424
logger.info("Calling trainer")
425-
trainer.train(
426-
model=model,
427-
dtype=dtype,
428-
devices=devices,
429-
train_datasets=train_datasets,
430-
val_datasets=val_datasets,
431-
checkpoint_dir=str(checkpoint_dir),
432-
)
425+
try:
426+
trainer.train(
427+
model=model,
428+
dtype=dtype,
429+
devices=devices,
430+
train_datasets=train_datasets,
431+
val_datasets=val_datasets,
432+
checkpoint_dir=str(checkpoint_dir),
433+
)
434+
except Exception as e:
435+
raise ArchitectureError(e)
433436

434437
if not is_main_process():
435438
return # only save and evaluate on the main process

src/metatrain/experimental/phace/default-hypers.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@ architecture:
22
name: experimental.phace
33

44
model:
5-
nu_max: 3
5+
max_correlation_order_per_layer: 3
66
num_message_passing_layers: 2
77
cutoff: 5.0
88
cutoff_width: 1.0
99
num_element_channels: 64
1010
radial_basis:
1111
mlp: true
12-
E_max: 50.0
12+
max_eigenvalue: 50.0
1313
scale: 0.7
1414
optimizable_lengthscales: false
1515
nu_scaling: 0.1

src/metatrain/experimental/phace/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from .modules.cg import get_cg_coefficients
2424
from .modules.cg_iterator import CGIterator
2525
from .modules.initial_features import get_initial_features
26-
from .modules.layers import EquivariantLastLayer, InvariantMLP, NothingLayer
26+
from .modules.layers import EquivariantLastLayer, Identity, InvariantMLP
2727
from .modules.message_passing import EquivariantMessagePasser, InvariantMessagePasser
2828
from .modules.precomputations import Precomputer
2929
from .utils import systems_to_batch
@@ -74,7 +74,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
7474
)
7575
self.register_buffer("species_to_species_index", species_to_species_index)
7676

77-
self.nu_max = model_hypers["nu_max"]
77+
self.nu_max = model_hypers["max_correlation_order_per_layer"]
7878
self.num_message_passing_layers = model_hypers["num_message_passing_layers"]
7979
if self.num_message_passing_layers < 1:
8080
raise ValueError("Number of message-passing layers must be at least 1")
@@ -521,7 +521,7 @@ def _add_output(self, target_name: str, target_info: TargetInfo) -> None:
521521
self.k_max_l[0], self.head_num_layers
522522
)
523523
else:
524-
self.heads[target_name] = NothingLayer()
524+
self.heads[target_name] = Identity()
525525

526526
if target_info.is_scalar:
527527
self.last_layers[target_name] = EquivariantLastLayer(

src/metatrain/experimental/phace/modules/center_embedding.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import List
22

33
import torch
4-
from metatensor.torch import Labels, TensorBlock, TensorMap
4+
from metatensor.torch import TensorBlock, TensorMap
55

66

77
def embed_centers(equivariants: TensorMap, center_embeddings: torch.Tensor):
@@ -29,9 +29,6 @@ def embed_centers(equivariants: TensorMap, center_embeddings: torch.Tensor):
2929
)
3030

3131
return TensorMap(
32-
keys=Labels(
33-
names=["nu", "o3_lambda", "o3_sigma"],
34-
values=torch.stack(keys).to(equivariants.keys.values.device),
35-
),
32+
keys=equivariants.keys,
3633
blocks=blocks,
3734
)

src/metatrain/experimental/phace/modules/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def forward(self, features: TensorMap) -> TensorMap:
179179
return TensorMap(keys=new_keys, blocks=new_blocks)
180180

181181

182-
class NothingLayer(torch.nn.Module):
182+
class Identity(torch.nn.Module):
183183
# useful when the head for an output is a simple linear layer
184184

185185
def forward(self, features: TensorMap) -> TensorMap:

src/metatrain/experimental/phace/modules/message_passing.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@ class DummyAdder(torch.nn.Module):
1515
def __init__(self):
1616
super().__init__()
1717

18-
def forward(
19-
self, tmap_1: metatensor.torch.TensorMap, tmap_2: metatensor.torch.TensorMap
20-
) -> metatensor.torch.TensorMap:
21-
return metatensor.torch.TensorMap(
18+
def forward(self, tmap_1: TensorMap, tmap_2: TensorMap) -> TensorMap:
19+
return TensorMap(
2220
keys=Labels(names=["dummy"], values=torch.empty(1, 1)), blocks=[]
2321
)
2422

src/metatrain/experimental/phace/modules/physical_basis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def get_physical_basis_spliner(E_max, r_cut, normalize):
2929
for l in range(l_max + 1): # noqa: E741
3030
n_max_l.append(np.where(E_nl[:, l] <= E_max)[0][-1] + 1)
3131
if n_max_l[0] > n_max:
32-
raise ValueError("n_max too large, try decreasing E_max")
32+
raise ValueError("n_max too large, try decreasing max_eigenvalue")
3333

3434
def function_for_splining(n, l, x): # noqa: E741
3535
ret = physical_basis.compute(n, l, x)

src/metatrain/experimental/phace/modules/radial_basis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, hypers, all_species) -> None:
1616
for species in all_species:
1717
lengthscales[species] = np.log(hypers["scale"] * covalent_radii[species])
1818
self.n_max_l, self.spliner = get_physical_basis_spliner(
19-
hypers["E_max"], hypers["cutoff"], normalize=True
19+
hypers["max_eigenvalue"], hypers["cutoff"], normalize=True
2020
)
2121
if hypers["optimizable_lengthscales"]:
2222
self.lengthscales = torch.nn.Parameter(lengthscales)

src/metatrain/experimental/phace/schema-hypers.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"model": {
1010
"type": "object",
1111
"properties": {
12-
"nu_max": {
12+
"max_correlation_order_per_layer": {
1313
"type": "integer"
1414
},
1515
"num_message_passing_layers": {
@@ -30,7 +30,7 @@
3030
"mlp": {
3131
"type": "boolean"
3232
},
33-
"E_max": {
33+
"max_eigenvalue": {
3434
"type": "number"
3535
},
3636
"scale": {

tox.ini

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,10 @@ commands =
137137
pytest {posargs}
138138

139139
[testenv:phace-tests]
140-
description = Run NanoPET tests with pytest
140+
description = Run PhACE tests with pytest
141141
passenv = *
142142
deps =
143143
pytest
144-
spherical # for nanoPET spherical target
145144
extras = phace
146145
changedir = src/metatrain/experimental/phace/tests/
147146
commands =

0 commit comments

Comments
 (0)