Skip to content

Commit 6f3689a

Browse files
authored
[data][train] Create a deepcopy of the data context on the split coordinator process (#56211)
The main change of this PR is to create a deepcopy of the base dataset's context before setting the process-global context. Otherwise, mutations to the base dataset's context during the planning phase are also propagated to the global context, which can affect future dataset executions launched from the same process. Misc. drive-by changes: * Utility to create a `StorageContext` from the `RunConfig` directly * Pipe the `DatasetShardMetadata` from the outermost level among other changes, for easier patching --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com>
1 parent f038551 commit 6f3689a

File tree

10 files changed

+53
-53
lines changed

10 files changed

+53
-53
lines changed

python/ray/data/_internal/iterator/stream_split_iterator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,13 @@ def __init__(
139139
locality_hints: Optional[List[NodeIdStr]],
140140
):
141141
dataset = dataset_wrapper._dataset
142+
142143
# Set current DataContext.
143-
self._data_context = dataset.context
144+
# This needs to be a deep copy so that updates to the base dataset's
145+
# context does not affect this process's global DataContext.
146+
self._data_context = dataset.context.copy()
144147
ray.data.DataContext._set_current(self._data_context)
148+
145149
if self._data_context.execution_options.locality_with_output is True:
146150
self._data_context.execution_options.locality_with_output = locality_hints
147151
logger.info(f"Auto configuring locality_with_output={locality_hints}")

python/ray/data/iterator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ def iter_batches(
158158
local_shuffle_seed=local_shuffle_seed,
159159
)
160160

161+
def _create_batch_iterator(
162+
self, ref_bundles_iter: Iterator[RefBundle], **kwargs
163+
) -> BatchIterator:
164+
return BatchIterator(ref_bundles_iter, **kwargs)
165+
161166
def _iter_batches(
162167
self,
163168
*,
@@ -186,7 +191,7 @@ def _create_iterator() -> Iterator[DataBatch]:
186191

187192
dataset_tag = self._get_dataset_tag()
188193

189-
batch_iterator = BatchIterator(
194+
batch_iterator = self._create_batch_iterator(
190195
ref_bundles_iterator,
191196
stats=stats,
192197
dataset_tag=dataset_tag,

python/ray/train/v2/_internal/callbacks/datasets.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from ray.train.v2._internal.data_integration.interfaces import (
88
DatasetShardMetadata,
99
DatasetShardProvider,
10-
GenDataset,
1110
)
1211
from ray.train.v2._internal.execution.callback import WorkerGroupCallback
12+
from ray.train.v2._internal.execution.context import TrainRunContext
1313
from ray.train.v2._internal.execution.worker_group.worker_group import (
1414
Worker,
1515
WorkerGroup,
@@ -37,15 +37,10 @@ def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> DataIterator:
3737
class DatasetsSetupCallback(WorkerGroupCallback):
3838
"""The callback to setup Ray Datasets for the worker group."""
3939

40-
def __init__(
41-
self,
42-
datasets: Dict[str, GenDataset],
43-
data_config: ray.train.DataConfig,
44-
scaling_config: ray.train.ScalingConfig,
45-
):
46-
self._datasets = datasets
47-
self._data_config = data_config
48-
self._scaling_config = scaling_config
40+
def __init__(self, train_run_context: TrainRunContext):
41+
self._datasets = train_run_context.datasets
42+
self._data_config = copy.deepcopy(train_run_context.dataset_config)
43+
self._scaling_config = train_run_context.scaling_config
4944

5045
# Capture the current DataContext to propagate it to
5146
# the Train workers later.

python/ray/train/v2/_internal/execution/controller/controller.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
ResizeDecision,
5151
ScalingPolicy,
5252
)
53-
from ray.train.v2._internal.execution.storage import StorageContext
5453
from ray.train.v2._internal.execution.worker_group import (
5554
WorkerGroup,
5655
WorkerGroupPollStatus,
@@ -126,11 +125,7 @@ def __init__(
126125
self._failure_policy = failure_policy
127126
self._run_config = self._train_run_context.run_config
128127
self._callbacks = callbacks or []
129-
self._storage_context = StorageContext(
130-
storage_path=self._run_config.storage_path,
131-
experiment_dir_name=self._run_config.name,
132-
storage_filesystem=self._run_config.storage_filesystem,
133-
)
128+
self._storage_context = self._train_run_context.run_config.storage_context
134129

135130
self._checkpoint_manager = CheckpointManager(
136131
checkpoint_config=self._run_config.checkpoint_config,

python/ray/train/v2/_internal/execution/train_fn_utils.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING, Any, Dict, List, Optional
55

66
from ray.data import DataIterator
7+
from ray.train.v2._internal.data_integration.interfaces import DatasetShardMetadata
78
from ray.train.v2._internal.execution import collective_impl
89
from ray.train.v2._internal.execution.context import (
910
get_train_context as get_internal_train_context,
@@ -68,14 +69,11 @@ def get_all_reported_checkpoints(self) -> List["ReportedCheckpoint"]:
6869
pass
6970

7071
@abstractmethod
71-
def get_dataset_shard(self, dataset_name: str) -> DataIterator:
72+
def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> DataIterator:
7273
"""Get the dataset shard for this training process.
7374
74-
This method is used by the public API function :func:`ray.train.get_dataset_shard`.
75-
Users should typically call ``ray.train.get_dataset_shard()`` instead of calling this method directly.
76-
7775
Args:
78-
dataset_name: The name of the dataset to get the shard for.
76+
dataset_info: The metadata of the dataset to get the shard for.
7977
8078
Returns:
8179
The DataIterator shard for this worker.
@@ -131,14 +129,8 @@ def report(
131129
def get_checkpoint(self):
132130
return get_internal_train_context().get_checkpoint()
133131

134-
def get_dataset_shard(self, dataset_name: str) -> DataIterator:
135-
from ray.train.v2._internal.data_integration.interfaces import (
136-
DatasetShardMetadata,
137-
)
138-
139-
return get_internal_train_context().get_dataset_shard(
140-
DatasetShardMetadata(dataset_name=dataset_name)
141-
)
132+
def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> DataIterator:
133+
return get_internal_train_context().get_dataset_shard(dataset_info)
142134

143135
def get_context(self) -> DistributedTrainContext:
144136
return DistributedTrainContext()
@@ -182,7 +174,8 @@ def report(
182174
def get_checkpoint(self) -> Optional["Checkpoint"]:
183175
return self._last_checkpoint
184176

185-
def get_dataset_shard(self, dataset_name: str) -> DataIterator:
177+
def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> DataIterator:
178+
dataset_name = dataset_info.dataset_name
186179
assert (
187180
self._dataset_shards is not None and dataset_name in self._dataset_shards
188181
), f"Dataset shard {dataset_name} not found."

python/ray/train/v2/_internal/execution/worker_group/worker_group.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from ray.train.v2._internal.execution.checkpoint.sync_actor import SynchronizationActor
3838
from ray.train.v2._internal.execution.context import (
3939
DistributedContext,
40-
StorageContext,
4140
TrainRunContext,
4241
)
4342
from ray.train.v2._internal.execution.worker_group.poll import (
@@ -145,11 +144,7 @@ def __init__(
145144
"""
146145
self._train_run_context = train_run_context
147146
run_config = self._train_run_context.run_config
148-
self._storage_context = StorageContext(
149-
storage_path=run_config.storage_path,
150-
experiment_dir_name=run_config.name,
151-
storage_filesystem=run_config.storage_filesystem,
152-
)
147+
self._storage_context = run_config.storage_context
153148

154149
self._worker_group_context: WorkerGroupContext = worker_group_context
155150

python/ray/train/v2/api/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from dataclasses import dataclass
3+
from functools import cached_property
34
from pathlib import Path
45
from typing import TYPE_CHECKING, List, Optional, Union
56

@@ -12,6 +13,7 @@
1213
)
1314
from ray.runtime_env import RuntimeEnv
1415
from ray.train.v2._internal.constants import _DEPRECATED
16+
from ray.train.v2._internal.execution.storage import StorageContext
1517
from ray.train.v2._internal.migration_utils import (
1618
FAIL_FAST_DEPRECATION_MESSAGE,
1719
TRAINER_RESOURCES_DEPRECATION_MESSAGE,
@@ -261,3 +263,11 @@ def __post_init__(self):
261263
"See this issue for more context: "
262264
"https://github.com/ray-project/ray/issues/49454"
263265
)
266+
267+
@cached_property
268+
def storage_context(self) -> StorageContext:
269+
return StorageContext(
270+
storage_path=self.storage_path,
271+
experiment_dir_name=self.name,
272+
storage_filesystem=self.storage_filesystem,
273+
)

python/ray/train/v2/api/data_parallel_trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
TPUReservationCallback,
3131
WorkingDirectorySetupCallback,
3232
)
33-
from ray.train.v2._internal.callbacks.datasets import GenDataset
3433
from ray.train.v2._internal.callbacks.env_callback import _initialize_env_callbacks
3534
from ray.train.v2._internal.callbacks.metrics import (
3635
ControllerMetricsCallback,
@@ -42,6 +41,7 @@
4241
METRICS_ENABLED_ENV_VAR,
4342
get_env_vars_to_propagate,
4443
)
44+
from ray.train.v2._internal.data_integration.interfaces import GenDataset
4545
from ray.train.v2._internal.execution.callback import RayTrainCallback
4646
from ray.train.v2._internal.execution.context import TrainRunContext
4747
from ray.train.v2._internal.execution.controller import TrainController
@@ -164,9 +164,7 @@ def _create_default_callbacks(self) -> List[RayTrainCallback]:
164164
)
165165
backend_setup_callback = BackendSetupCallback(self.backend_config)
166166
datasets_setup_callback = DatasetsSetupCallback(
167-
datasets=self.datasets,
168-
data_config=self.data_config,
169-
scaling_config=self.scaling_config,
167+
train_run_context=self.train_run_context
170168
)
171169
tpu_reservation_setup_callback = TPUReservationCallback()
172170
callbacks.extend(

python/ray/train/v2/api/train_fn_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import TYPE_CHECKING, Any, Dict, List, Optional
22

3+
from ray.train.v2._internal.data_integration.interfaces import DatasetShardMetadata
34
from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
45
from ray.train.v2.api.context import TrainContext
56
from ray.util.annotations import PublicAPI
@@ -241,4 +242,6 @@ def train_loop_per_worker(config):
241242
The ``DataIterator`` shard to use for this worker.
242243
If no dataset is passed into Trainer, then return None.
243244
"""
244-
return get_train_fn_utils().get_dataset_shard(dataset_name)
245+
return get_train_fn_utils().get_dataset_shard(
246+
DatasetShardMetadata(dataset_name=dataset_name)
247+
)

python/ray/train/v2/tests/test_data_integration.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from unittest.mock import MagicMock
2-
31
import pytest
42

53
import ray.data
@@ -9,12 +7,15 @@
97
from ray.data.tests.conftest import restore_data_context # noqa: F401
108
from ray.train.v2._internal.callbacks.datasets import DatasetsSetupCallback
119
from ray.train.v2._internal.data_integration.interfaces import DatasetShardMetadata
12-
from ray.train.v2._internal.execution.context import TrainRunContext
1310
from ray.train.v2._internal.execution.worker_group.worker_group import (
1411
WorkerGroupContext,
1512
)
1613
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
17-
from ray.train.v2.tests.util import DummyObjectRefWrapper, DummyWorkerGroup
14+
from ray.train.v2.tests.util import (
15+
DummyObjectRefWrapper,
16+
DummyWorkerGroup,
17+
create_dummy_run_context,
18+
)
1819

1920
# TODO(justinvyu): Bring over more tests from ray/air/tests/test_new_dataset_config.py
2021

@@ -77,17 +78,18 @@ def test_dataset_setup_callback(ray_start_4_cpus):
7778
num_workers=scaling_config.num_workers,
7879
resources_per_worker=scaling_config.resources_per_worker,
7980
)
81+
train_run_context = create_dummy_run_context(
82+
datasets={"train": train_ds, "valid": valid_ds},
83+
dataset_config=data_config,
84+
scaling_config=scaling_config,
85+
)
8086
worker_group = DummyWorkerGroup(
81-
train_run_context=MagicMock(spec=TrainRunContext),
87+
train_run_context=train_run_context,
8288
worker_group_context=worker_group_context,
8389
)
8490
worker_group._start()
8591

86-
callback = DatasetsSetupCallback(
87-
datasets={"train": train_ds, "valid": valid_ds},
88-
data_config=data_config,
89-
scaling_config=scaling_config,
90-
)
92+
callback = DatasetsSetupCallback(train_run_context)
9193
dataset_manager_for_each_worker = callback.before_init_train_context(
9294
worker_group.get_workers()
9395
)["dataset_shard_provider"]

0 commit comments

Comments
 (0)