Skip to content

Commit ca80fd1

Browse files
Force block partitions materialization
1 parent a3b508b commit ca80fd1

File tree

1 file changed

+17
-29
lines changed

1 file changed

+17
-29
lines changed

modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -88,38 +88,35 @@ def __init__(
8888
if not isinstance(data, Collection) or len(data) == 1:
8989
if not isinstance(data, Collection):
9090
data = [data]
91-
self._set_data_ref(data[0]._data_ref)
91+
self._set_data_ref(data[0]._data)
9292
self._num_splits = 1
9393
self._list_of_block_partitions = data
9494
return
9595

9696
self._num_splits = len(data)
9797
self._list_of_block_partitions = data
98-
refs = [part._data_ref for part in self._list_of_block_partitions]
98+
refs = [part._data_ref for part in data]
9999

100100
if (
101101
isinstance(refs[0], _DeferredGetChunk)
102+
and isinstance(split := refs[0].data, _DeferredSplit)
102103
and (refs[0].index == 0)
103104
and all(prev.is_next_chunk(next) for prev, next in zip(refs[:-1], refs[1:]))
104105
):
105-
self._chunk_lengths_cache = (
106-
None
107-
if any(chunk.length is None for chunk in refs)
108-
else [chunk.length for chunk in refs]
109-
)
106+
if all(chunk.length is not None for chunk in refs):
107+
self._chunk_lengths_cache = [chunk.length for chunk in refs]
110108

111-
split: _DeferredSplit = refs[0].split
112-
if split.num_splits == refs[-1].index:
109+
if split.num_splits == refs[-1].index + 1:
113110
# All the partitions are the chunks of the same DataFrame. Concatenation of
114111
# all these chunks will get a df identical to the original one. Thus, we
115112
# don't need to concatenate but can get the original one instead.
116-
self._set_data_ref(split.non_split)
113+
self._set_data_ref(split.data)
117114
return
118115

119116
# TODO: We have a subset of the same frame here and can just get a single chunk
120117
# from the original frame instead of concatenating all these chunks.
121118

122-
self._set_data_ref(self._concat(refs))
119+
self._set_data_ref(self._concat([part._data for part in data]))
123120

124121
def _set_data_ref(
125122
self, data: Union[DeferredExecution, ObjectRefType]
@@ -165,16 +162,16 @@ def apply(
165162
if other_axis_partition is not None:
166163
if isinstance(other_axis_partition, Collection):
167164
if len(other_axis_partition) == 1:
168-
other_part = other_axis_partition[0]._data_ref
165+
other_part = other_axis_partition[0]._data
169166
else:
170167
concat_fn = (
171168
PandasOnRayDataframeColumnPartition
172169
if self.axis
173170
else PandasOnRayDataframeRowPartition
174171
)._concat
175-
other_part = concat_fn([p._data_ref for p in other_axis_partition])
172+
other_part = concat_fn([p._data for p in other_axis_partition])
176173
else:
177-
other_part = other_axis_partition._data_ref
174+
other_part = other_axis_partition._data
178175
args = [other_part] + list(args)
179176

180177
de = self._apply(func, args, kwargs)
@@ -224,10 +221,6 @@ def split(
224221
def _length_cache(self): # noqa: GL08
225222
return self._meta[self._meta_offset]
226223

227-
@_length_cache.setter
228-
def _length_cache(self, value): # noqa: GL08
229-
self._meta[self._meta_offset] = value
230-
231224
def length(self, materialize=True): # noqa: GL08
232225
if self._length_cache is None:
233226
self._calculate_lengths(materialize)
@@ -237,10 +230,6 @@ def length(self, materialize=True): # noqa: GL08
237230
def _width_cache(self): # noqa: GL08
238231
return self._meta[self._meta_offset + 1]
239232

240-
@_width_cache.setter
241-
def _width_cache(self, value): # noqa: GL08
242-
self._meta[self._meta_offset + 1] = value
243-
244233
def width(self, materialize=True): # noqa: GL08
245234
if self._width_cache is None:
246235
self._calculate_lengths(materialize)
@@ -417,18 +406,17 @@ def split(
417406
class _DeferredSplit(DeferredExecution): # noqa: GL08
418407
def __init__(
419408
self,
420-
non_split: ObjectRefOrDeType,
409+
obj: ObjectRefOrDeType,
421410
func: ObjectRefType,
422411
num_splits: int,
423-
lengths: Optional[List[int]],
412+
lengths: Union[List[int], None],
424413
):
425-
self.non_split = non_split
426414
self.num_splits = num_splits
427415
self.skip_chunks = set()
428416
args = [num_splits, MinPartitionSize.get(), self.skip_chunks]
429417
if lengths and (len(lengths) == num_splits):
430418
args.extend(lengths)
431-
super().__init__(non_split, func, args, num_returns=num_splits)
419+
super().__init__(obj, func, args, num_returns=num_splits)
432420

433421

434422
class _DeferredGetChunk(DeferredGetItem): # noqa: GL08
@@ -439,13 +427,13 @@ def __init__(self, split: _DeferredSplit, index: int, length: Optional[int] = No
439427

440428
def __del__(self):
441429
"""Remove this chunk from _DeferredSplit if it's not executed yet."""
442-
if self.data is self.split:
443-
self.split.skip_chunks.add(self.index)
430+
if isinstance(self.data, _DeferredSplit):
431+
self.data.skip_chunks.add(self.index)
444432

445433
def is_next_chunk(self, other): # noqa: GL08
446434
return (
447435
isinstance(other, _DeferredGetChunk)
448-
and (self.split is other.split)
436+
and (self.data is other.data)
449437
and (other.index == self.index + 1)
450438
)
451439

0 commit comments

Comments
 (0)