From 20c631d97307626fe915bdb9fb3f1732605ce8dd Mon Sep 17 00:00:00 2001 From: zzhx1 Date: Wed, 10 Sep 2025 21:44:00 +0800 Subject: [PATCH] refactor Signed-off-by: zzhx1 --- vllm_ascend/ops/linear.py | 143 ++++++++++++++------------------------ 1 file changed, 52 insertions(+), 91 deletions(-) diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index c29837a35d..37cd84c981 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -65,15 +65,31 @@ def __init__( *, return_bias: bool = True, ): + # LinearBase.__init__ will specify the original TP communication group, but will later override it with a custom communication parallel group. + LinearBase.__init__(self, + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + self.comm_group = None if prefix.find("gate_up_proj") != -1 and mlp_tp_enable(): self.comm_group = get_mlp_tp_group() + self.forward_type = "mlp_tp" else: self.comm_group = get_tp_group() + self.forward_type = "normal_tp" self.tp_size = self.comm_group.world_size self.tp_rank = self.comm_group.rank_in_group + self.output_sizes = output_sizes + assert all(output_size % self.tp_size == 0 + for output_size in output_sizes) + self.input_size_per_partition = input_size self.output_size_per_partition = divide(output_size, self.tp_size) self.output_partition_sizes = [self.output_size_per_partition] @@ -83,17 +99,8 @@ def __init__( divide(output_size, self.tp_size) for output_size in self.output_sizes ] - AscendLinearBase.__init__(self, - input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias) self.gather_output = gather_output - if output_sizes is None: output_sizes = [output_size] @@ -118,6 +125,30 @@ def __init__( }) else: self.register_parameter("bias", None) + + def forward( + self, + input_, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.forward_type == "mlp_tp": + return self._forward_mlp_tp(input_) + else: + return super().forward(input_) + + def _forward_mlp_tp( + self, + input_: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + # Matrix multiply. + assert self.quant_method is not None + input_parallel = get_mlp_tp_group().all_gather(input_, 0) + output = self.quant_method.apply(self, input_parallel, bias) + + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias class AscendRowParallelLinear(RowParallelLinear): @@ -140,6 +171,16 @@ def __init__( *, return_bias: bool = True, ): + # LinearBase.__init__ will specify the original TP communication group, but we will later override it with a custom communication parallel group. + LinearBase.__init__(self, + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + if prefix.find("down_proj") != -1 and mlp_tp_enable(): comm_group = get_mlp_tp_group() self.forward_type = "mlp_tp" @@ -163,15 +204,6 @@ def __init__( self.output_size_per_partition = output_size self.output_partition_sizes = [output_size] - AscendLinearBase.__init__(self, - input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias) - self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -356,20 +388,7 @@ def __init__( prefix: str = "", *, return_bias: bool = True, - ): - self.comm_group = None - if prefix.find("gate_up_proj") != -1 and mlp_tp_enable(): - self.comm_group = get_mlp_tp_group() - self.forward_type = "mlp_tp" - else: - self.comm_group = get_tp_group() - self.forward_type = "normal_tp" - self.tp_rank = self.comm_group.rank_in_group - self.tp_size = self.comm_group.world_size - - self.output_sizes = output_sizes - assert all(output_size % self.tp_size == 0 - for output_size in output_sizes) + ): AscendColumnParallelLinear.__init__(self, input_size=input_size, output_size=sum(output_sizes), @@ -379,63 +398,5 @@ def __init__( params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, - return_bias=return_bias) + return_bias=return_bias) - def forward( - self, - input_, - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - if self.forward_type == "mlp_tp": - return self._forward_mlp_tp(input_) - else: - return super().forward(input_) - - def _forward_mlp_tp( - self, - input_: torch.Tensor, - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - bias = self.bias if not self.skip_bias_add else None - # Matrix multiply. - assert self.quant_method is not None - input_parallel = get_mlp_tp_group().all_gather(input_, 0) - output = self.quant_method.apply(self, input_parallel, bias) - - output_bias = self.bias if self.skip_bias_add else None - if not self.return_bias: - return output - return output, output_bias - - -class AscendLinearBase(LinearBase): - - def __init__( - self, - input_size: int, - output_size: int, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - *, - return_bias: bool = True, - disable_tp: bool = False, - ): - nn.Module.__init__(self) - - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.skip_bias_add = skip_bias_add - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - self.quant_config = quant_config - self.prefix = prefix - if quant_config is None: - self.quant_method: Optional[ - QuantizeMethodBase] = UnquantizedLinearMethod() - else: - self.quant_method = quant_config.get_quant_method(self, - prefix=prefix) - self.return_bias = return_bias - self.disable_tp = disable_tp