14
14
# limitations under the License.
15
15
"""
16
16
17
+ import os
17
18
from contextlib import contextmanager
18
- from dataclasses import dataclass
19
- from typing import Callable , Dict , Optional
19
+ from dataclasses import dataclass , field
20
+ from typing import Callable , Dict , List , Optional
20
21
21
22
import paddle .jit .dy2static .utils as jit_utils
22
23
import paddle .nn .layer
23
24
from paddle .device .cuda import graphs
24
25
26
+ from fastdeploy import envs
25
27
from fastdeploy .config import FDConfig
26
28
from fastdeploy .distributed .communication import capture_custom_allreduce
27
29
from fastdeploy .utils import get_logger
@@ -46,8 +48,8 @@ class ConcreteSizeEntry:
46
48
num_finished_warmup : int = 0
47
49
# Captured cuda graph object corresponding to the current real shape
48
50
cuda_graph : Optional [graphs .CUDAGraph ] = None
49
- # Output buffer of cudagraph
50
- output_buffer : Optional [paddle .Tensor ] = None
51
+ # Output buffers of cudagraph
52
+ output_buffers : List [ Optional [paddle .Tensor ]] = field ( default_factory = list )
51
53
52
54
53
55
class Dy2StCudaGraphManager :
@@ -130,9 +132,9 @@ def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
130
132
with self .cuda_graph_manager .run_impl_guard ():
131
133
return entry .runnable (** kwargs )
132
134
133
- def __call__ (self , ** kwargs ):
135
+ def __call__ (self , ** kwargs ) -> List [ paddle . Tensor ] | paddle . Tensor :
134
136
# Get real shape(all num tokens)
135
- ids_remove_padding : paddle .Tensor = kwargs ["ids_remove_padding" ]
137
+ ids_remove_padding : paddle .Tensor = kwargs ["forward_meta" ]. ids_remove_padding
136
138
real_shape = ids_remove_padding .shape [0 ]
137
139
padding_real_shape = self .real_shape_to_captured_size [real_shape ]
138
140
logger .debug (
@@ -173,14 +175,22 @@ def __call__(self, **kwargs):
173
175
# Capture
174
176
with capture_custom_allreduce ():
175
177
new_grpah .capture_begin ()
176
- output = entry .runnable (** kwargs )
178
+ outputs = entry .runnable (** kwargs )
179
+ if isinstance (outputs , paddle .Tensor ):
180
+ assert outputs is not None
181
+ outputs = [outputs ]
177
182
new_grpah .capture_end ()
178
183
179
184
# Store output buffer
180
185
entry .cuda_graph = new_grpah
181
- entry .output_buffer = paddle .zeros_like (output )
182
- output ._share_buffer_to (entry .output_buffer )
183
- output ._clear
186
+ for output in outputs :
187
+ if output is not None :
188
+ output_buffer = paddle .zeros_like (output )
189
+ output ._share_buffer_to (output_buffer )
190
+ output ._clear
191
+ entry .output_buffers .append (output_buffer )
192
+ else :
193
+ entry .output_buffers .append (None )
184
194
185
195
paddle .device .synchronize ()
186
196
@@ -191,7 +201,9 @@ def __call__(self, **kwargs):
191
201
# Replay
192
202
entry .cuda_graph .replay ()
193
203
logger .debug (f"[CUDA GRAPH] CUDAGraph replayed for real shape { padding_real_shape } " )
194
- return entry .output_buffer
204
+ if len (entry .output_buffers ) == 1 :
205
+ return entry .output_buffers [0 ]
206
+ return entry .output_buffers
195
207
196
208
def _create_entry_dict (self ):
197
209
""" """
@@ -221,8 +233,11 @@ def clear_graph(self):
221
233
222
234
def _save_cudagrpah_dot_files (self , entry ):
223
235
"""Print CUDAGrpah to dot files"""
236
+ log_dir = envs .FD_LOG_DIR
237
+ if not os .path .exists (log_dir ):
238
+ os .makedirs (log_dir , exist_ok = True )
224
239
if entry .cuda_graph :
225
240
entry .cuda_graph .print_to_dot_files (
226
- f"./log /GraphDotFiles/backend{ id (self )} _shape{ entry .real_shape } " ,
241
+ f"{ log_dir } /GraphDotFiles/backend{ id (self )} _shape{ entry .real_shape } " ,
227
242
1 << 0 ,
228
243
)
0 commit comments