@@ -161,24 +161,6 @@ def exec(
161
161
and self .flat_kwargs
162
162
and self .num_returns == 1
163
163
):
164
- # self.data = RayWrapper.materialize(self.data)
165
- # self.args = [
166
- # RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
167
- # for o in self.args
168
- # ]
169
- # self.kwargs = {
170
- # k: RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
171
- # for k, o in self.kwargs.items()
172
- # }
173
- # obj = _REMOTE_EXEC.exec_func(
174
- # RayWrapper.materialize(self.func), self.data, self.args, self.kwargs
175
- # )
176
- # result, length, width, ip = (
177
- # obj,
178
- # len(obj) if hasattr(obj, "__len__") else 0,
179
- # len(obj.columns) if hasattr(obj, "columns") else 0,
180
- # "",
181
- # )
182
164
result , length , width , ip = remote_exec_func .remote (
183
165
self .func , self .data , * self .args , ** self .kwargs
184
166
)
@@ -191,13 +173,6 @@ def exec(
191
173
self .subscribers += 2
192
174
consumers , output = self ._deconstruct ()
193
175
194
- # assert not any(isinstance(o, ListOrTuple) for o in output)
195
- # tmp = [
196
- # RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
197
- # for o in output
198
- # ]
199
- # list(_REMOTE_EXEC.construct(tmp))
200
-
201
176
# The last result is the MetaList, so adding +1 here.
202
177
num_returns = sum (c .num_returns for c in consumers ) + 1
203
178
results = self ._remote_exec_chain (num_returns , * output )
@@ -336,7 +311,9 @@ def _deconstruct_chain(
336
311
out_extend = output .extend
337
312
while True :
338
313
de .unsubscribe ()
339
- if not de .has_result and (out_pos := getattr (de , "out_pos" , None )):
314
+ if not (has_result := de .has_result ) and (
315
+ out_pos := getattr (de , "out_pos" , None )
316
+ ):
340
317
out_append (_Tag .REF )
341
318
out_append (out_pos )
342
319
output [out_pos ] = out_pos
@@ -357,7 +334,7 @@ def _deconstruct_chain(
357
334
)
358
335
else :
359
336
out_append (data )
360
- if not de . has_result :
337
+ if not has_result :
361
338
stack .append (de )
362
339
break
363
340
else :
@@ -425,28 +402,24 @@ def _deconstruct_list(
425
402
"""
426
403
for obj in lst :
427
404
if isinstance (obj , DeferredExecution ):
428
- if out_pos := getattr (obj , "out_pos" , None ):
405
+ if obj .has_result :
406
+ obj = obj .data
407
+ elif out_pos := getattr (obj , "out_pos" , None ):
429
408
obj .unsubscribe ()
430
- if obj .has_result :
431
- if isinstance (obj .data , ListOrTuple ):
432
- out_append (_Tag .LIST )
433
- yield cls ._deconstruct_list (
434
- obj .data , output , stack , result_consumers , out_append
435
- )
436
- else :
437
- out_append (obj .data )
438
- else :
439
- out_append (_Tag .REF )
440
- out_append (out_pos )
441
- output [out_pos ] = out_pos
442
- if obj .subscribers == 0 :
443
- output [out_pos + 1 ] = 0
444
- result_consumers .remove (obj )
409
+ out_append (_Tag .REF )
410
+ out_append (out_pos )
411
+ output [out_pos ] = out_pos
412
+ if obj .subscribers == 0 :
413
+ output [out_pos + 1 ] = 0
414
+ result_consumers .remove (obj )
415
+ continue
445
416
else :
446
417
out_append (_Tag .CHAIN )
447
418
yield cls ._deconstruct_chain (obj , output , stack , result_consumers )
448
419
out_append (_Tag .END )
449
- elif isinstance (obj , ListOrTuple ):
420
+ continue
421
+
422
+ if isinstance (obj , ListOrTuple ):
450
423
out_append (_Tag .LIST )
451
424
yield cls ._deconstruct_list (
452
425
obj , output , stack , result_consumers , out_append
@@ -517,27 +490,13 @@ class DeferredGetItem(DeferredExecution):
517
490
----------
518
491
data : ObjectRefOrDeType
519
492
The object to get the item from.
520
- idx : int
493
+ index : int
521
494
The item index.
522
495
"""
523
496
524
- def __init__ (self , data : ObjectRefOrDeType , idx : int ):
525
- super ().__init__ (data , self ._remote_fn (), [idx ])
526
- self .index = idx
527
-
528
- @_inherit_docstrings (DeferredExecution .exec )
529
- def exec (self ) -> Tuple [ObjectRefType , "MetaList" , int ]:
530
- if self .has_result :
531
- return self .data , self .meta , self .meta_offset
532
-
533
- if not isinstance (self .data , DeferredExecution ) or self .data .num_returns == 1 :
534
- return super ().exec ()
535
-
536
- # If `data` is a `DeferredExecution`, that returns multiple results,
537
- # it's not required to execute `_remote_fn()`. We can only execute
538
- # `data` and get the result by index.
539
- self ._data_exec ()
540
- return self .data , self .meta , self .meta_offset
497
+ def __init__ (self , data : ObjectRefOrDeType , index : int ):
498
+ super ().__init__ (data , self ._remote_fn (), [index ])
499
+ self .index = index
541
500
542
501
@property
543
502
@_inherit_docstrings (DeferredExecution .has_result )
@@ -550,16 +509,18 @@ def has_result(self):
550
509
and self .data .has_result
551
510
and self .data .num_returns != 1
552
511
):
553
- self ._data_exec ()
512
+ # If `data` is a `DeferredExecution`, that returns multiple results,
513
+ # it's not required to execute `_remote_fn()`. We can only execute
514
+ # `data` and get the result by index.
515
+ self ._set_result (
516
+ self .data .data [self .index ],
517
+ self .data .meta ,
518
+ self .data .meta_offset [self .index ],
519
+ )
554
520
return True
555
521
556
522
return False
557
523
558
- def _data_exec (self ):
559
- """Execute the `data` task and get the result."""
560
- obj , meta , offsets = self .data .exec ()
561
- self ._set_result (obj [self .index ], meta , offsets [self .index ])
562
-
563
524
@classmethod
564
525
def _remote_fn (cls ) -> ObjectRefType :
565
526
"""
@@ -592,7 +553,8 @@ def __init__(self, obj: Union[ray.ObjectID, ClientObjectRef, List]):
592
553
593
554
def materialize (self ):
594
555
"""Materialized the list, if required."""
595
- self ._obj = RayWrapper .materialize (self ._obj )
556
+ if not isinstance (self ._obj , list ):
557
+ self ._obj = RayWrapper .materialize (self ._obj )
596
558
597
559
def __getitem__ (self , index ):
598
560
"""
@@ -632,14 +594,13 @@ class MetaListHook(MaterializationHook, DeferredGetItem):
632
594
----------
633
595
meta : MetaList
634
596
Non-materialized list to get the value from.
635
- idx : int
597
+ index : int
636
598
The value index in the list.
637
599
"""
638
600
639
- def __init__ (self , meta : MetaList , idx : int ):
640
- super ().__init__ (meta ._obj , idx )
601
+ def __init__ (self , meta : MetaList , index : int ):
602
+ super ().__init__ (meta ._obj , index )
641
603
self .meta = meta
642
- self .idx = idx
643
604
644
605
def pre_materialize (self ):
645
606
"""
@@ -650,7 +611,7 @@ def pre_materialize(self):
650
611
object
651
612
"""
652
613
obj = self .meta ._obj
653
- return obj [self .idx ] if isinstance (obj , list ) else obj
614
+ return obj [self .index ] if isinstance (obj , list ) else obj
654
615
655
616
def post_materialize (self , materialized ):
656
617
"""
@@ -665,7 +626,7 @@ def post_materialize(self, materialized):
665
626
object
666
627
"""
667
628
self .meta ._obj = materialized
668
- return materialized [self .idx ]
629
+ return materialized [self .index ]
669
630
670
631
671
632
class _Tag (Enum ): # noqa: PR01
0 commit comments