|
4 | 4 | from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
5 | 5 |
|
6 | 6 | from ray.data import DataIterator
|
| 7 | +from ray.train.v2._internal.data_integration.interfaces import DatasetShardMetadata |
7 | 8 | from ray.train.v2._internal.execution import collective_impl
|
8 | 9 | from ray.train.v2._internal.execution.context import (
|
9 | 10 | get_train_context as get_internal_train_context,
|
@@ -68,14 +69,11 @@ def get_all_reported_checkpoints(self) -> List["ReportedCheckpoint"]:
|
68 | 69 | pass
|
69 | 70 |
|
70 | 71 | @abstractmethod
|
71 |
| - def get_dataset_shard(self, dataset_name: str) -> DataIterator: |
| 72 | + def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> DataIterator: |
72 | 73 | """Get the dataset shard for this training process.
|
73 | 74 |
|
74 |
| - This method is used by the public API function :func:`ray.train.get_dataset_shard`. |
75 |
| - Users should typically call ``ray.train.get_dataset_shard()`` instead of calling this method directly. |
76 |
| -
|
77 | 75 | Args:
|
78 |
| - dataset_name: The name of the dataset to get the shard for. |
| 76 | + dataset_info: The metadata of the dataset to get the shard for. |
79 | 77 |
|
80 | 78 | Returns:
|
81 | 79 | The DataIterator shard for this worker.
|
@@ -131,14 +129,8 @@ def report(
|
131 | 129 | def get_checkpoint(self):
|
132 | 130 | return get_internal_train_context().get_checkpoint()
|
133 | 131 |
|
134 |
| - def get_dataset_shard(self, dataset_name: str) -> DataIterator: |
135 |
| - from ray.train.v2._internal.data_integration.interfaces import ( |
136 |
| - DatasetShardMetadata, |
137 |
| - ) |
138 |
| - |
139 |
| - return get_internal_train_context().get_dataset_shard( |
140 |
| - DatasetShardMetadata(dataset_name=dataset_name) |
141 |
| - ) |
| 132 | + def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> DataIterator: |
| 133 | + return get_internal_train_context().get_dataset_shard(dataset_info) |
142 | 134 |
|
143 | 135 | def get_context(self) -> DistributedTrainContext:
|
144 | 136 | return DistributedTrainContext()
|
@@ -182,7 +174,8 @@ def report(
|
182 | 174 | def get_checkpoint(self) -> Optional["Checkpoint"]:
|
183 | 175 | return self._last_checkpoint
|
184 | 176 |
|
185 |
| - def get_dataset_shard(self, dataset_name: str) -> DataIterator: |
| 177 | + def get_dataset_shard(self, dataset_info: DatasetShardMetadata) -> DataIterator: |
| 178 | + dataset_name = dataset_info.dataset_name |
186 | 179 | assert (
|
187 | 180 | self._dataset_shards is not None and dataset_name in self._dataset_shards
|
188 | 181 | ), f"Dataset shard {dataset_name} not found."
|
|
0 commit comments