diff --git a/fastdeploy/input/preprocess.py b/fastdeploy/input/preprocess.py index 48626f380d..5b8eb3ccd3 100644 --- a/fastdeploy/input/preprocess.py +++ b/fastdeploy/input/preprocess.py @@ -71,15 +71,9 @@ def create_processor(self): """ reasoning_parser_obj = None tool_parser_obj = None - try: - from fastdeploy.plugins.reasoning_parser import ( - load_reasoning_parser_plugins, - ) - reasoning_parser_obj = load_reasoning_parser_plugins() - except: - if self.reasoning_parser: - reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser) + if self.reasoning_parser: + reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser) if self.tool_parser: tool_parser_obj = ToolParserManager.get_tool_parser(self.tool_parser) diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index d975833ea0..8e488d57ee 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -14,14 +14,16 @@ # limitations under the License. """ +import os from contextlib import contextmanager -from dataclasses import dataclass -from typing import Callable, Dict, Optional +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional import paddle.jit.dy2static.utils as jit_utils import paddle.nn.layer from paddle.device.cuda import graphs +from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.distributed.communication import capture_custom_allreduce from fastdeploy.utils import get_logger @@ -46,8 +48,8 @@ class ConcreteSizeEntry: num_finished_warmup: int = 0 # Captured cuda graph object corresponding to the current real shape cuda_graph: Optional[graphs.CUDAGraph] = None - # Output buffer of cudagraph - output_buffer: Optional[paddle.Tensor] = None + # Output buffers of cudagraph + output_buffers: List[Optional[paddle.Tensor]] = field(default_factory=list) class Dy2StCudaGraphManager: @@ -130,9 +132,9 @@ def run_static_model(self, entry: ConcreteSizeEntry, **kwargs): with self.cuda_graph_manager.run_impl_guard(): return entry.runnable(**kwargs) - def __call__(self, **kwargs): + def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor: # Get real shape(all num tokens) - ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"] + ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding real_shape = ids_remove_padding.shape[0] padding_real_shape = self.real_shape_to_captured_size[real_shape] logger.debug( @@ -173,14 +175,22 @@ def __call__(self, **kwargs): # Capture with capture_custom_allreduce(): new_grpah.capture_begin() - output = entry.runnable(**kwargs) + outputs = entry.runnable(**kwargs) + if isinstance(outputs, paddle.Tensor): + assert outputs is not None + outputs = [outputs] new_grpah.capture_end() # Store output buffer entry.cuda_graph = new_grpah - entry.output_buffer = paddle.zeros_like(output) - output._share_buffer_to(entry.output_buffer) - output._clear + for output in outputs: + if output is not None: + output_buffer = paddle.zeros_like(output) + output._share_buffer_to(output_buffer) + output._clear + entry.output_buffers.append(output_buffer) + else: + entry.output_buffers.append(None) paddle.device.synchronize() @@ -191,7 +201,9 @@ def __call__(self, **kwargs): # Replay entry.cuda_graph.replay() logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}") - return entry.output_buffer + if len(entry.output_buffers) == 1: + return entry.output_buffers[0] + return entry.output_buffers def _create_entry_dict(self): """ """ @@ -221,8 +233,11 @@ def clear_graph(self): def _save_cudagrpah_dot_files(self, entry): """Print CUDAGrpah to dot files""" + log_dir = envs.FD_LOG_DIR + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) if entry.cuda_graph: entry.cuda_graph.print_to_dot_files( - f"./log/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}", + f"{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}", 1 << 0, ) diff --git a/fastdeploy/plugins/input_processor/__init__.py b/fastdeploy/plugins/input_processor/__init__.py index d7c698f44e..eca894f42c 100644 --- a/fastdeploy/plugins/input_processor/__init__.py +++ b/fastdeploy/plugins/input_processor/__init__.py @@ -23,5 +23,5 @@ def load_input_processor_plugins(): """load_input_processor_plugins""" plugins = load_plugins_by_group(group=PLUGINS_GROUP) - assert len(plugins) <= 1, "Most one plugin is allowed to be loaded." + assert len(plugins) == 1, "Only one plugin is allowed to be loaded." return next(iter(plugins.values()))() diff --git a/fastdeploy/plugins/model_runner/__init__.py b/fastdeploy/plugins/model_runner/__init__.py index 8897abfbc0..19ce33ce8a 100644 --- a/fastdeploy/plugins/model_runner/__init__.py +++ b/fastdeploy/plugins/model_runner/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. """ -from fastdeploy.plugins.utils import load_plugins_by_group, plugins_loaded +from fastdeploy.plugins.utils import load_plugins_by_group # use for modle runner PLUGINS_GROUP = "fastdeploy.model_runner_plugins" @@ -22,11 +22,6 @@ def load_model_runner_plugins(): """load_model_runner_plugins""" - global plugins_loaded - if plugins_loaded: - return - plugins_loaded = True - plugins = load_plugins_by_group(group=PLUGINS_GROUP) - assert len(plugins) <= 1, "Most one plugin is allowed to be loaded." + assert len(plugins) == 1, "Only one plugin is allowed to be loaded." return next(iter(plugins.values()))() diff --git a/fastdeploy/plugins/reasoning_parser/__init__.py b/fastdeploy/plugins/reasoning_parser/__init__.py index bb19e0e70e..ba862d02ae 100644 --- a/fastdeploy/plugins/reasoning_parser/__init__.py +++ b/fastdeploy/plugins/reasoning_parser/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. """ -from fastdeploy.plugins.utils import load_plugins_by_group +from fastdeploy.plugins.utils import load_plugins_by_group, plugins_loaded # make sure one process only loads plugins once PLUGINS_GROUP = "fastdeploy.reasoning_parser_plugins" @@ -22,6 +22,12 @@ def load_reasoning_parser_plugins(): """load_reasoning_parser_plugins""" + global plugins_loaded + if plugins_loaded: + return + plugins_loaded = True + plugins = load_plugins_by_group(group=PLUGINS_GROUP) - assert len(plugins) <= 1, "Most one plugin is allowed to be loaded." - return next(iter(plugins.values()))() + # general plugins, we only need to execute the loaded functions + for func in plugins.values(): + func() diff --git a/fastdeploy/reasoning/__init__.py b/fastdeploy/reasoning/__init__.py index 51f59776e0..49c627895f 100644 --- a/fastdeploy/reasoning/__init__.py +++ b/fastdeploy/reasoning/__init__.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from fastdeploy.plugins import load_reasoning_parser_plugins + from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from .ernie_vl_reasoning_parsers import ErnieVLReasoningParser from .ernie_x1_reasoning_parsers import ErnieX1ReasoningParser @@ -26,3 +28,5 @@ "Qwen3ReasoningParser", "ErnieX1ReasoningParser", ] + +load_reasoning_parser_plugins() diff --git a/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py index c75182b014..8dbfb23ca9 100644 --- a/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py +++ b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py @@ -1,14 +1,6 @@ +""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # -# -from collections.abc import Sequence -from typing import Tuple, Union - -from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage -from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager - -# -# # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,6 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" + +from collections.abc import Sequence +from typing import Tuple, Union + +from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager @ReasoningParserManager.register_module("ernie_x1") diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 04ea94e8b0..b7a905b9c8 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -248,6 +248,11 @@ def init_health_status(self) -> None: create=False, ) + def _broadcast_model_weights_signal(self, src: int, group) -> int: + model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32") + paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group) + return model_weights_signal_tensor.item() + def event_loop_normal(self) -> None: """Main event loop for Paddle Distributed Workers. TODO(gongshaotian): support remote calling of functions that control worker. @@ -257,15 +262,19 @@ def event_loop_normal(self) -> None: req_ids = [] num_running_requests = 0 - self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32) + self.model_weights_signal = np.zeros([1], dtype=np.int32) while True: if self.local_rank % self.parallel_config.tensor_parallel_size == 0: if self.model_weights_status.value[0] != 0: self.model_weights_signal[0] = int(self.model_weights_status.value[0]) if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel: - paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group) - if self.fd_config.load_config.dynamic_load_weight: - paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.tp_group) + self.model_weights_signal[0] = self._broadcast_model_weights_signal( + src=0, group=self.parallel_config.ep_group + ) + if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.tensor_parallel_size > 1: + self.model_weights_signal[0] = self._broadcast_model_weights_signal( + src=0, group=self.parallel_config.tp_group + ) self.insert_step = False req_dicts = None @@ -294,7 +303,9 @@ def event_loop_normal(self) -> None: else: paddle.distributed.barrier(self.parallel_config.tp_group) if self.model_weights_signal[0] != 0: - logger.info(f"Rank: {self.local_rank} has updated parameters.") + logger.info( + f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]" + ) from fastdeploy.rl.dynamic_weight_manager import ( DynamicWeightManager, ) @@ -307,6 +318,7 @@ def event_loop_normal(self) -> None: self.parallel_config.engine_worker_queue_port, ) self.model_weights_signal[0] = 0 + logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.") if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1: logger.info(f"Rank: {self.local_rank} Detected new requests.")