Skip to content

Commit 7d4e663

Browse files
committed
adding generators to split_result_of_axis_func_pandas
1 parent 1a4ef3c commit 7d4e663

File tree

2 files changed

+20
-21
lines changed

2 files changed

+20
-21
lines changed

modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -320,16 +320,17 @@ def _deploy_ray_func(
320320
f_args = positional_args[:f_len_args]
321321
deploy_args = positional_args[f_len_args:]
322322
result = deployer(axis, f_to_deploy, f_args, f_kwargs, *deploy_args, **kwargs)
323+
323324
if not extract_metadata:
324-
return result
325+
for item in result:
326+
yield item
327+
return
328+
325329
ip = get_node_ip_address()
326-
if isinstance(result, pandas.DataFrame):
327-
return result, len(result), len(result.columns), ip
328-
elif all(isinstance(r, pandas.DataFrame) for r in result):
329-
for r in result:
330+
for r in result:
331+
if isinstance(r, pandas.DataFrame):
330332
for item in [r, len(r), len(r.columns), ip]:
331333
yield item
332-
else:
333-
for r in result:
334+
else:
334335
for item in [r, None, None, ip]:
335336
yield item

modin/core/storage_formats/pandas/utils.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def split_result_of_axis_func_pandas(axis, num_splits, result, length_list=None)
8080
Splitted dataframe represented by list of frames.
8181
"""
8282
if num_splits == 1:
83-
return [result]
83+
yield result
84+
return
8485

8586
if length_list is None:
8687
length_list = get_length_list(result.shape[axis], num_splits)
@@ -89,24 +90,21 @@ def split_result_of_axis_func_pandas(axis, num_splits, result, length_list=None)
8990

9091
sums = np.cumsum(length_list)
9192
axis = 0 if isinstance(result, pandas.Series) else axis
92-
# We do this to restore block partitioning
93-
if axis == 0:
94-
chunked = [result.iloc[sums[i] : sums[i + 1]] for i in range(len(sums) - 1)]
95-
else:
96-
chunked = [result.iloc[:, sums[i] : sums[i + 1]] for i in range(len(sums) - 1)]
9793

98-
return [
94+
for i in range(len(sums) - 1):
95+
# We do this to restore block partitioning
96+
if axis == 0:
97+
chunk = result.iloc[sums[i] : sums[i + 1]]
98+
else:
99+
chunk = result.iloc[:, sums[i] : sums[i + 1]]
100+
99101
# Sliced MultiIndex still stores all encoded values of the original index, explicitly
100102
# asking it to drop unused values in order to save memory.
101-
(
102-
chunk.set_axis(
103+
if isinstance(chunk.axes[axis], pandas.MultiIndex):
104+
chunk = chunk.set_axis(
103105
chunk.axes[axis].remove_unused_levels(), axis=axis, copy=False
104106
)
105-
if isinstance(chunk.axes[axis], pandas.MultiIndex)
106-
else chunk
107-
)
108-
for chunk in chunked
109-
]
107+
yield chunk
110108

111109

112110
def get_length_list(axis_len: int, num_splits: int) -> list:

0 commit comments

Comments
 (0)