Skip to content

Commit b738755

Browse files
authored
[train] Add Torch process group shutdown timeout (#56182)
Shutting down a healthy torch process group, which we may want to do for reasons like restarting a group of workers if an async checkpoint upload fails, can hang. This is a workaround until we figure out how to avoid this hang. When this happens, `before_worker_group_shutdown` finishes after the timeout and then workers get killed by `ray.kill`: https://github.com/ray-project/ray/blob/master/python/ray/train/v2/_internal/execution/worker_group/state.py#L127. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com>
1 parent aac861a commit b738755

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

python/ray/train/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ def _v2_migration_warnings_enabled() -> bool:
125125
"TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE"
126126
)
127127

128+
# Seconds to wait for torch process group to shut down.
129+
# Shutting down a healthy torch process group, which we may want to do for reasons
130+
# like restarting a group of workers if an async checkpoint upload fails, can hang.
131+
# This is a workaround until we figure out how to avoid this hang.
132+
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S = "TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S"
133+
DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S = 30
128134

129135
# NOTE: When adding a new environment variable, please track it in this list.
130136
TRAIN_ENV_VARS = {
@@ -137,6 +143,7 @@ def _v2_migration_warnings_enabled() -> bool:
137143
RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE,
138144
RAY_TRAIN_ENABLE_STATE_TRACKING,
139145
TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE,
146+
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
140147
}
141148

142149
# Key for AIR Checkpoint metadata in TrainingResult metadata

python/ray/train/tests/test_backend.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ray.train.constants import (
2929
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
3030
ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV,
31+
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
3132
TRAIN_ENABLE_WORKER_SPREAD_ENV,
3233
)
3334
from ray.train.torch import TorchConfig
@@ -364,6 +365,24 @@ def check_process_group():
364365
assert not any(e.finish_training())
365366

366367

368+
@pytest.mark.parametrize(
369+
"init_method, timeout_s", [("env", 5), ("tcp", 5), ("env", 0), ("tcp", 0)]
370+
)
371+
def test_torch_process_group_shutdown_timeout(
372+
ray_start_2_cpus, monkeypatch, init_method, timeout_s
373+
):
374+
monkeypatch.setenv(TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S, timeout_s)
375+
torch_config = TorchConfig(backend="gloo", init_method=init_method)
376+
e = BackendExecutor(torch_config, num_workers=2)
377+
e.start()
378+
379+
_start_training(e, lambda: 1)
380+
assert e.finish_training() == [1, 1]
381+
382+
# Verify that we do not raise an exception even if we time out
383+
e._backend.on_shutdown(e.worker_group, e._backend_config)
384+
385+
367386
@pytest.mark.parametrize(
368387
"worker_results",
369388
[

python/ray/train/torch/config.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,16 @@
1010

1111
import ray
1212
from ray._common.network_utils import build_address
13+
from ray._private import ray_constants
1314
from ray.air._internal.device_manager import register_custom_torch_dist_backend
15+
from ray.exceptions import GetTimeoutError
1416
from ray.train._internal.utils import get_address_and_port
1517
from ray.train._internal.worker_group import WorkerGroup
1618
from ray.train.backend import Backend, BackendConfig
19+
from ray.train.constants import (
20+
DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
21+
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
22+
)
1723
from ray.util import PublicAPI
1824

1925
logger = logging.getLogger(__name__)
@@ -202,11 +208,21 @@ def set_env_vars(addr, port):
202208
else:
203209
raise RuntimeError("Distributed torch is not available.")
204210

205-
def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchConfig):
206-
worker_group.execute(
211+
def on_shutdown(self, worker_group, backend_config):
212+
futures = worker_group.execute_async(
207213
_shutdown_torch,
208214
destroy_process_group=len(worker_group) > 1,
209215
)
216+
timeout_s = ray_constants.env_integer(
217+
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
218+
DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
219+
)
220+
try:
221+
ray.get(futures, timeout=timeout_s)
222+
except GetTimeoutError:
223+
logger.warning(
224+
f"Torch process group shutdown timed out after {timeout_s} seconds"
225+
)
210226

211227
def on_training_start(
212228
self, worker_group: WorkerGroup, backend_config: BackendConfig

0 commit comments

Comments
 (0)