Skip to content

Commit eb740b9

Browse files
authored
FEAT-#7004: use generators when returning from _deploy_ray_func remote function. (#7005)
Signed-off-by: arunjose696 <arunjose696@gmail.com>
1 parent 2ed2d49 commit eb740b9

File tree

3 files changed

+104
-39
lines changed

3 files changed

+104
-39
lines changed

modin/core/dataframe/pandas/partitioning/axis_partition.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
from modin.core.dataframe.base.partitioning.axis_partition import (
2222
BaseDataframeAxisPartition,
2323
)
24-
from modin.core.storage_formats.pandas.utils import split_result_of_axis_func_pandas
24+
from modin.core.storage_formats.pandas.utils import (
25+
generate_result_of_axis_func_pandas,
26+
split_result_of_axis_func_pandas,
27+
)
2528

2629
from .partition import PandasDataframePartition
2730

@@ -388,6 +391,7 @@ def deploy_axis_func(
388391
*partitions,
389392
lengths=None,
390393
manual_partition=False,
394+
return_generator=False,
391395
):
392396
"""
393397
Deploy a function along a full axis.
@@ -413,11 +417,14 @@ def deploy_axis_func(
413417
The list of lengths to shuffle the object.
414418
manual_partition : bool, default: False
415419
If True, partition the result with `lengths`.
420+
return_generator : bool, default: False
421+
Return a generator from the function, set to `True` for Ray backend
422+
as Ray remote functions can return Generators.
416423
417424
Returns
418425
-------
419-
list
420-
A list of pandas DataFrames.
426+
list | Generator
427+
A list or generator of pandas DataFrames.
421428
"""
422429
dataframe = pandas.concat(list(partitions), axis=axis, copy=False)
423430
with warnings.catch_warnings():
@@ -451,7 +458,12 @@ def deploy_axis_func(
451458
lengths = [len(part.columns) for part in partitions]
452459
if sum(lengths) != len(result.columns):
453460
lengths = None
454-
return split_result_of_axis_func_pandas(axis, num_splits, result, lengths)
461+
if return_generator:
462+
return generate_result_of_axis_func_pandas(
463+
axis, num_splits, result, lengths
464+
)
465+
else:
466+
return split_result_of_axis_func_pandas(axis, num_splits, result, lengths)
455467

456468
@classmethod
457469
def deploy_func_between_two_axis_partitions(
@@ -464,6 +476,7 @@ def deploy_func_between_two_axis_partitions(
464476
len_of_left,
465477
other_shape,
466478
*partitions,
479+
return_generator=False,
467480
):
468481
"""
469482
Deploy a function along a full axis between two data sets.
@@ -487,11 +500,14 @@ def deploy_func_between_two_axis_partitions(
487500
(other_shape[i-1], other_shape[i]) will indicate slice to restore i-1 axis partition.
488501
*partitions : iterable
489502
All partitions that make up the full axis (row or column) for both data sets.
503+
return_generator : bool, default: False
504+
Return a generator from the function, set to `True` for Ray backend
505+
as Ray remote functions can return Generators.
490506
491507
Returns
492508
-------
493-
list
494-
A list of pandas DataFrames.
509+
list | Generator
510+
A list or generator of pandas DataFrames.
495511
"""
496512
lt_frame = pandas.concat(partitions[:len_of_left], axis=axis, copy=False)
497513

@@ -510,7 +526,18 @@ def deploy_func_between_two_axis_partitions(
510526
with warnings.catch_warnings():
511527
warnings.filterwarnings("ignore", category=FutureWarning)
512528
result = func(lt_frame, rt_frame, *f_args, **f_kwargs)
513-
return split_result_of_axis_func_pandas(axis, num_splits, result)
529+
if return_generator:
530+
return generate_result_of_axis_func_pandas(
531+
axis,
532+
num_splits,
533+
result,
534+
)
535+
else:
536+
return split_result_of_axis_func_pandas(
537+
axis,
538+
num_splits,
539+
result,
540+
)
514541

515542
@classmethod
516543
def drain(cls, df: pandas.DataFrame, call_queue: list):

modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def deploy_axis_func(
189189
f_kwargs=f_kwargs,
190190
manual_partition=manual_partition,
191191
lengths=lengths,
192+
return_generator=True,
192193
)
193194

194195
@classmethod
@@ -244,6 +245,7 @@ def deploy_func_between_two_axis_partitions(
244245
f_to_deploy=func,
245246
f_len_args=len(f_args),
246247
f_kwargs=f_kwargs,
248+
return_generator=True,
247249
)
248250

249251
def wait(self):
@@ -320,12 +322,16 @@ def _deploy_ray_func(
320322
f_args = positional_args[:f_len_args]
321323
deploy_args = positional_args[f_len_args:]
322324
result = deployer(axis, f_to_deploy, f_args, f_kwargs, *deploy_args, **kwargs)
325+
323326
if not extract_metadata:
324-
return result
325-
ip = get_node_ip_address()
326-
if isinstance(result, pandas.DataFrame):
327-
return result, len(result), len(result.columns), ip
328-
elif all(isinstance(r, pandas.DataFrame) for r in result):
329-
return [i for r in result for i in [r, len(r), len(r.columns), ip]]
327+
for item in result:
328+
yield item
330329
else:
331-
return [i for r in result for i in [r, None, None, ip]]
330+
ip = get_node_ip_address()
331+
for r in result:
332+
if isinstance(r, pandas.DataFrame):
333+
for item in [r, len(r), len(r.columns), ip]:
334+
yield item
335+
else:
336+
for item in [r, None, None, ip]:
337+
yield item

modin/core/storage_formats/pandas/utils.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -84,34 +84,66 @@ def split_result_of_axis_func_pandas(
8484
list of pandas.DataFrames
8585
Splitted dataframe represented by list of frames.
8686
"""
87+
return list(
88+
generate_result_of_axis_func_pandas(
89+
axis, num_splits, result, length_list, min_block_size
90+
)
91+
)
92+
93+
94+
def generate_result_of_axis_func_pandas(
95+
axis, num_splits, result, length_list=None, min_block_size=None
96+
):
97+
"""
98+
Generate pandas DataFrame evenly based on the provided number of splits.
99+
100+
Parameters
101+
----------
102+
axis : {0, 1}
103+
Axis to split across. 0 means index axis when 1 means column axis.
104+
num_splits : int
105+
Number of splits to separate the DataFrame into.
106+
This parameter is ignored if `length_list` is specified.
107+
result : pandas.DataFrame
108+
DataFrame to split.
109+
length_list : list of ints, optional
110+
List of slice lengths to split DataFrame into. This is used to
111+
return the DataFrame to its original partitioning schema.
112+
min_block_size : int, optional
113+
Minimum number of rows/columns in a single split.
114+
If not specified, the value is assumed equal to ``MinPartitionSize``.
115+
116+
Yields
117+
------
118+
Generator
119+
Generates 'num_splits' dataframes as a result of axis function.
120+
"""
87121
if num_splits == 1:
88-
return [result]
89-
90-
if length_list is None:
91-
length_list = get_length_list(result.shape[axis], num_splits, min_block_size)
92-
# Inserting the first "zero" to properly compute cumsum indexing slices
93-
length_list = np.insert(length_list, obj=0, values=[0])
94-
95-
sums = np.cumsum(length_list)
96-
axis = 0 if isinstance(result, pandas.Series) else axis
97-
# We do this to restore block partitioning
98-
if axis == 0:
99-
chunked = [result.iloc[sums[i] : sums[i + 1]] for i in range(len(sums) - 1)]
122+
yield result
100123
else:
101-
chunked = [result.iloc[:, sums[i] : sums[i + 1]] for i in range(len(sums) - 1)]
102-
103-
return [
104-
# Sliced MultiIndex still stores all encoded values of the original index, explicitly
105-
# asking it to drop unused values in order to save memory.
106-
(
107-
chunk.set_axis(
108-
chunk.axes[axis].remove_unused_levels(), axis=axis, copy=False
124+
if length_list is None:
125+
length_list = get_length_list(
126+
result.shape[axis], num_splits, min_block_size
109127
)
110-
if isinstance(chunk.axes[axis], pandas.MultiIndex)
111-
else chunk
112-
)
113-
for chunk in chunked
114-
]
128+
# Inserting the first "zero" to properly compute cumsum indexing slices
129+
length_list = np.insert(length_list, obj=0, values=[0])
130+
sums = np.cumsum(length_list)
131+
axis = 0 if isinstance(result, pandas.Series) else axis
132+
133+
for i in range(len(sums) - 1):
134+
# We do this to restore block partitioning
135+
if axis == 0:
136+
chunk = result.iloc[sums[i] : sums[i + 1]]
137+
else:
138+
chunk = result.iloc[:, sums[i] : sums[i + 1]]
139+
140+
# Sliced MultiIndex still stores all encoded values of the original index, explicitly
141+
# asking it to drop unused values in order to save memory.
142+
if isinstance(chunk.axes[axis], pandas.MultiIndex):
143+
chunk = chunk.set_axis(
144+
chunk.axes[axis].remove_unused_levels(), axis=axis, copy=False
145+
)
146+
yield chunk
115147

116148

117149
def get_length_list(axis_len: int, num_splits: int, min_block_size=None) -> list:

0 commit comments

Comments
 (0)