Skip to content

Commit 7c50810

Browse files
committed
refine cudagraph
1 parent 2650f58 commit 7c50810

File tree

2 files changed

+44
-24
lines changed

2 files changed

+44
-24
lines changed

fastdeploy/input/preprocess.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,13 @@ def create_processor(self):
7272
reasoning_parser_obj = None
7373
tool_parser_obj = None
7474
try:
75-
from fastdeploy.plugins.reasoning_parser import (
76-
load_reasoning_parser_plugins,
77-
)
75+
if self.reasoning_parser == "custom_reasoning_parser":
76+
from fastdeploy.plugins.reasoning_parser import (
77+
load_reasoning_parser_plugins,
78+
)
7879

79-
reasoning_parser_obj = load_reasoning_parser_plugins()
80+
custom_reasoning_parser = load_reasoning_parser_plugins()
81+
reasoning_parser_obj = custom_reasoning_parser
8082
except:
8183
if self.reasoning_parser:
8284
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser)

fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414
# limitations under the License.
1515
"""
1616

17+
import os
1718
from contextlib import contextmanager
18-
from dataclasses import dataclass
19-
from typing import Callable, Dict, Optional
19+
from dataclasses import dataclass, field
20+
from typing import Callable, Dict, List, Optional
2021

21-
import paddle.jit.dy2static.utils as jit_utils
2222
import paddle.nn.layer
2323
from paddle.device.cuda import graphs
2424

25+
from fastdeploy import envs
2526
from fastdeploy.config import FDConfig
2627
from fastdeploy.distributed.communication import capture_custom_allreduce
2728
from fastdeploy.utils import get_logger
@@ -40,36 +41,39 @@ class ConcreteSizeEntry:
4041
# Has runtime-bs been captured before
4142
captured: bool = False
4243

43-
# Need to be captured callable object(dynamic graph or static graph backend)
44+
# Need to be captured callable object(dynamic graph or static grpah backend)
4445
runnable: Callable = None # type: ignore
4546
# Number of completed warmups
4647
num_finished_warmup: int = 0
4748
# Captured cuda graph object corresponding to the current real shape
4849
cuda_graph: Optional[graphs.CUDAGraph] = None
49-
# Output buffer of cudagraph
50-
output_buffer: Optional[paddle.Tensor] = None
50+
# Output buffers of cudagraph
51+
output_buffers: List[Optional[paddle.Tensor]] = field(default_factory=list)
5152

5253

5354
class Dy2StCudaGraphManager:
5455
def __init__(self):
56+
# NOTE(gongshaotian): Use local import to avoid RLHF version problems
57+
from paddle.jit.dy2static.utils import CUDAGraphState
5558

56-
self.state = jit_utils.CUDAGraphState.DISABLE
59+
self.state = CUDAGraphState.DISABLE
5760
self.captured_batch_size = set()
5861
self.batch_size = -1
5962

6063
def run_impl(self, original_run_impl, inputs, parameters, attrs):
64+
from paddle.jit.dy2static.utils import CUDAGraphState
6165

6266
run_state = self.state
6367
prog_attrs, cuda_graph_attrs = attrs
64-
if run_state == jit_utils.CUDAGraphState.REPLAY:
68+
if run_state == CUDAGraphState.REPLAY:
6569
if self.batch_size not in self.captured_batch_size:
66-
run_state = jit_utils.CUDAGraphState.DISABLE
67-
elif run_state == jit_utils.CUDAGraphState.CAPTURE:
70+
run_state = CUDAGraphState.DISABLE
71+
elif run_state == CUDAGraphState.CAPTURE:
6872
self.captured_batch_size.add(self.batch_size)
6973

7074
cuda_graph_attrs |= {
7175
"cuda_graph_state": run_state,
72-
"cuda_graph_dispatch_key": self.batch_size if run_state != jit_utils.CUDAGraphState.DISABLE else 0,
76+
"cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0,
7377
}
7478
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))
7579

@@ -102,6 +106,7 @@ def __init__(
102106
self.cuda_graph_manager = Dy2StCudaGraphManager()
103107

104108
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
109+
from paddle.jit.dy2static.utils import CUDAGraphState
105110

106111
if not entry.captured:
107112
# Warmup the model
@@ -118,21 +123,21 @@ def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
118123
entry.input_addresses = input_addresses
119124

120125
# Capture
121-
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.CAPTURE
126+
self.cuda_graph_manager.state = CUDAGraphState.CAPTURE
122127
self.cuda_graph_manager.batch_size = entry.real_shape
123128
entry.captured = True
124129
with self.cuda_graph_manager.run_impl_guard():
125130
entry.runnable(**kwargs)
126131

127132
# Replay
128-
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.REPLAY
133+
self.cuda_graph_manager.state = CUDAGraphState.REPLAY
129134
self.cuda_graph_manager.batch_size = entry.real_shape
130135
with self.cuda_graph_manager.run_impl_guard():
131136
return entry.runnable(**kwargs)
132137

133138
def __call__(self, **kwargs):
134139
# Get real shape(all num tokens)
135-
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
140+
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
136141
real_shape = ids_remove_padding.shape[0]
137142
padding_real_shape = self.real_shape_to_captured_size[real_shape]
138143
logger.debug(
@@ -173,14 +178,22 @@ def __call__(self, **kwargs):
173178
# Capture
174179
with capture_custom_allreduce():
175180
new_grpah.capture_begin()
176-
output = entry.runnable(**kwargs)
181+
outputs = entry.runnable(**kwargs)
182+
if isinstance(outputs, paddle.Tensor):
183+
assert outputs is not None
184+
outputs = [outputs]
177185
new_grpah.capture_end()
178186

179187
# Store output buffer
180188
entry.cuda_graph = new_grpah
181-
entry.output_buffer = paddle.zeros_like(output)
182-
output._share_buffer_to(entry.output_buffer)
183-
output._clear
189+
for output in outputs:
190+
if output is not None:
191+
output_buffer = paddle.zeros_like(output)
192+
output._share_buffer_to(output_buffer)
193+
output._clear
194+
entry.output_buffers.append(output_buffer)
195+
else:
196+
entry.output_buffers.append(None)
184197

185198
paddle.device.synchronize()
186199

@@ -191,7 +204,9 @@ def __call__(self, **kwargs):
191204
# Replay
192205
entry.cuda_graph.replay()
193206
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
194-
return entry.output_buffer
207+
if len(entry.output_buffers) == 1:
208+
return entry.output_buffers[0]
209+
return entry.output_buffers
195210

196211
def _create_entry_dict(self):
197212
""" """
@@ -221,8 +236,11 @@ def clear_graph(self):
221236

222237
def _save_cudagrpah_dot_files(self, entry):
223238
"""Print CUDAGrpah to dot files"""
239+
log_dir = envs.FD_LOG_DIR
240+
if not os.path.exists(log_dir):
241+
os.makedirs(log_dir, exist_ok=True)
224242
if entry.cuda_graph:
225243
entry.cuda_graph.print_to_dot_files(
226-
f"./log/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
244+
f"{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
227245
1 << 0,
228246
)

0 commit comments

Comments
 (0)