Skip to content

Commit 6ad1dd6

Browse files
authored
Support pickleability of classes with callable args (#728)
1 parent b111f10 commit 6ad1dd6

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

CHANGELOG.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@ The semantic versioning only considers the public API as described in
1111
:ref:`api-ref`. Components not mentioned in :ref:`api-ref` or different import
1212
paths are considered internals and can change in minor and patch releases.
1313

14-
1514
v4.40.1 (2025-05-??)
1615
--------------------
1716

1817
Fixed
1918
^^^^^
2019
- ``print_shtab`` incorrectly parsed from environment variable (`#725
2120
<https://github.com/omni-us/jsonargparse/pull/725>`__).
21+
- ``adapt_class_type`` used a locally defined `partial_instance` wrapper
22+
function that is not pickleable (`#728
23+
<https://github.com/omni-us/jsonargparse/pull/728>`__).
2224

2325

2426
v4.40.0 (2025-05-16)

jsonargparse/_typehints.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,11 +1424,11 @@ def adapt_class_type(
14241424
instantiator_fn = get_class_instantiator()
14251425

14261426
if partial_classes:
1427-
1428-
def partial_instance(*args):
1429-
return instantiator_fn(val_class, *args, **{**init_args, **dict_kwargs})
1430-
1431-
return partial_instance
1427+
return partial(
1428+
instantiator_fn,
1429+
val_class,
1430+
**{**init_args, **dict_kwargs},
1431+
)
14321432
return instantiator_fn(val_class, **{**init_args, **dict_kwargs})
14331433

14341434
prev_init_args = prev_val.get("init_args") if isinstance(prev_val, Namespace) else None

jsonargparse_tests/test_typehints.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,6 +1239,21 @@ def test_callable_args_return_type_class_subconfig(parser, tmp_cwd):
12391239
assert optimizer.momentum == 0.8
12401240

12411241

1242+
def test_callable_args_pickleable(parser, tmp_cwd):
1243+
config = {
1244+
"class_path": "Adam",
1245+
"init_args": {"momentum": 0.8},
1246+
}
1247+
Path("optimizer.yaml").write_text(json_or_yaml_dump(config))
1248+
parser.add_class_arguments(CallableSubconfig, "m", sub_configs=True)
1249+
cfg = parser.parse_args(["--m.o=optimizer.yaml"])
1250+
init = parser.instantiate_classes(cfg)
1251+
1252+
filepath = str(tmp_cwd) + "/pickled.pkl"
1253+
with open(filepath, "wb") as f:
1254+
pickle.dump(init, f)
1255+
1256+
12421257
class Module:
12431258
pass
12441259

0 commit comments

Comments
 (0)