Skip to content

Commit 840971f

Browse files
author
Joost van Amersfoort
authored
Properly register parameters of LU decomposition
Currently when `lu_decomposed` is set to `True`, loading from a checkpoint doesn't work. This PR fixes that by properly registering them as a `nn.Parameter` so they will be part of the state dict.
1 parent 8f35267 commit 840971f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

glow/modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ def __init__(self, num_channels, LU_decomposed=False):
208208
l_mask = np.tril(np.ones(w_shape, dtype=np.float32), -1)
209209
eye = np.eye(*w_shape, dtype=np.float32)
210210

211-
self.p = torch.Tensor(np_p.astype(np.float32))
212-
self.sign_s = torch.Tensor(np_sign_s.astype(np.float32))
211+
self.p = nn.Parameter(torch.Tensor(np_p.astype(np.float32)))
212+
self.sign_s = nn.Parameter(torch.Tensor(np_sign_s.astype(np.float32)))
213213
self.l = nn.Parameter(torch.Tensor(np_l.astype(np.float32)))
214214
self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(np.float32)))
215215
self.u = nn.Parameter(torch.Tensor(np_u.astype(np.float32)))

0 commit comments

Comments
 (0)