Skip to content

Commit dc20609

Browse files
committed
fix glm accu bug
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent fc2bcbe commit dc20609

File tree

4 files changed

+282
-25
lines changed

4 files changed

+282
-25
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,11 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool,
4141
else:
4242
return FusedMoEState.MC2
4343

44-
45-
def get_dispatcher_name(ep_size: int, with_prefill: bool) -> str:
46-
if ep_size == 1:
47-
return "TokenDispatcherWithAllGather"
48-
elif envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1:
49-
return "TokenDispatcherWithAllGather"
50-
elif ep_size < 16 or with_prefill:
51-
return "TokenDispatcherWithAll2AllV"
52-
else:
53-
return "TokenDispatcherWithMC2"
44+
_moe_method_to_dispatcher = {
45+
"allgather": "TokenDispatcherWithAllGather",
46+
"alltoall": "TokenDispatcherWithAll2AllV",
47+
"mc2": "TokenDispatcherWithMC2",
48+
}
5449

5550

5651
@contextmanager
@@ -98,7 +93,7 @@ def set_ascend_forward_context(
9893
forward_context.in_profile_run = in_profile_run
9994

10095
from vllm_ascend.ops.moe.token_dispatcher import get_token_dispatcher
101-
dispatcher_name = get_dispatcher_name(ep_size, with_prefill)
96+
dispatcher_name = _moe_method_to_dispatcher[moe_comm_method]
10297
dispatcher = get_token_dispatcher(dispatcher_name)
10398
forward_context.token_dispatcher = dispatcher
10499

vllm_ascend/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,6 @@ def register_model():
5353
"PanguProMoEForCausalLM",
5454
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
5555
)
56+
ModelRegistry.register_model(
57+
"Glm4MoeForCausalLM",
58+
"vllm_ascend.models.glm4_moe:CustomGlm4MoeForCausalLM")

vllm_ascend/models/glm4_moe.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2025 The ZhipuAI Team.
3+
# Copyright 2023 The vLLM team.
4+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
# Adapted from vllm/model_executor/models/glm4_moe.py
18+
# This file is a part of the vllm-ascend project.
19+
"""Inference-only GLM-4.5 model compatible with HuggingFace weights."""
20+
from typing import Optional
21+
22+
import torch
23+
from torch import nn
24+
from transformers.models.glm4_moe import Glm4MoeConfig
25+
26+
from vllm.forward_context import get_forward_context
27+
from vllm.compilation.decorators import support_torch_compile
28+
from vllm.config import CacheConfig, VllmConfig
29+
from vllm.distributed import (get_pp_group, tensor_model_parallel_all_reduce)
30+
from vllm.logger import init_logger
31+
from vllm.model_executor.layers.fused_moe import FusedMoE
32+
from vllm.model_executor.layers.layernorm import RMSNorm
33+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
34+
from vllm.model_executor.layers.quantization import QuantizationConfig
35+
from vllm.model_executor.layers.vocab_parallel_embedding import (
36+
ParallelLMHead, VocabParallelEmbedding)
37+
from vllm.model_executor.models.glm4_moe import Glm4MoE, Glm4MoeAttention, Glm4MoeMLP, Glm4MoeModel, Glm4MoeForCausalLM
38+
39+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
40+
from vllm.model_executor.models.utils import (PPMissingLayer, make_empty_intermediate_tensors_factory,
41+
make_layers, maybe_prefix)
42+
43+
logger = init_logger(__name__)
44+
45+
46+
class CustomGlm4MoE(Glm4MoE):
47+
48+
def __init__(
49+
self,
50+
config: Glm4MoeConfig,
51+
quant_config: Optional[QuantizationConfig] = None,
52+
prefix: str = "",
53+
enable_eplb: bool = False,
54+
):
55+
super().__init__(config, quant_config, prefix, enable_eplb)
56+
57+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
58+
num_tokens, hidden_dim = hidden_states.shape
59+
hidden_states = hidden_states.view(-1, hidden_dim)
60+
forward_context = get_forward_context()
61+
moe_comm_method_name = forward_context.moe_comm_method_name
62+
63+
if self.n_shared_experts is not None:
64+
shared_output = self.shared_experts(hidden_states)
65+
else:
66+
shared_output = None
67+
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
68+
final_hidden_states = self.experts(
69+
hidden_states=hidden_states,
70+
router_logits=router_logits) * self.routed_scaling_factor
71+
if shared_output is not None:
72+
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
73+
shared_output = tensor_model_parallel_all_reduce(shared_output)
74+
final_hidden_states = final_hidden_states + shared_output
75+
if self.tp_size > 1:
76+
final_hidden_states = (
77+
self.experts.maybe_all_reduce_tensor_model_parallel(
78+
final_hidden_states))
79+
return final_hidden_states.view(num_tokens, hidden_dim)
80+
81+
82+
class CustomGlm4MoeDecoderLayer(nn.Module):
83+
84+
def __init__(
85+
self,
86+
config: Glm4MoeConfig,
87+
cache_config: Optional[CacheConfig] = None,
88+
quant_config: Optional[QuantizationConfig] = None,
89+
prefix: str = "",
90+
enable_eplb: bool = False,
91+
) -> None:
92+
super().__init__()
93+
self.hidden_size = config.hidden_size
94+
rope_theta = getattr(config, "rope_theta", 10000)
95+
rope_scaling = getattr(config, "rope_scaling", None)
96+
max_position_embeddings = getattr(config, "max_position_embeddings",
97+
131072)
98+
# DecoderLayers are created with `make_layers` which passes the prefix
99+
# with the layer's index.
100+
layer_idx = int(prefix.split(sep='.')[-1])
101+
self.layer_idx = layer_idx
102+
103+
self.self_attn = Glm4MoeAttention(
104+
config=config,
105+
hidden_size=self.hidden_size,
106+
num_heads=config.num_attention_heads,
107+
num_kv_heads=config.num_key_value_heads,
108+
rope_theta=rope_theta,
109+
rope_scaling=rope_scaling,
110+
max_position_embeddings=max_position_embeddings,
111+
head_dim=config.head_dim,
112+
rms_norm_eps=config.rms_norm_eps,
113+
qkv_bias=config.attention_bias,
114+
cache_config=cache_config,
115+
quant_config=quant_config,
116+
prefix=f"{prefix}.self_attn",
117+
use_qk_norm=config.use_qk_norm,
118+
)
119+
120+
if (config.n_routed_experts is not None
121+
and layer_idx >= config.first_k_dense_replace):
122+
self.mlp = CustomGlm4MoE(
123+
config=config,
124+
quant_config=quant_config,
125+
prefix=f"{prefix}.mlp",
126+
enable_eplb=enable_eplb,
127+
)
128+
else:
129+
self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size,
130+
intermediate_size=config.intermediate_size,
131+
hidden_act=config.hidden_act,
132+
quant_config=quant_config,
133+
prefix=f"{prefix}.mlp")
134+
135+
self.input_layernorm = RMSNorm(config.hidden_size,
136+
eps=config.rms_norm_eps)
137+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
138+
eps=config.rms_norm_eps)
139+
self.routed_scaling_factor = config.routed_scaling_factor
140+
141+
def forward(
142+
self,
143+
positions: torch.Tensor,
144+
hidden_states: torch.Tensor,
145+
residual: Optional[torch.Tensor],
146+
) -> tuple[torch.Tensor, torch.Tensor]:
147+
if residual is None:
148+
residual = hidden_states
149+
hidden_states = self.input_layernorm(hidden_states)
150+
else:
151+
hidden_states, residual = self.input_layernorm(
152+
hidden_states, residual)
153+
hidden_states = self.self_attn(positions=positions,
154+
hidden_states=hidden_states)
155+
hidden_states, residual = self.post_attention_layernorm(
156+
hidden_states, residual)
157+
hidden_states = self.mlp(hidden_states)
158+
return hidden_states, residual
159+
160+
161+
@support_torch_compile(
162+
dynamic_arg_dims={
163+
"input_ids": 0,
164+
"positions": -1,
165+
"intermediate_tensors": 0,
166+
"inputs_embeds": 0,
167+
})
168+
class CustomGlm4MoeModel(Glm4MoeModel):
169+
170+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
171+
nn.Module.__init__(self)
172+
config = vllm_config.model_config.hf_config
173+
cache_config = vllm_config.cache_config
174+
quant_config = vllm_config.quant_config
175+
enable_eplb = vllm_config.parallel_config.enable_eplb
176+
self.config = config
177+
178+
self.vocab_size = config.vocab_size
179+
180+
if get_pp_group().is_first_rank:
181+
self.embed_tokens = VocabParallelEmbedding(
182+
config.vocab_size,
183+
config.hidden_size,
184+
prefix=f"{prefix}.embed_tokens")
185+
else:
186+
self.embed_tokens = PPMissingLayer()
187+
188+
self.start_layer, self.end_layer, self.layers = make_layers(
189+
config.num_hidden_layers,
190+
lambda prefix: CustomGlm4MoeDecoderLayer(
191+
config=config,
192+
cache_config=cache_config,
193+
quant_config=quant_config,
194+
prefix=prefix,
195+
enable_eplb=enable_eplb,
196+
),
197+
prefix=f"{prefix}.layers")
198+
199+
if get_pp_group().is_last_rank:
200+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
201+
else:
202+
self.norm = PPMissingLayer()
203+
self.make_empty_intermediate_tensors = (
204+
make_empty_intermediate_tensors_factory(
205+
["hidden_states", "residual"], config.hidden_size))
206+
207+
208+
class CustomGlm4MoeForCausalLM(Glm4MoeForCausalLM):
209+
packed_modules_mapping = {
210+
"qkv_proj": [
211+
"q_proj",
212+
"k_proj",
213+
"v_proj",
214+
],
215+
"gate_up_proj": [
216+
"gate_proj",
217+
"up_proj",
218+
],
219+
"experts":
220+
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
221+
}
222+
223+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
224+
nn.Module.__init__(self)
225+
SupportsPP.__init__(self)
226+
SupportsLoRA.__init__(self)
227+
config = vllm_config.model_config.hf_config
228+
quant_config = vllm_config.quant_config
229+
self.config = config
230+
self.quant_config = quant_config
231+
self.model = CustomGlm4MoeModel(vllm_config=vllm_config,
232+
prefix=maybe_prefix(prefix, "model"))
233+
if get_pp_group().is_last_rank:
234+
self.lm_head = ParallelLMHead(config.vocab_size,
235+
config.hidden_size,
236+
quant_config=quant_config,
237+
prefix=maybe_prefix(prefix, "lm_head"))
238+
else:
239+
self.lm_head = PPMissingLayer()
240+
self.logits_processor = LogitsProcessor(config.vocab_size)
241+
self.make_empty_intermediate_tensors = (
242+
self.model.make_empty_intermediate_tensors)
243+
self.expert_weights = []
244+
245+
# Set MoE hyperparameters
246+
self.num_moe_layers = (config.num_hidden_layers -
247+
config.first_k_dense_replace)
248+
self.num_expert_groups = config.n_group
249+
250+
self.moe_layers: list[FusedMoE] = []
251+
example_moe = None
252+
for layer in self.model.layers:
253+
if isinstance(layer, PPMissingLayer):
254+
continue
255+
256+
assert isinstance(layer, CustomGlm4MoeDecoderLayer)
257+
if isinstance(layer.mlp, CustomGlm4MoE):
258+
# Pick last one layer since the first ones may be dense layers.
259+
example_moe = layer.mlp
260+
self.moe_layers.append(layer.mlp.experts)
261+
262+
if example_moe is None:
263+
raise RuntimeError("No Glm4MoE layer found in model.layers.")
264+
265+
self.num_logical_experts = example_moe.n_logical_experts
266+
self.num_physical_experts = example_moe.n_physical_experts
267+
self.num_local_physical_experts = example_moe.n_local_physical_experts
268+
self.num_routed_experts = example_moe.n_routed_experts
269+
self.num_shared_experts = example_moe.n_shared_experts
270+
self.num_redundant_experts = example_moe.n_redundant_experts

vllm_ascend/ops/moe/token_dispatcher.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,9 @@ def get_token_dispatcher(name: str):
4646

4747

4848
def setup_token_dispatchers(ep_size: int, **kwargs):
49-
existing_dispatchers = set(_Dispatchers.keys())
50-
51-
if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers:
52-
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
53-
elif envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 \
54-
and "TokenDispatcherWithAllGather" not in existing_dispatchers:
55-
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
56-
elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
57-
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
58-
elif ep_size >= 16:
59-
if "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
60-
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
61-
if "TokenDispatcherWithMC2" not in existing_dispatchers:
62-
_register_token_dispatcher(TokenDispatcherWithMC2(**kwargs))
49+
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
50+
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
51+
_register_token_dispatcher(TokenDispatcherWithMC2(**kwargs))
6352

6453

6554
class MoETokenDispatcher(ABC):

0 commit comments

Comments
 (0)