Skip to content

Commit 004e6e6

Browse files
committed
calling split_result_of_axis_func_pandas only for ray
Signed-off-by: arunjose696 <arunjose696@gmail.com>
1 parent 5fb4465 commit 004e6e6

File tree

6 files changed

+103
-11
lines changed

6 files changed

+103
-11
lines changed

modin/core/dataframe/pandas/partitioning/axis_partition.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
from modin.core.dataframe.base.partitioning.axis_partition import (
2222
BaseDataframeAxisPartition,
2323
)
24-
from modin.core.storage_formats.pandas.utils import split_result_of_axis_func_pandas
24+
from modin.core.storage_formats.pandas.utils import (
25+
generate_result_of_axis_func_pandas,
26+
split_result_of_axis_func_pandas,
27+
)
2528

2629
from .partition import PandasDataframePartition
2730

@@ -388,6 +391,7 @@ def deploy_axis_func(
388391
*partitions,
389392
lengths=None,
390393
manual_partition=False,
394+
return_generator=False,
391395
):
392396
"""
393397
Deploy a function along a full axis.
@@ -413,11 +417,14 @@ def deploy_axis_func(
413417
The list of lengths to shuffle the object.
414418
manual_partition : bool, default: False
415419
If True, partition the result with `lengths`.
420+
return_generator : bool, default: False
421+
Return a generator from the function, set to `True` for Ray backend
422+
as Ray remote functions can return Generators.
416423
417424
Returns
418425
-------
419-
list
420-
A list of pandas DataFrames.
426+
list | Generator
427+
A list or generator of pandas DataFrames.
421428
"""
422429
dataframe = pandas.concat(list(partitions), axis=axis, copy=False)
423430
with warnings.catch_warnings():
@@ -451,7 +458,12 @@ def deploy_axis_func(
451458
lengths = [len(part.columns) for part in partitions]
452459
if sum(lengths) != len(result.columns):
453460
lengths = None
454-
return split_result_of_axis_func_pandas(axis, num_splits, result, lengths)
461+
if return_generator:
462+
return generate_result_of_axis_func_pandas(
463+
axis, num_splits, result, lengths
464+
)
465+
else:
466+
return split_result_of_axis_func_pandas(axis, num_splits, result, lengths)
455467

456468
@classmethod
457469
def deploy_func_between_two_axis_partitions(
@@ -464,6 +476,7 @@ def deploy_func_between_two_axis_partitions(
464476
len_of_left,
465477
other_shape,
466478
*partitions,
479+
return_generator=False,
467480
):
468481
"""
469482
Deploy a function along a full axis between two data sets.
@@ -487,11 +500,14 @@ def deploy_func_between_two_axis_partitions(
487500
(other_shape[i-1], other_shape[i]) will indicate slice to restore i-1 axis partition.
488501
*partitions : iterable
489502
All partitions that make up the full axis (row or column) for both data sets.
503+
return_generator : bool, default: False
504+
Return a generator from the function, set to `True` for Ray backend
505+
as Ray remote functions can return Generators.
490506
491507
Returns
492508
-------
493-
list
494-
A list of pandas DataFrames.
509+
list | Generator
510+
A list or generator of pandas DataFrames.
495511
"""
496512
lt_frame = pandas.concat(partitions[:len_of_left], axis=axis, copy=False)
497513

@@ -510,7 +526,18 @@ def deploy_func_between_two_axis_partitions(
510526
with warnings.catch_warnings():
511527
warnings.filterwarnings("ignore", category=FutureWarning)
512528
result = func(lt_frame, rt_frame, *f_args, **f_kwargs)
513-
return split_result_of_axis_func_pandas(axis, num_splits, result)
529+
if return_generator:
530+
return generate_result_of_axis_func_pandas(
531+
axis,
532+
num_splits,
533+
result,
534+
)
535+
else:
536+
return split_result_of_axis_func_pandas(
537+
axis,
538+
num_splits,
539+
result,
540+
)
514541

515542
@classmethod
516543
def drain(cls, df: pandas.DataFrame, call_queue: list):

modin/core/execution/ray/implementations/cudf_on_ray/partitioning/partition_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def from_pandas(cls, df, return_dims=False):
125125
num_splits = GpuCount.get()
126126
put_func = cls._partition_class.put
127127
# For now, we default to row partitioning
128-
pandas_dfs = list(split_result_of_axis_func_pandas(0, num_splits, df))
128+
pandas_dfs = split_result_of_axis_func_pandas(0, num_splits, df)
129129
keys = [
130130
put_func(cls._get_gpu_managers()[i], pandas_dfs[i])
131131
for i in range(num_splits)

modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def deploy_axis_func(
189189
f_kwargs=f_kwargs,
190190
manual_partition=manual_partition,
191191
lengths=lengths,
192+
return_generator=True,
192193
)
193194

194195
@classmethod
@@ -244,6 +245,7 @@ def deploy_func_between_two_axis_partitions(
244245
f_to_deploy=func,
245246
f_len_args=len(f_args),
246247
f_kwargs=f_kwargs,
248+
return_generator=True,
247249
)
248250

249251
def wait(self):

modin/core/storage_formats/cudf/parser.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def _split_result_for_readers(axis, num_splits, df): # pragma: no cover
3939
Returns:
4040
A list of pandas DataFrames.
4141
"""
42-
splits = list(split_result_of_axis_func_pandas(axis, num_splits, df))
42+
splits = split_result_of_axis_func_pandas(axis, num_splits, df)
43+
if not isinstance(splits, list):
44+
splits = [splits]
4345
return splits
4446

4547

modin/core/storage_formats/pandas/parsers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ def _split_result_for_readers(axis, num_splits, df): # pragma: no cover
113113
list
114114
A list of pandas DataFrames.
115115
"""
116-
splits = list(split_result_of_axis_func_pandas(axis, num_splits, df))
116+
splits = split_result_of_axis_func_pandas(axis, num_splits, df)
117+
if not isinstance(splits, list):
118+
splits = [splits]
117119
return splits
118120

119121

modin/core/storage_formats/pandas/utils.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,70 @@ def split_result_of_axis_func_pandas(
8484
list of pandas.DataFrames
8585
Splitted dataframe represented by list of frames.
8686
"""
87+
if num_splits == 1:
88+
return [result]
89+
90+
if length_list is None:
91+
length_list = get_length_list(result.shape[axis], num_splits, min_block_size)
92+
# Inserting the first "zero" to properly compute cumsum indexing slices
93+
length_list = np.insert(length_list, obj=0, values=[0])
94+
95+
sums = np.cumsum(length_list)
96+
axis = 0 if isinstance(result, pandas.Series) else axis
97+
# We do this to restore block partitioning
98+
if axis == 0:
99+
chunked = [result.iloc[sums[i] : sums[i + 1]] for i in range(len(sums) - 1)]
100+
else:
101+
chunked = [result.iloc[:, sums[i] : sums[i + 1]] for i in range(len(sums) - 1)]
102+
103+
return [
104+
# Sliced MultiIndex still stores all encoded values of the original index, explicitly
105+
# asking it to drop unused values in order to save memory.
106+
(
107+
chunk.set_axis(
108+
chunk.axes[axis].remove_unused_levels(), axis=axis, copy=False
109+
)
110+
if isinstance(chunk.axes[axis], pandas.MultiIndex)
111+
else chunk
112+
)
113+
for chunk in chunked
114+
]
115+
116+
117+
def generate_result_of_axis_func_pandas(
118+
axis, num_splits, result, length_list=None, min_block_size=None
119+
):
120+
"""
121+
Generate pandas DataFrame evenly based on the provided number of splits.
122+
123+
Parameters
124+
----------
125+
axis : {0, 1}
126+
Axis to split across. 0 means index axis when 1 means column axis.
127+
num_splits : int
128+
Number of splits to separate the DataFrame into.
129+
This parameter is ignored if `length_list` is specified.
130+
result : pandas.DataFrame
131+
DataFrame to split.
132+
length_list : list of ints, optional
133+
List of slice lengths to split DataFrame into. This is used to
134+
return the DataFrame to its original partitioning schema.
135+
min_block_size : int, optional
136+
Minimum number of rows/columns in a single split.
137+
If not specified, the value is assumed equal to ``MinPartitionSize``.
138+
139+
Yields
140+
------
141+
Generator
142+
Generates 'num_splits' dataframes as a result of axis function.
143+
"""
87144
if num_splits == 1:
88145
yield result
89146
else:
90147
if length_list is None:
91-
length_list = get_length_list(result.shape[axis], num_splits,min_block_size)
148+
length_list = get_length_list(
149+
result.shape[axis], num_splits, min_block_size
150+
)
92151
# Inserting the first "zero" to properly compute cumsum indexing slices
93152
length_list = np.insert(length_list, obj=0, values=[0])
94153
sums = np.cumsum(length_list)

0 commit comments

Comments
 (0)