@@ -88,38 +88,35 @@ def __init__(
88
88
if not isinstance (data , Collection ) or len (data ) == 1 :
89
89
if not isinstance (data , Collection ):
90
90
data = [data ]
91
- self ._set_data_ref (data [0 ]._data_ref )
91
+ self ._set_data_ref (data [0 ]._data )
92
92
self ._num_splits = 1
93
93
self ._list_of_block_partitions = data
94
94
return
95
95
96
96
self ._num_splits = len (data )
97
97
self ._list_of_block_partitions = data
98
- refs = [part ._data_ref for part in self . _list_of_block_partitions ]
98
+ refs = [part ._data_ref for part in data ]
99
99
100
100
if (
101
101
isinstance (refs [0 ], _DeferredGetChunk )
102
+ and isinstance (split := refs [0 ].data , _DeferredSplit )
102
103
and (refs [0 ].index == 0 )
103
104
and all (prev .is_next_chunk (next ) for prev , next in zip (refs [:- 1 ], refs [1 :]))
104
105
):
105
- self ._chunk_lengths_cache = (
106
- None
107
- if any (chunk .length is None for chunk in refs )
108
- else [chunk .length for chunk in refs ]
109
- )
106
+ if all (chunk .length is not None for chunk in refs ):
107
+ self ._chunk_lengths_cache = [chunk .length for chunk in refs ]
110
108
111
- split : _DeferredSplit = refs [0 ].split
112
- if split .num_splits == refs [- 1 ].index :
109
+ if split .num_splits == refs [- 1 ].index + 1 :
113
110
# All the partitions are the chunks of the same DataFrame. Concatenation of
114
111
# all these chunks will get a df identical to the original one. Thus, we
115
112
# don't need to concatenate but can get the original one instead.
116
- self ._set_data_ref (split .non_split )
113
+ self ._set_data_ref (split .data )
117
114
return
118
115
119
116
# TODO: We have a subset of the same frame here and can just get a single chunk
120
117
# from the original frame instead of concatenating all these chunks.
121
118
122
- self ._set_data_ref (self ._concat (refs ))
119
+ self ._set_data_ref (self ._concat ([ part . _data for part in data ] ))
123
120
124
121
def _set_data_ref (
125
122
self , data : Union [DeferredExecution , ObjectRefType ]
@@ -165,16 +162,16 @@ def apply(
165
162
if other_axis_partition is not None :
166
163
if isinstance (other_axis_partition , Collection ):
167
164
if len (other_axis_partition ) == 1 :
168
- other_part = other_axis_partition [0 ]._data_ref
165
+ other_part = other_axis_partition [0 ]._data
169
166
else :
170
167
concat_fn = (
171
168
PandasOnRayDataframeColumnPartition
172
169
if self .axis
173
170
else PandasOnRayDataframeRowPartition
174
171
)._concat
175
- other_part = concat_fn ([p ._data_ref for p in other_axis_partition ])
172
+ other_part = concat_fn ([p ._data for p in other_axis_partition ])
176
173
else :
177
- other_part = other_axis_partition ._data_ref
174
+ other_part = other_axis_partition ._data
178
175
args = [other_part ] + list (args )
179
176
180
177
de = self ._apply (func , args , kwargs )
@@ -224,10 +221,6 @@ def split(
224
221
def _length_cache (self ): # noqa: GL08
225
222
return self ._meta [self ._meta_offset ]
226
223
227
- @_length_cache .setter
228
- def _length_cache (self , value ): # noqa: GL08
229
- self ._meta [self ._meta_offset ] = value
230
-
231
224
def length (self , materialize = True ): # noqa: GL08
232
225
if self ._length_cache is None :
233
226
self ._calculate_lengths (materialize )
@@ -237,10 +230,6 @@ def length(self, materialize=True): # noqa: GL08
237
230
def _width_cache (self ): # noqa: GL08
238
231
return self ._meta [self ._meta_offset + 1 ]
239
232
240
- @_width_cache .setter
241
- def _width_cache (self , value ): # noqa: GL08
242
- self ._meta [self ._meta_offset + 1 ] = value
243
-
244
233
def width (self , materialize = True ): # noqa: GL08
245
234
if self ._width_cache is None :
246
235
self ._calculate_lengths (materialize )
@@ -417,18 +406,17 @@ def split(
417
406
class _DeferredSplit (DeferredExecution ): # noqa: GL08
418
407
def __init__ (
419
408
self ,
420
- non_split : ObjectRefOrDeType ,
409
+ obj : ObjectRefOrDeType ,
421
410
func : ObjectRefType ,
422
411
num_splits : int ,
423
- lengths : Optional [List [int ]],
412
+ lengths : Union [List [int ], None ],
424
413
):
425
- self .non_split = non_split
426
414
self .num_splits = num_splits
427
415
self .skip_chunks = set ()
428
416
args = [num_splits , MinPartitionSize .get (), self .skip_chunks ]
429
417
if lengths and (len (lengths ) == num_splits ):
430
418
args .extend (lengths )
431
- super ().__init__ (non_split , func , args , num_returns = num_splits )
419
+ super ().__init__ (obj , func , args , num_returns = num_splits )
432
420
433
421
434
422
class _DeferredGetChunk (DeferredGetItem ): # noqa: GL08
@@ -439,13 +427,13 @@ def __init__(self, split: _DeferredSplit, index: int, length: Optional[int] = No
439
427
440
428
def __del__ (self ):
441
429
"""Remove this chunk from _DeferredSplit if it's not executed yet."""
442
- if self .data is self . split :
443
- self .split .skip_chunks .add (self .index )
430
+ if isinstance ( self .data , _DeferredSplit ) :
431
+ self .data .skip_chunks .add (self .index )
444
432
445
433
def is_next_chunk (self , other ): # noqa: GL08
446
434
return (
447
435
isinstance (other , _DeferredGetChunk )
448
- and (self .split is other .split )
436
+ and (self .data is other .data )
449
437
and (other .index == self .index + 1 )
450
438
)
451
439
0 commit comments