Skip to content

Commit fc2bcbe

Browse files
authored
[Ops] Fix bug in register_custom_ops without forward_context (#2883)
### What this PR does / why we need it? This PR fixed the bug in register_custom_ops without forward_context. We set try-except to consider this situation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. - vLLM version: main - vLLM main: vllm-project/vllm@7920de0 Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 6d8bc38 commit fc2bcbe

File tree

2 files changed

+54
-18
lines changed

2 files changed

+54
-18
lines changed

vllm_ascend/envs.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,13 @@
139139
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
140140
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
141141
# buffer size for gate up prefetch
142-
"MLP_GATE_UP_PREFETCH_SIZE":
143-
lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
142+
"VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE":
143+
lambda: int(
144+
os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
144145
# buffer size for down proj prefetch
145-
"MLP_DOWN_PREFETCH_SIZE":
146-
lambda: int(os.getenv("MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
146+
"VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE":
147+
lambda: int(
148+
os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
147149
# Whether to enable dense model and general optimizations for better performance.
148150
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
149151
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.

vllm_ascend/ops/register_custom_ops.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,26 @@
77
tensor_model_parallel_all_reduce,
88
tensor_model_parallel_reduce_scatter)
99
from vllm.forward_context import get_forward_context
10+
from vllm.logger import logger
1011
from vllm.utils import direct_register_custom_op
1112

1213
import vllm_ascend.envs as envs_ascend
1314

1415

1516
def _maybe_chunk_residual_impl(x: torch.Tensor,
1617
residual: torch.Tensor) -> torch.Tensor:
18+
try:
19+
forward_context = get_forward_context()
20+
except AssertionError:
21+
logger.info("Forward context is None, skipping the operation.")
22+
return residual
23+
1724
if x.size(0) != residual.size(0):
18-
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
25+
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
1926
assert flashcomm_v1_enabled is True, (
2027
"Currently, this situation only occurs "
2128
"when flashcomm_v1 is enabled")
22-
pad_size = get_forward_context().pad_size
29+
pad_size = forward_context.pad_size
2330
if pad_size > 0:
2431
residual = F.pad(residual, (0, 0, 0, pad_size))
2532
tp_size = get_tensor_model_parallel_world_size()
@@ -31,19 +38,31 @@ def _maybe_chunk_residual_impl(x: torch.Tensor,
3138

3239
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
3340
label: bool) -> torch.Tensor:
34-
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
41+
try:
42+
forward_context = get_forward_context()
43+
except AssertionError:
44+
logger.info("Forward context is None, skipping the operation.")
45+
return x
46+
47+
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
3548
if flashcomm_v1_enabled and label:
3649
x = tensor_model_parallel_all_gather(x, 0)
37-
pad_size = get_forward_context().pad_size
50+
pad_size = forward_context.pad_size
3851
if pad_size > 0:
3952
x = x[:-pad_size, :]
4053
return x
4154

4255

4356
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
44-
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
57+
try:
58+
forward_context = get_forward_context()
59+
except AssertionError:
60+
logger.info("Forward context is None, skipping the operation.")
61+
return tensor_model_parallel_all_reduce(x)
62+
63+
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
4564
if flashcomm_v1_enabled:
46-
pad_size = get_forward_context().pad_size
65+
pad_size = forward_context.pad_size
4766
if pad_size > 0:
4867
x = F.pad(x, (0, 0, 0, pad_size))
4968
return tensor_model_parallel_reduce_scatter(x, 0)
@@ -53,7 +72,12 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
5372

5473
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
5574
prefix: str) -> None:
56-
forward_context = get_forward_context()
75+
try:
76+
forward_context = get_forward_context()
77+
except AssertionError:
78+
logger.info("Forward context is None, skipping the operation.")
79+
return
80+
5781
if not forward_context.prefetch_mlp_enabled:
5882
return
5983
model_instance = forward_context.model_instance
@@ -67,9 +91,9 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
6791
prefetch_stream.wait_stream(torch.npu.current_stream())
6892

6993
with torch.npu.stream(prefetch_stream):
70-
MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE
94+
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
7195
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
72-
x_dependency, MLP_GATE_UP_PREFETCH_SIZE)
96+
x_dependency, mlp_gate_up_prefetch_size)
7397
return
7498

7599

@@ -79,7 +103,12 @@ def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
79103

80104

81105
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
82-
forward_context = get_forward_context()
106+
try:
107+
forward_context = get_forward_context()
108+
except AssertionError:
109+
logger.info("Forward context is None, skipping the operation.")
110+
return
111+
83112
if not forward_context.prefetch_mlp_enabled:
84113
return
85114
forward_context.prefetch_mlp_down_proj = True
@@ -91,9 +120,9 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
91120
prefetch_stream.wait_stream(torch.npu.current_stream())
92121

93122
with torch.npu.stream(prefetch_stream):
94-
MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE
123+
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
95124
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
96-
x_dependency, MLP_DOWN_PREFETCH_SIZE)
125+
x_dependency, mlp_down_prefetch_size)
97126
forward_context.layer_idx += 1
98127
return
99128

@@ -104,12 +133,17 @@ def _maybe_prefetch_mlp_down_proj_impl_fake(
104133

105134

106135
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
107-
forward_context = get_forward_context()
136+
try:
137+
forward_context = get_forward_context()
138+
except AssertionError:
139+
logger.info("Forward context is None, skipping the operation.")
140+
return
141+
108142
if not forward_context.prefetch_mlp_enabled:
109143
return
110144
if forward_context.prefetch_mlp_gate_up_proj or \
111145
forward_context.prefetch_mlp_down_proj:
112-
prefetch_stream = get_forward_context().prefetch_stream
146+
prefetch_stream = forward_context.prefetch_stream
113147
# wait until prefetch done
114148
torch.npu.current_stream().wait_stream(prefetch_stream)
115149
forward_context.prefetch_mlp_gate_up_proj = False

0 commit comments

Comments
 (0)