**Describe the bug** get_batch doesnt support cp broadcasting. Running with cp > 1 on current code yields wrong loss **To Reproduce** any same SFT experiment with cp=1 then cp=2 **Expected behavior** cp=1 or cp=2 should yield very similar lm loss **Screenshots**  