Skip to content

Commit 128509d

Browse files
perf
1 parent ca80fd1 commit 128509d

File tree

4 files changed

+162
-103
lines changed

4 files changed

+162
-103
lines changed

modin/core/execution/ray/common/deferred_execution.py

Lines changed: 36 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -161,24 +161,6 @@ def exec(
161161
and self.flat_kwargs
162162
and self.num_returns == 1
163163
):
164-
# self.data = RayWrapper.materialize(self.data)
165-
# self.args = [
166-
# RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
167-
# for o in self.args
168-
# ]
169-
# self.kwargs = {
170-
# k: RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
171-
# for k, o in self.kwargs.items()
172-
# }
173-
# obj = _REMOTE_EXEC.exec_func(
174-
# RayWrapper.materialize(self.func), self.data, self.args, self.kwargs
175-
# )
176-
# result, length, width, ip = (
177-
# obj,
178-
# len(obj) if hasattr(obj, "__len__") else 0,
179-
# len(obj.columns) if hasattr(obj, "columns") else 0,
180-
# "",
181-
# )
182164
result, length, width, ip = remote_exec_func.remote(
183165
self.func, self.data, *self.args, **self.kwargs
184166
)
@@ -191,13 +173,6 @@ def exec(
191173
self.subscribers += 2
192174
consumers, output = self._deconstruct()
193175

194-
# assert not any(isinstance(o, ListOrTuple) for o in output)
195-
# tmp = [
196-
# RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
197-
# for o in output
198-
# ]
199-
# list(_REMOTE_EXEC.construct(tmp))
200-
201176
# The last result is the MetaList, so adding +1 here.
202177
num_returns = sum(c.num_returns for c in consumers) + 1
203178
results = self._remote_exec_chain(num_returns, *output)
@@ -336,7 +311,9 @@ def _deconstruct_chain(
336311
out_extend = output.extend
337312
while True:
338313
de.unsubscribe()
339-
if not de.has_result and (out_pos := getattr(de, "out_pos", None)):
314+
if not (has_result := de.has_result) and (
315+
out_pos := getattr(de, "out_pos", None)
316+
):
340317
out_append(_Tag.REF)
341318
out_append(out_pos)
342319
output[out_pos] = out_pos
@@ -357,7 +334,7 @@ def _deconstruct_chain(
357334
)
358335
else:
359336
out_append(data)
360-
if not de.has_result:
337+
if not has_result:
361338
stack.append(de)
362339
break
363340
else:
@@ -425,28 +402,24 @@ def _deconstruct_list(
425402
"""
426403
for obj in lst:
427404
if isinstance(obj, DeferredExecution):
428-
if out_pos := getattr(obj, "out_pos", None):
405+
if obj.has_result:
406+
obj = obj.data
407+
elif out_pos := getattr(obj, "out_pos", None):
429408
obj.unsubscribe()
430-
if obj.has_result:
431-
if isinstance(obj.data, ListOrTuple):
432-
out_append(_Tag.LIST)
433-
yield cls._deconstruct_list(
434-
obj.data, output, stack, result_consumers, out_append
435-
)
436-
else:
437-
out_append(obj.data)
438-
else:
439-
out_append(_Tag.REF)
440-
out_append(out_pos)
441-
output[out_pos] = out_pos
442-
if obj.subscribers == 0:
443-
output[out_pos + 1] = 0
444-
result_consumers.remove(obj)
409+
out_append(_Tag.REF)
410+
out_append(out_pos)
411+
output[out_pos] = out_pos
412+
if obj.subscribers == 0:
413+
output[out_pos + 1] = 0
414+
result_consumers.remove(obj)
415+
continue
445416
else:
446417
out_append(_Tag.CHAIN)
447418
yield cls._deconstruct_chain(obj, output, stack, result_consumers)
448419
out_append(_Tag.END)
449-
elif isinstance(obj, ListOrTuple):
420+
continue
421+
422+
if isinstance(obj, ListOrTuple):
450423
out_append(_Tag.LIST)
451424
yield cls._deconstruct_list(
452425
obj, output, stack, result_consumers, out_append
@@ -517,27 +490,13 @@ class DeferredGetItem(DeferredExecution):
517490
----------
518491
data : ObjectRefOrDeType
519492
The object to get the item from.
520-
idx : int
493+
index : int
521494
The item index.
522495
"""
523496

524-
def __init__(self, data: ObjectRefOrDeType, idx: int):
525-
super().__init__(data, self._remote_fn(), [idx])
526-
self.index = idx
527-
528-
@_inherit_docstrings(DeferredExecution.exec)
529-
def exec(self) -> Tuple[ObjectRefType, "MetaList", int]:
530-
if self.has_result:
531-
return self.data, self.meta, self.meta_offset
532-
533-
if not isinstance(self.data, DeferredExecution) or self.data.num_returns == 1:
534-
return super().exec()
535-
536-
# If `data` is a `DeferredExecution`, that returns multiple results,
537-
# it's not required to execute `_remote_fn()`. We can only execute
538-
# `data` and get the result by index.
539-
self._data_exec()
540-
return self.data, self.meta, self.meta_offset
497+
def __init__(self, data: ObjectRefOrDeType, index: int):
498+
super().__init__(data, self._remote_fn(), [index])
499+
self.index = index
541500

542501
@property
543502
@_inherit_docstrings(DeferredExecution.has_result)
@@ -550,16 +509,18 @@ def has_result(self):
550509
and self.data.has_result
551510
and self.data.num_returns != 1
552511
):
553-
self._data_exec()
512+
# If `data` is a `DeferredExecution`, that returns multiple results,
513+
# it's not required to execute `_remote_fn()`. We can only execute
514+
# `data` and get the result by index.
515+
self._set_result(
516+
self.data.data[self.index],
517+
self.data.meta,
518+
self.data.meta_offset[self.index],
519+
)
554520
return True
555521

556522
return False
557523

558-
def _data_exec(self):
559-
"""Execute the `data` task and get the result."""
560-
obj, meta, offsets = self.data.exec()
561-
self._set_result(obj[self.index], meta, offsets[self.index])
562-
563524
@classmethod
564525
def _remote_fn(cls) -> ObjectRefType:
565526
"""
@@ -592,7 +553,8 @@ def __init__(self, obj: Union[ray.ObjectID, ClientObjectRef, List]):
592553

593554
def materialize(self):
594555
"""Materialized the list, if required."""
595-
self._obj = RayWrapper.materialize(self._obj)
556+
if not isinstance(self._obj, list):
557+
self._obj = RayWrapper.materialize(self._obj)
596558

597559
def __getitem__(self, index):
598560
"""
@@ -632,14 +594,13 @@ class MetaListHook(MaterializationHook, DeferredGetItem):
632594
----------
633595
meta : MetaList
634596
Non-materialized list to get the value from.
635-
idx : int
597+
index : int
636598
The value index in the list.
637599
"""
638600

639-
def __init__(self, meta: MetaList, idx: int):
640-
super().__init__(meta._obj, idx)
601+
def __init__(self, meta: MetaList, index: int):
602+
super().__init__(meta._obj, index)
641603
self.meta = meta
642-
self.idx = idx
643604

644605
def pre_materialize(self):
645606
"""
@@ -650,7 +611,7 @@ def pre_materialize(self):
650611
object
651612
"""
652613
obj = self.meta._obj
653-
return obj[self.idx] if isinstance(obj, list) else obj
614+
return obj[self.index] if isinstance(obj, list) else obj
654615

655616
def post_materialize(self, materialized):
656617
"""
@@ -665,7 +626,7 @@ def post_materialize(self, materialized):
665626
object
666627
"""
667628
self.meta._obj = materialized
668-
return materialized[self.idx]
629+
return materialized[self.index]
669630

670631

671632
class _Tag(Enum): # noqa: PR01

modin/core/execution/ray/common/engine_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def materialize(cls, obj_id):
106106
107107
Parameters
108108
----------
109-
obj_id : ray.ObjectID
109+
obj_id : ObjectRefTypes
110110
Ray object identifier to get the value by.
111111
112112
Returns

modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition_manager.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
# governing permissions and limitations under the License.
1313

1414
"""Module houses class that implements ``GenericRayDataframePartitionManager`` using Ray."""
15+
import math
1516

1617
import numpy as np
18+
import pandas
1719
from pandas.core.dtypes.common import is_numeric_dtype
1820

19-
from modin.config import AsyncReadMode
21+
from modin.config import AsyncReadMode, MinPartitionSize
2022
from modin.core.execution.modin_aqp import progress_bar_wrapper
2123
from modin.core.execution.ray.common import RayWrapper
24+
from modin.core.execution.ray.common.deferred_execution import DeferredExecution
2225
from modin.core.execution.ray.generic.partitioning import (
2326
GenericRayDataframePartitionManager,
2427
)
@@ -29,6 +32,7 @@
2932
from .virtual_partition import (
3033
PandasOnRayDataframeColumnPartition,
3134
PandasOnRayDataframeRowPartition,
35+
PandasOnRayDataframeVirtualPartition,
3236
)
3337

3438

@@ -42,6 +46,68 @@ class PandasOnRayDataframePartitionManager(GenericRayDataframePartitionManager):
4246
_execution_wrapper = RayWrapper
4347
materialize_futures = RayWrapper.materialize
4448

49+
@classmethod
50+
@_inherit_docstrings(GenericRayDataframePartitionManager.get_indices)
51+
def get_indices(cls, axis, partitions, index_func=None):
52+
partitions = partitions.T if axis == 0 else partitions
53+
if len(partitions) == 0:
54+
return pandas.Index([]), []
55+
56+
partitions = [part for part in partitions[0]]
57+
non_split, lengths, _ = (
58+
PandasOnRayDataframeVirtualPartition.find_non_split_block(partitions)
59+
)
60+
if non_split is not None:
61+
partitions = [non_split]
62+
else:
63+
partitions = [part._data for part in partitions]
64+
65+
if index_func is None:
66+
attr_name = f"_GET_AXIS_{axis}"
67+
if (fn := getattr(cls, attr_name, None)) is None:
68+
69+
def get_cols(*dfs, axis=axis):
70+
return [df.axes[axis] for df in dfs]
71+
72+
setattr(cls, attr_name, get_cols)
73+
fn = RayWrapper.put(get_cols)
74+
data, args = partitions[0], partitions[1:]
75+
else:
76+
if (fn := getattr(cls, "_GET_AXIS_FN", None)) is None:
77+
78+
def apply_index(fn, *dfs):
79+
return [fn(df) for df in dfs]
80+
81+
cls._GET_AXIS_FN = fn = RayWrapper.put(apply_index)
82+
data, args = index_func, partitions
83+
84+
de = DeferredExecution(data, fn, args, num_returns=len(partitions))
85+
part_indices = de.exec()[0]
86+
87+
if non_split is not None:
88+
materialized = RayWrapper.materialize([part_indices] + lengths)
89+
idx = materialized[0][0]
90+
lengths = materialized[1:]
91+
92+
if (idx_len := len(idx)) != sum(lengths):
93+
count = len(lengths)
94+
chunk_len = max(math.ceil(idx_len / count), MinPartitionSize.get())
95+
lengths = [chunk_len] * count
96+
97+
part_indices = []
98+
start = 0
99+
for length in lengths:
100+
end = start + length
101+
part_indices.append(idx[start:end])
102+
start = end
103+
return idx, part_indices
104+
105+
part_indices = RayWrapper.materialize(part_indices)
106+
indices = [idx for idx in part_indices if len(idx)]
107+
if len(indices) == 0:
108+
return part_indices[0], part_indices
109+
return indices[0].append(indices[1:]), part_indices
110+
45111
@classmethod
46112
def wait_partitions(cls, partitions):
47113
"""

0 commit comments

Comments
 (0)