Skip to content

Commit b1b3321

Browse files
authored
[CUDAGraph] Support multi output buffers and merge some fixes from feature/exp_0908 (#4062)
* refine cudagraph * refine cudagraph * typo * fix * fix plugins * fix * update * update * update
1 parent 9409665 commit b1b3321

File tree

8 files changed

+70
-45
lines changed

8 files changed

+70
-45
lines changed

fastdeploy/input/preprocess.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,9 @@ def create_processor(self):
7171
"""
7272
reasoning_parser_obj = None
7373
tool_parser_obj = None
74-
try:
75-
from fastdeploy.plugins.reasoning_parser import (
76-
load_reasoning_parser_plugins,
77-
)
7874

79-
reasoning_parser_obj = load_reasoning_parser_plugins()
80-
except:
81-
if self.reasoning_parser:
82-
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser)
75+
if self.reasoning_parser:
76+
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser)
8377
if self.tool_parser:
8478
tool_parser_obj = ToolParserManager.get_tool_parser(self.tool_parser)
8579

fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414
# limitations under the License.
1515
"""
1616

17+
import os
1718
from contextlib import contextmanager
18-
from dataclasses import dataclass
19-
from typing import Callable, Dict, Optional
19+
from dataclasses import dataclass, field
20+
from typing import Callable, Dict, List, Optional
2021

2122
import paddle.jit.dy2static.utils as jit_utils
2223
import paddle.nn.layer
2324
from paddle.device.cuda import graphs
2425

26+
from fastdeploy import envs
2527
from fastdeploy.config import FDConfig
2628
from fastdeploy.distributed.communication import capture_custom_allreduce
2729
from fastdeploy.utils import get_logger
@@ -46,8 +48,8 @@ class ConcreteSizeEntry:
4648
num_finished_warmup: int = 0
4749
# Captured cuda graph object corresponding to the current real shape
4850
cuda_graph: Optional[graphs.CUDAGraph] = None
49-
# Output buffer of cudagraph
50-
output_buffer: Optional[paddle.Tensor] = None
51+
# Output buffers of cudagraph
52+
output_buffers: List[Optional[paddle.Tensor]] = field(default_factory=list)
5153

5254

5355
class Dy2StCudaGraphManager:
@@ -130,9 +132,9 @@ def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
130132
with self.cuda_graph_manager.run_impl_guard():
131133
return entry.runnable(**kwargs)
132134

133-
def __call__(self, **kwargs):
135+
def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor:
134136
# Get real shape(all num tokens)
135-
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
137+
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
136138
real_shape = ids_remove_padding.shape[0]
137139
padding_real_shape = self.real_shape_to_captured_size[real_shape]
138140
logger.debug(
@@ -173,14 +175,22 @@ def __call__(self, **kwargs):
173175
# Capture
174176
with capture_custom_allreduce():
175177
new_grpah.capture_begin()
176-
output = entry.runnable(**kwargs)
178+
outputs = entry.runnable(**kwargs)
179+
if isinstance(outputs, paddle.Tensor):
180+
assert outputs is not None
181+
outputs = [outputs]
177182
new_grpah.capture_end()
178183

179184
# Store output buffer
180185
entry.cuda_graph = new_grpah
181-
entry.output_buffer = paddle.zeros_like(output)
182-
output._share_buffer_to(entry.output_buffer)
183-
output._clear
186+
for output in outputs:
187+
if output is not None:
188+
output_buffer = paddle.zeros_like(output)
189+
output._share_buffer_to(output_buffer)
190+
output._clear
191+
entry.output_buffers.append(output_buffer)
192+
else:
193+
entry.output_buffers.append(None)
184194

185195
paddle.device.synchronize()
186196

@@ -191,7 +201,9 @@ def __call__(self, **kwargs):
191201
# Replay
192202
entry.cuda_graph.replay()
193203
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
194-
return entry.output_buffer
204+
if len(entry.output_buffers) == 1:
205+
return entry.output_buffers[0]
206+
return entry.output_buffers
195207

196208
def _create_entry_dict(self):
197209
""" """
@@ -221,8 +233,11 @@ def clear_graph(self):
221233

222234
def _save_cudagrpah_dot_files(self, entry):
223235
"""Print CUDAGrpah to dot files"""
236+
log_dir = envs.FD_LOG_DIR
237+
if not os.path.exists(log_dir):
238+
os.makedirs(log_dir, exist_ok=True)
224239
if entry.cuda_graph:
225240
entry.cuda_graph.print_to_dot_files(
226-
f"./log/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
241+
f"{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
227242
1 << 0,
228243
)

fastdeploy/plugins/input_processor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@
2323
def load_input_processor_plugins():
2424
"""load_input_processor_plugins"""
2525
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
26-
assert len(plugins) <= 1, "Most one plugin is allowed to be loaded."
26+
assert len(plugins) == 1, "Only one plugin is allowed to be loaded."
2727
return next(iter(plugins.values()))()

fastdeploy/plugins/model_runner/__init__.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,14 @@
1414
# limitations under the License.
1515
"""
1616

17-
from fastdeploy.plugins.utils import load_plugins_by_group, plugins_loaded
17+
from fastdeploy.plugins.utils import load_plugins_by_group
1818

1919
# use for modle runner
2020
PLUGINS_GROUP = "fastdeploy.model_runner_plugins"
2121

2222

2323
def load_model_runner_plugins():
2424
"""load_model_runner_plugins"""
25-
global plugins_loaded
26-
if plugins_loaded:
27-
return
28-
plugins_loaded = True
29-
3025
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
31-
assert len(plugins) <= 1, "Most one plugin is allowed to be loaded."
26+
assert len(plugins) == 1, "Only one plugin is allowed to be loaded."
3227
return next(iter(plugins.values()))()

fastdeploy/plugins/reasoning_parser/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,20 @@
1414
# limitations under the License.
1515
"""
1616

17-
from fastdeploy.plugins.utils import load_plugins_by_group
17+
from fastdeploy.plugins.utils import load_plugins_by_group, plugins_loaded
1818

1919
# make sure one process only loads plugins once
2020
PLUGINS_GROUP = "fastdeploy.reasoning_parser_plugins"
2121

2222

2323
def load_reasoning_parser_plugins():
2424
"""load_reasoning_parser_plugins"""
25+
global plugins_loaded
26+
if plugins_loaded:
27+
return
28+
plugins_loaded = True
29+
2530
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
26-
assert len(plugins) <= 1, "Most one plugin is allowed to be loaded."
27-
return next(iter(plugins.values()))()
31+
# general plugins, we only need to execute the loaded functions
32+
for func in plugins.values():
33+
func()

fastdeploy/reasoning/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515
"""
1616

17+
from fastdeploy.plugins import load_reasoning_parser_plugins
18+
1719
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
1820
from .ernie_vl_reasoning_parsers import ErnieVLReasoningParser
1921
from .ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
@@ -26,3 +28,5 @@
2628
"Qwen3ReasoningParser",
2729
"ErnieX1ReasoningParser",
2830
]
31+
32+
load_reasoning_parser_plugins()

fastdeploy/reasoning/ernie_x1_reasoning_parsers.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
1+
"""
12
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
23
#
3-
#
4-
from collections.abc import Sequence
5-
from typing import Tuple, Union
6-
7-
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
8-
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
9-
10-
#
11-
#
124
# Licensed under the Apache License, Version 2.0 (the "License"
135
# you may not use this file except in compliance with the License.
146
# You may obtain a copy of the License at
@@ -20,6 +12,13 @@
2012
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2113
# See the License for the specific language governing permissions and
2214
# limitations under the License.
15+
"""
16+
17+
from collections.abc import Sequence
18+
from typing import Tuple, Union
19+
20+
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
21+
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
2322

2423

2524
@ReasoningParserManager.register_module("ernie_x1")

fastdeploy/worker/worker_process.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ def init_health_status(self) -> None:
248248
create=False,
249249
)
250250

251+
def _broadcast_model_weights_signal(self, src: int, group) -> int:
252+
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
253+
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
254+
return model_weights_signal_tensor.item()
255+
251256
def event_loop_normal(self) -> None:
252257
"""Main event loop for Paddle Distributed Workers.
253258
TODO(gongshaotian): support remote calling of functions that control worker.
@@ -257,15 +262,19 @@ def event_loop_normal(self) -> None:
257262
req_ids = []
258263
num_running_requests = 0
259264

260-
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
265+
self.model_weights_signal = np.zeros([1], dtype=np.int32)
261266
while True:
262267
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
263268
if self.model_weights_status.value[0] != 0:
264269
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
265270
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
266-
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group)
267-
if self.fd_config.load_config.dynamic_load_weight:
268-
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.tp_group)
271+
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
272+
src=0, group=self.parallel_config.ep_group
273+
)
274+
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.tensor_parallel_size > 1:
275+
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
276+
src=0, group=self.parallel_config.tp_group
277+
)
269278

270279
self.insert_step = False
271280
req_dicts = None
@@ -294,7 +303,9 @@ def event_loop_normal(self) -> None:
294303
else:
295304
paddle.distributed.barrier(self.parallel_config.tp_group)
296305
if self.model_weights_signal[0] != 0:
297-
logger.info(f"Rank: {self.local_rank} has updated parameters.")
306+
logger.info(
307+
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
308+
)
298309
from fastdeploy.rl.dynamic_weight_manager import (
299310
DynamicWeightManager,
300311
)
@@ -307,6 +318,7 @@ def event_loop_normal(self) -> None:
307318
self.parallel_config.engine_worker_queue_port,
308319
)
309320
self.model_weights_signal[0] = 0
321+
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
310322

311323
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1:
312324
logger.info(f"Rank: {self.local_rank} Detected new requests.")

0 commit comments

Comments
 (0)