Skip to content

Commit 83e5c2c

Browse files
committed
Final changes
1 parent c0665f3 commit 83e5c2c

File tree

2 files changed

+5
-8
lines changed

2 files changed

+5
-8
lines changed

src/metatrain/experimental/phace/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
138138
cgs,
139139
irreps_in=equivariant_message_passer.irreps_out,
140140
# TODO: speed up with something like this?
141-
# requested_LS_string=(
142-
# "0_1" if idx == self.num_message_passing_layers - 2 else None
143-
# ),
141+
requested_LS_string=(
142+
"0_1" if idx == self.num_message_passing_layers - 2 else None
143+
),
144144
)
145145
generalized_cg_iterators.append(generalized_cg_iterator)
146146
self.equivariant_message_passers = torch.nn.ModuleList(
@@ -499,7 +499,7 @@ def _add_output(self, target_name: str, target_info: TargetInfo) -> None:
499499
# TODO: the equivariant could be a scalar...
500500
else:
501501
# specified by the user
502-
use_mlp = self.head_types[target_name] == "mlp"
502+
use_mlp = (self.head_types[target_name] == "mlp")
503503

504504
self.outputs[target_name] = ModelOutput(
505505
quantity=target_info.quantity,

src/metatrain/experimental/phace/modules/layers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,8 @@ def __init__(self, n_inputs: int) -> None:
2323
self.mlp = torch.nn.Sequential(
2424
Linear(n_inputs, 4 * n_inputs),
2525
torch.nn.SiLU(),
26-
# Linear(4*n_inputs, 4*n_inputs),
27-
# torch.nn.SiLU(),
28-
# Linear(4*n_inputs, 4*n_inputs),
29-
# torch.nn.SiLU(),
3026
Linear(4 * n_inputs, n_inputs),
27+
torch.nn.SiLU(),
3128
)
3229

3330
def forward(self, features: TensorMap) -> TensorMap:

0 commit comments

Comments
 (0)