From 7e644d293ec507c83046565c124240e3c75424c2 Mon Sep 17 00:00:00 2001 From: phlrain <--global> Date: Mon, 18 Aug 2025 12:24:36 +0800 Subject: [PATCH] optimize mtp speed --- .../transformers/deepseek_v2/modeling.py | 60 +++++++++++++++++-- .../transformers/deepseek_v2/modeling_pp.py | 28 ++++++++- 2 files changed, 83 insertions(+), 5 deletions(-) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index e07a3e0fbce4..54a3368ad311 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -209,6 +209,33 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int): assignment_list[i * num_card_per_heads + j].append(i) return assignment_list +class LMHeadFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, weight, transpose_y): + out = paddle.matmul(x, weight, transpose_y = transpose_y) + + ctx.save_for_backward(x, weight, transpose_y) + return out + + @staticmethod + def backward(ctx, dout): + if dout.dtype == paddle.float32: + dout = dout.cast( paddle.bfloat16) + + x, weight, transpose_y = ctx.saved_tensor() + + dx = paddle.matmul( dout, weight, transpose_y = not transpose_y) + if transpose_y: + with paddle.amp.auto_cast(False): + paddle._C_ops.fused_linear_param_grad_add( + dout.reshape( [-1, dout.shape[-1]]), x.reshape( [-1, x.shape[-1]]), weight.main_grad, None, True, False + ) + else: + with paddle.amp.auto_cast(False): + paddle._C_ops.fused_linear_param_grad_add( + x.reshape([-1, x.shape[-1]]), dout.reshape([-1, dout.shape[-1]]), weight.main_grad, None, True, False + ) + return dx, None def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True): is_fleet_init = True @@ -235,8 +262,9 @@ def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_out return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) - else: - logits = paddle.matmul(x, y, transpose_y=transpose_y) + else: + logits = LMHeadFunction.apply(x, y, transpose_y=transpose_y) + return logits @@ -2445,7 +2473,8 @@ def forward( hidden_states = self.hnorm(hidden_states) nextn_hidden_state = self.enorm(nextn_hidden_state) - hidden_states = self.eh_proj(paddle.concat([hidden_states, nextn_hidden_state], axis=-1)) + concat_h = paddle.concat([hidden_states, nextn_hidden_state], axis=-1) + hidden_states = LMHeadFunction.apply( concat_h, self.eh_proj.weight, False) layer_outputs = super(DeepseekV2MTPLayer, self).forward( hidden_states, @@ -3034,6 +3063,29 @@ def forward( mtp_outputs=mtp_outputs, ) +class FastCrossEntropyFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, preds, labels): + + softmax_val, loss = paddle._C_ops.cross_entropy_with_softmax( + preds, labels, False, True, False, -100, -1 + ) + + # print("softmax val", softmax_val.dtype) + + ctx.save_for_backward(labels, softmax_val) + return loss + + @staticmethod + def backward(ctx, dout): + labels, softmax_val = ctx.saved_tensor() + + preds_grad = paddle.incubate.nn.functional.cross_entropy_with_softmax_bwd_w_downcast( + labels, softmax_val.cast(paddle.float32), dout.cast(paddle.float32) + ) + + + return preds_grad, None class DeepseekV2PretrainingCriterion(nn.Layer): """ @@ -3062,7 +3114,7 @@ def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_log def compute_loss(preds, labels): with paddle.amp.auto_cast(False): - masked_lm_loss = self.loss_func(preds.astype("float32"), labels.unsqueeze(2)) + masked_lm_loss = FastCrossEntropyFunction.apply( preds, labels.unsqueeze(2)) binary_sequence = paddle.where( masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) ) diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 40f022d09a9b..01812bd92039 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -1141,6 +1141,32 @@ def build_overlapped_nodes(forward_chunk, backward_chunk): overlap_node = OverlapedScheduleChunk(forward_overlap_layers, backward_overlap_layers, use_fuion=DSV3_USE_FP8_GEMM) return forward_pre_node, backward_pre_node, overlap_node, forward_post_node, backward_post_node +class EmbeddingFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, weight): + out = paddle.nn.functional.embedding( + x, + weight=weight, + padding_idx=None, + max_norm=None, + norm_type=2.0, + sparse=False, + scale_grad_by_freq=False ) + + ctx.save_for_backward(x, weight) + return out + + @staticmethod + def backward(ctx, dout): + x, weight = ctx.saved_tensor() + + if hasattr( weight, "main_grad" ): + paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.main_grad, dout) + else: + paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.grad, dout) + + + return None, None class DeepseekV2EmbeddingPipe(nn.Layer): def __init__(self, config: DeepseekV2Config): @@ -1171,7 +1197,7 @@ def forward(self, args): _type_: _description_ """ input_ids, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = EmbeddingFunction.apply( input_ids, self.embed_tokens.weight ) batch_size, seq_length = input_ids.shape if self.config.num_nextn_predict_layers > 0: