Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 52 additions & 91 deletions vllm_ascend/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,31 @@ def __init__(
*,
return_bias: bool = True,
):
# LinearBase.__init__ will specify the original TP communication group, but will later override it with a custom communication parallel group.
LinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)

self.comm_group = None
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
self.comm_group = get_mlp_tp_group()
self.forward_type = "mlp_tp"
else:
self.comm_group = get_tp_group()
self.forward_type = "normal_tp"

self.tp_size = self.comm_group.world_size
self.tp_rank = self.comm_group.rank_in_group

self.output_sizes = output_sizes
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
Comment on lines +90 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This assertion will raise a TypeError if output_sizes is None. Since output_sizes is an optional argument in AscendColumnParallelLinear.__init__ with a default value of None, this will cause a crash when the class is initialized without this argument.

A similar issue exists in the if hasattr(self, "output_sizes") block at line 97, which is not part of the diff but is affected by this change. The list comprehension at line 100 will also fail if self.output_sizes is None.

Please add a check for output_sizes being non-None before attempting to iterate over it.

Suggested change
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
if output_sizes:
assert all(output_size % self.tp_size == 0 for output_size in output_sizes)


self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
Expand All @@ -83,17 +99,8 @@ def __init__(
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)

self.gather_output = gather_output

if output_sizes is None:
output_sizes = [output_size]

Expand All @@ -118,6 +125,30 @@ def __init__(
})
else:
self.register_parameter("bias", None)

def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.forward_type == "mlp_tp":
return self._forward_mlp_tp(input_)
else:
return super().forward(input_)

def _forward_mlp_tp(
self,
input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
input_parallel = get_mlp_tp_group().all_gather(input_, 0)
output = self.quant_method.apply(self, input_parallel, bias)

output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias


class AscendRowParallelLinear(RowParallelLinear):
Expand All @@ -140,6 +171,16 @@ def __init__(
*,
return_bias: bool = True,
):
# LinearBase.__init__ will specify the original TP communication group, but we will later override it with a custom communication parallel group.
LinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)

if prefix.find("down_proj") != -1 and mlp_tp_enable():
comm_group = get_mlp_tp_group()
self.forward_type = "mlp_tp"
Expand All @@ -163,15 +204,6 @@ def __init__(
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]

AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results

Expand Down Expand Up @@ -356,20 +388,7 @@ def __init__(
prefix: str = "",
*,
return_bias: bool = True,
):
self.comm_group = None
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
self.comm_group = get_mlp_tp_group()
self.forward_type = "mlp_tp"
else:
self.comm_group = get_tp_group()
self.forward_type = "normal_tp"
self.tp_rank = self.comm_group.rank_in_group
self.tp_size = self.comm_group.world_size

self.output_sizes = output_sizes
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
):
AscendColumnParallelLinear.__init__(self,
input_size=input_size,
output_size=sum(output_sizes),
Expand All @@ -379,63 +398,5 @@ def __init__(
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias)

def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.forward_type == "mlp_tp":
return self._forward_mlp_tp(input_)
else:
return super().forward(input_)

def _forward_mlp_tp(
self,
input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
input_parallel = get_mlp_tp_group().all_gather(input_, 0)
output = self.quant_method.apply(self, input_parallel, bias)

output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias


class AscendLinearBase(LinearBase):

def __init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
nn.Module.__init__(self)

# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.quant_config = quant_config
self.prefix = prefix
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
self.return_bias = return_bias
self.disable_tp = disable_tp
Loading