7
7
tensor_model_parallel_all_reduce ,
8
8
tensor_model_parallel_reduce_scatter )
9
9
from vllm .forward_context import get_forward_context
10
+ from vllm .logger import logger
10
11
from vllm .utils import direct_register_custom_op
11
12
12
13
import vllm_ascend .envs as envs_ascend
13
14
14
15
15
16
def _maybe_chunk_residual_impl (x : torch .Tensor ,
16
17
residual : torch .Tensor ) -> torch .Tensor :
18
+ try :
19
+ forward_context = get_forward_context ()
20
+ except AssertionError :
21
+ logger .info ("Forward context is None, skipping the operation." )
22
+ return residual
23
+
17
24
if x .size (0 ) != residual .size (0 ):
18
- flashcomm_v1_enabled = get_forward_context () .flashcomm_v1_enabled
25
+ flashcomm_v1_enabled = forward_context .flashcomm_v1_enabled
19
26
assert flashcomm_v1_enabled is True , (
20
27
"Currently, this situation only occurs "
21
28
"when flashcomm_v1 is enabled" )
22
- pad_size = get_forward_context () .pad_size
29
+ pad_size = forward_context .pad_size
23
30
if pad_size > 0 :
24
31
residual = F .pad (residual , (0 , 0 , 0 , pad_size ))
25
32
tp_size = get_tensor_model_parallel_world_size ()
@@ -31,19 +38,31 @@ def _maybe_chunk_residual_impl(x: torch.Tensor,
31
38
32
39
def _maybe_all_gather_and_maybe_unpad_impl (x : torch .Tensor ,
33
40
label : bool ) -> torch .Tensor :
34
- flashcomm_v1_enabled = get_forward_context ().flashcomm_v1_enabled
41
+ try :
42
+ forward_context = get_forward_context ()
43
+ except AssertionError :
44
+ logger .info ("Forward context is None, skipping the operation." )
45
+ return x
46
+
47
+ flashcomm_v1_enabled = forward_context .flashcomm_v1_enabled
35
48
if flashcomm_v1_enabled and label :
36
49
x = tensor_model_parallel_all_gather (x , 0 )
37
- pad_size = get_forward_context () .pad_size
50
+ pad_size = forward_context .pad_size
38
51
if pad_size > 0 :
39
52
x = x [:- pad_size , :]
40
53
return x
41
54
42
55
43
56
def _maybe_pad_and_reduce_impl (x : torch .Tensor ) -> torch .Tensor :
44
- flashcomm_v1_enabled = get_forward_context ().flashcomm_v1_enabled
57
+ try :
58
+ forward_context = get_forward_context ()
59
+ except AssertionError :
60
+ logger .info ("Forward context is None, skipping the operation." )
61
+ return tensor_model_parallel_all_reduce (x )
62
+
63
+ flashcomm_v1_enabled = forward_context .flashcomm_v1_enabled
45
64
if flashcomm_v1_enabled :
46
- pad_size = get_forward_context () .pad_size
65
+ pad_size = forward_context .pad_size
47
66
if pad_size > 0 :
48
67
x = F .pad (x , (0 , 0 , 0 , pad_size ))
49
68
return tensor_model_parallel_reduce_scatter (x , 0 )
@@ -53,7 +72,12 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
53
72
54
73
def _maybe_prefetch_mlp_gate_up_proj_impl (x_dependency : torch .Tensor ,
55
74
prefix : str ) -> None :
56
- forward_context = get_forward_context ()
75
+ try :
76
+ forward_context = get_forward_context ()
77
+ except AssertionError :
78
+ logger .info ("Forward context is None, skipping the operation." )
79
+ return
80
+
57
81
if not forward_context .prefetch_mlp_enabled :
58
82
return
59
83
model_instance = forward_context .model_instance
@@ -67,9 +91,9 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
67
91
prefetch_stream .wait_stream (torch .npu .current_stream ())
68
92
69
93
with torch .npu .stream (prefetch_stream ):
70
- MLP_GATE_UP_PREFETCH_SIZE = envs_ascend .MLP_GATE_UP_PREFETCH_SIZE
94
+ mlp_gate_up_prefetch_size = envs_ascend .VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
71
95
torch_npu .npu_prefetch (model_instance .model .layers [layer_idx ].mlp .gate_up_proj .weight , \
72
- x_dependency , MLP_GATE_UP_PREFETCH_SIZE )
96
+ x_dependency , mlp_gate_up_prefetch_size )
73
97
return
74
98
75
99
@@ -79,7 +103,12 @@ def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
79
103
80
104
81
105
def _maybe_prefetch_mlp_down_proj_impl (x_dependency : torch .Tensor ) -> None :
82
- forward_context = get_forward_context ()
106
+ try :
107
+ forward_context = get_forward_context ()
108
+ except AssertionError :
109
+ logger .info ("Forward context is None, skipping the operation." )
110
+ return
111
+
83
112
if not forward_context .prefetch_mlp_enabled :
84
113
return
85
114
forward_context .prefetch_mlp_down_proj = True
@@ -91,9 +120,9 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
91
120
prefetch_stream .wait_stream (torch .npu .current_stream ())
92
121
93
122
with torch .npu .stream (prefetch_stream ):
94
- MLP_DOWN_PREFETCH_SIZE = envs_ascend .MLP_DOWN_PREFETCH_SIZE
123
+ mlp_down_prefetch_size = envs_ascend .VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
95
124
torch_npu .npu_prefetch (model_instance .model .layers [layer_idx ].mlp .down_proj .weight , \
96
- x_dependency , MLP_DOWN_PREFETCH_SIZE )
125
+ x_dependency , mlp_down_prefetch_size )
97
126
forward_context .layer_idx += 1
98
127
return
99
128
@@ -104,12 +133,17 @@ def _maybe_prefetch_mlp_down_proj_impl_fake(
104
133
105
134
106
135
def _maybe_wait_prefetch_done_impl (x : torch .Tensor ) -> None :
107
- forward_context = get_forward_context ()
136
+ try :
137
+ forward_context = get_forward_context ()
138
+ except AssertionError :
139
+ logger .info ("Forward context is None, skipping the operation." )
140
+ return
141
+
108
142
if not forward_context .prefetch_mlp_enabled :
109
143
return
110
144
if forward_context .prefetch_mlp_gate_up_proj or \
111
145
forward_context .prefetch_mlp_down_proj :
112
- prefetch_stream = get_forward_context () .prefetch_stream
146
+ prefetch_stream = forward_context .prefetch_stream
113
147
# wait until prefetch done
114
148
torch .npu .current_stream ().wait_stream (prefetch_stream )
115
149
forward_context .prefetch_mlp_gate_up_proj = False
0 commit comments