Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,33 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
assignment_list[i * num_card_per_heads + j].append(i)
return assignment_list

class LMHeadFunction(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, weight, transpose_y):
out = paddle.matmul(x, weight, transpose_y = transpose_y)

ctx.save_for_backward(x, weight, transpose_y)
return out

@staticmethod
def backward(ctx, dout):
if dout.dtype == paddle.float32:
dout = dout.cast( paddle.bfloat16)

x, weight, transpose_y = ctx.saved_tensor()

dx = paddle.matmul( dout, weight, transpose_y = not transpose_y)
if transpose_y:
with paddle.amp.auto_cast(False):
paddle._C_ops.fused_linear_param_grad_add(
dout.reshape( [-1, dout.shape[-1]]), x.reshape( [-1, x.shape[-1]]), weight.main_grad, None, True, False
)
else:
with paddle.amp.auto_cast(False):
paddle._C_ops.fused_linear_param_grad_add(
x.reshape([-1, x.shape[-1]]), dout.reshape([-1, dout.shape[-1]]), weight.main_grad, None, True, False
)
return dx, None

def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True):
is_fleet_init = True
Expand All @@ -235,8 +262,9 @@ def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_out

return paddle.distributed.collective._c_concat(logits, group=model_parallel_group)

else:
logits = paddle.matmul(x, y, transpose_y=transpose_y)
else:
logits = LMHeadFunction.apply(x, y, transpose_y=transpose_y)

return logits


Expand Down Expand Up @@ -2445,7 +2473,8 @@ def forward(
hidden_states = self.hnorm(hidden_states)
nextn_hidden_state = self.enorm(nextn_hidden_state)

hidden_states = self.eh_proj(paddle.concat([hidden_states, nextn_hidden_state], axis=-1))
concat_h = paddle.concat([hidden_states, nextn_hidden_state], axis=-1)
hidden_states = LMHeadFunction.apply( concat_h, self.eh_proj.weight, False)

layer_outputs = super(DeepseekV2MTPLayer, self).forward(
hidden_states,
Expand Down Expand Up @@ -3034,6 +3063,29 @@ def forward(
mtp_outputs=mtp_outputs,
)

class FastCrossEntropyFunction(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, preds, labels):

softmax_val, loss = paddle._C_ops.cross_entropy_with_softmax(
preds, labels, False, True, False, -100, -1
)

# print("softmax val", softmax_val.dtype)

ctx.save_for_backward(labels, softmax_val)
return loss

@staticmethod
def backward(ctx, dout):
labels, softmax_val = ctx.saved_tensor()

preds_grad = paddle.incubate.nn.functional.cross_entropy_with_softmax_bwd_w_downcast(
labels, softmax_val.cast(paddle.float32), dout.cast(paddle.float32)
)


return preds_grad, None

class DeepseekV2PretrainingCriterion(nn.Layer):
"""
Expand Down Expand Up @@ -3062,7 +3114,7 @@ def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_log

def compute_loss(preds, labels):
with paddle.amp.auto_cast(False):
masked_lm_loss = self.loss_func(preds.astype("float32"), labels.unsqueeze(2))
masked_lm_loss = FastCrossEntropyFunction.apply( preds, labels.unsqueeze(2))
binary_sequence = paddle.where(
masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss)
)
Expand Down
28 changes: 27 additions & 1 deletion paddlenlp/transformers/deepseek_v2/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,32 @@ def build_overlapped_nodes(forward_chunk, backward_chunk):
overlap_node = OverlapedScheduleChunk(forward_overlap_layers, backward_overlap_layers, use_fuion=DSV3_USE_FP8_GEMM)
return forward_pre_node, backward_pre_node, overlap_node, forward_post_node, backward_post_node

class EmbeddingFunction(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, weight):
out = paddle.nn.functional.embedding(
x,
weight=weight,
padding_idx=None,
max_norm=None,
norm_type=2.0,
sparse=False,
scale_grad_by_freq=False )

ctx.save_for_backward(x, weight)
return out

@staticmethod
def backward(ctx, dout):
x, weight = ctx.saved_tensor()

if hasattr( weight, "main_grad" ):
paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.main_grad, dout)
else:
paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.grad, dout)


return None, None

class DeepseekV2EmbeddingPipe(nn.Layer):
def __init__(self, config: DeepseekV2Config):
Expand Down Expand Up @@ -1171,7 +1197,7 @@ def forward(self, args):
_type_: _description_
"""
input_ids, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = EmbeddingFunction.apply( input_ids, self.embed_tokens.weight )

batch_size, seq_length = input_ids.shape
if self.config.num_nextn_predict_layers > 0:
Expand Down
Loading