14
14
# limitations under the License.
15
15
"""
16
16
17
+ import os
17
18
from contextlib import contextmanager
18
- from dataclasses import dataclass
19
- from typing import Callable , Dict , Optional
19
+ from dataclasses import dataclass , field
20
+ from typing import Callable , Dict , List , Optional
20
21
21
- import paddle .jit .dy2static .utils as jit_utils
22
22
import paddle .nn .layer
23
23
from paddle .device .cuda import graphs
24
24
25
+ from fastdeploy import envs
25
26
from fastdeploy .config import FDConfig
26
27
from fastdeploy .distributed .communication import capture_custom_allreduce
27
28
from fastdeploy .utils import get_logger
@@ -40,36 +41,39 @@ class ConcreteSizeEntry:
40
41
# Has runtime-bs been captured before
41
42
captured : bool = False
42
43
43
- # Need to be captured callable object(dynamic graph or static graph backend)
44
+ # Need to be captured callable object(dynamic graph or static grpah backend)
44
45
runnable : Callable = None # type: ignore
45
46
# Number of completed warmups
46
47
num_finished_warmup : int = 0
47
48
# Captured cuda graph object corresponding to the current real shape
48
49
cuda_graph : Optional [graphs .CUDAGraph ] = None
49
- # Output buffer of cudagraph
50
- output_buffer : Optional [paddle .Tensor ] = None
50
+ # Output buffers of cudagraph
51
+ output_buffers : List [ Optional [paddle .Tensor ]] = field ( default_factory = list )
51
52
52
53
53
54
class Dy2StCudaGraphManager :
54
55
def __init__ (self ):
56
+ # NOTE(gongshaotian): Use local import to avoid RLHF version problems
57
+ from paddle .jit .dy2static .utils import CUDAGraphState
55
58
56
- self .state = jit_utils . CUDAGraphState .DISABLE
59
+ self .state = CUDAGraphState .DISABLE
57
60
self .captured_batch_size = set ()
58
61
self .batch_size = - 1
59
62
60
63
def run_impl (self , original_run_impl , inputs , parameters , attrs ):
64
+ from paddle .jit .dy2static .utils import CUDAGraphState
61
65
62
66
run_state = self .state
63
67
prog_attrs , cuda_graph_attrs = attrs
64
- if run_state == jit_utils . CUDAGraphState .REPLAY :
68
+ if run_state == CUDAGraphState .REPLAY :
65
69
if self .batch_size not in self .captured_batch_size :
66
- run_state = jit_utils . CUDAGraphState .DISABLE
67
- elif run_state == jit_utils . CUDAGraphState .CAPTURE :
70
+ run_state = CUDAGraphState .DISABLE
71
+ elif run_state == CUDAGraphState .CAPTURE :
68
72
self .captured_batch_size .add (self .batch_size )
69
73
70
74
cuda_graph_attrs |= {
71
75
"cuda_graph_state" : run_state ,
72
- "cuda_graph_dispatch_key" : self .batch_size if run_state != jit_utils . CUDAGraphState .DISABLE else 0 ,
76
+ "cuda_graph_dispatch_key" : self .batch_size if run_state != CUDAGraphState .DISABLE else 0 ,
73
77
}
74
78
return original_run_impl (inputs , parameters , (prog_attrs , cuda_graph_attrs ))
75
79
@@ -102,6 +106,7 @@ def __init__(
102
106
self .cuda_graph_manager = Dy2StCudaGraphManager ()
103
107
104
108
def run_static_model (self , entry : ConcreteSizeEntry , ** kwargs ):
109
+ from paddle .jit .dy2static .utils import CUDAGraphState
105
110
106
111
if not entry .captured :
107
112
# Warmup the model
@@ -118,21 +123,21 @@ def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
118
123
entry .input_addresses = input_addresses
119
124
120
125
# Capture
121
- self .cuda_graph_manager .state = jit_utils . CUDAGraphState .CAPTURE
126
+ self .cuda_graph_manager .state = CUDAGraphState .CAPTURE
122
127
self .cuda_graph_manager .batch_size = entry .real_shape
123
128
entry .captured = True
124
129
with self .cuda_graph_manager .run_impl_guard ():
125
130
entry .runnable (** kwargs )
126
131
127
132
# Replay
128
- self .cuda_graph_manager .state = jit_utils . CUDAGraphState .REPLAY
133
+ self .cuda_graph_manager .state = CUDAGraphState .REPLAY
129
134
self .cuda_graph_manager .batch_size = entry .real_shape
130
135
with self .cuda_graph_manager .run_impl_guard ():
131
136
return entry .runnable (** kwargs )
132
137
133
138
def __call__ (self , ** kwargs ):
134
139
# Get real shape(all num tokens)
135
- ids_remove_padding : paddle .Tensor = kwargs ["ids_remove_padding" ]
140
+ ids_remove_padding : paddle .Tensor = kwargs ["forward_meta" ]. ids_remove_padding
136
141
real_shape = ids_remove_padding .shape [0 ]
137
142
padding_real_shape = self .real_shape_to_captured_size [real_shape ]
138
143
logger .debug (
@@ -173,14 +178,22 @@ def __call__(self, **kwargs):
173
178
# Capture
174
179
with capture_custom_allreduce ():
175
180
new_grpah .capture_begin ()
176
- output = entry .runnable (** kwargs )
181
+ outputs = entry .runnable (** kwargs )
182
+ if isinstance (outputs , paddle .Tensor ):
183
+ assert outputs is not None
184
+ outputs = [outputs ]
177
185
new_grpah .capture_end ()
178
186
179
187
# Store output buffer
180
188
entry .cuda_graph = new_grpah
181
- entry .output_buffer = paddle .zeros_like (output )
182
- output ._share_buffer_to (entry .output_buffer )
183
- output ._clear
189
+ for output in outputs :
190
+ if output is not None :
191
+ output_buffer = paddle .zeros_like (output )
192
+ output ._share_buffer_to (output_buffer )
193
+ output ._clear
194
+ entry .output_buffers .append (output_buffer )
195
+ else :
196
+ entry .output_buffers .append (None )
184
197
185
198
paddle .device .synchronize ()
186
199
@@ -191,7 +204,9 @@ def __call__(self, **kwargs):
191
204
# Replay
192
205
entry .cuda_graph .replay ()
193
206
logger .debug (f"[CUDA GRAPH] CUDAGraph replayed for real shape { padding_real_shape } " )
194
- return entry .output_buffer
207
+ if len (entry .output_buffers ) == 1 :
208
+ return entry .output_buffers [0 ]
209
+ return entry .output_buffers
195
210
196
211
def _create_entry_dict (self ):
197
212
""" """
@@ -221,8 +236,11 @@ def clear_graph(self):
221
236
222
237
def _save_cudagrpah_dot_files (self , entry ):
223
238
"""Print CUDAGrpah to dot files"""
239
+ log_dir = envs .FD_LOG_DIR
240
+ if not os .path .exists (log_dir ):
241
+ os .makedirs (log_dir , exist_ok = True )
224
242
if entry .cuda_graph :
225
243
entry .cuda_graph .print_to_dot_files (
226
- f"./log /GraphDotFiles/backend{ id (self )} _shape{ entry .real_shape } " ,
244
+ f"{ log_dir } /GraphDotFiles/backend{ id (self )} _shape{ entry .real_shape } " ,
227
245
1 << 0 ,
228
246
)
0 commit comments