Skip to content

Conversation

Rahulchaube1
Copy link

Summary:

  • Updated the decode() function to support softmax toggling, top-k sampling, and top-p (nucleus) sampling.
  • Added a dummy transformer model to test the decode function and ensure it works with the new features.

Key Changes:

  • decode() function enhancements:
    • Added softmax toggle for logit manipulation.
    • Implemented top-k and top-p filtering for better sampling.
    • Integrated a dummy model to test and validate the decode() function.

Copy link

@KAVYANSHTYAGI KAVYANSHTYAGI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestions
Automated Assertions:
Replace print() with simple checks like:

assert logits_or_probs.shape[-1] == vocab_size # Replace with actual size if known
assert torch.all(logits_or_probs >= 0) # If softmax applied
This would make the test self-validating.

Add CPU/GPU Flex Toggle:
It's great that device='cpu' is used, but a dynamic device selector like:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
helps avoid future hardcoded issues.

Filename Suggestion:
test_kernel.py is a bit generic. Something like test_decode_sampling.py might better reflect the intent.

Move to tests/ folder if available:
To keep inference/ clean, consider relocating this file under a tests/ or examples/ directory, if one exists.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants