Skip to content

Conversation

Hongbosherlock
Copy link
Contributor

@Hongbosherlock Hongbosherlock commented Sep 1, 2025

Motivation

Add scale_ub, because fbgemm_fp8 using it.

It's a part of a fix for fbgemm_fp8 #9466

ref to:vllm-project/vllm#6593

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Hongbosherlock, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the per-token FP8 quantization mechanism by introducing an explicit upper bound for the scaling factor. This change is essential for integrating with specific quantization libraries like fbgemm_fp8, providing finer control over the quantization process and ensuring numerical stability and compatibility in low-precision computations.

Highlights

  • New Parameter for Per-Token FP8 Quantization: Introduced an optional scale_ub (scale upper bound) parameter to the sgl_per_token_quant_fp8 function across its C++ definition, CUDA kernel implementations, and Python binding. This parameter allows for explicit control over the maximum scaling factor used during per-token FP8 quantization.
  • Integration with FBGEMM_FP8: The addition of scale_ub is a crucial part of a fix to ensure compatibility and proper functionality with fbgemm_fp8 implementations, as indicated by the pull request description.
  • Quantization Logic Update: The CUDA kernels (per_token_quant_fp8_kernel and per_token_quant_fp8_small_batch_kernel) now incorporate the scale_ub value. If provided, the calculated scale for quantization will be capped at this upper bound, and a min_scaling_factor is also introduced to prevent excessively small scales.
  • Test Case Update: The existing test suite for sgl_per_token_quant_fp8 has been updated to include passing a scale_ub tensor, ensuring the new functionality is covered by tests.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds an optional scale_ub parameter to the per_token_quant_fp8 function for FP8 quantization, which is used to apply an upper bound on the scaling factor. The changes span across C++ kernel definition, CUDA implementation, Python bindings, and tests. My review has identified a critical race condition in the CUDA kernel per_token_quant_fp8_kernel due to incorrect use of shared memory. I've also pointed out that the test coverage is insufficient to validate the new scale computation logic. Additionally, there are some minor suggestions to improve code clarity by removing magic numbers. Please address the critical and high severity issues.

@@ -49,7 +51,14 @@ __global__ void per_token_quant_fp8_kernel(
float warp_max = warpReduceMax(max_value);

__shared__ float scale;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a race condition on the __shared__ variable scale. This kernel uses multiple warps per block, and each warp calculates a different warp_max. All warps then try to write to the single scale variable in shared memory, causing a race condition.

To fix this, scale should be a register variable, not a shared memory variable. Since warp_max is uniform within a warp, each thread in the warp will correctly compute the same scale value in its own register.

  float scale;

const int64_t hidden_dim,
const int64_t num_tokens) {
const int warp_id = threadIdx.x / kWarpSize; // 0‑7 (8 warps)
const int lane_id = threadIdx.x & (kWarpSize - 1); // 0‑31
const int token_id = blockIdx.x * kTokensPerCTA + warp_id;
if (token_id >= num_tokens) return;
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The magic number 512.f is used here without explanation. To improve code readability and maintainability, it would be beneficial to define it as a named constant with a comment explaining its purpose or origin (e.g., if it's a heuristic from another library like TensorRT-LLM).

  // This value is a heuristic from TensorRT-LLM to prevent overflow when inverting the scale.
  constexpr float MIN_SCALING_FACTOR_DENOM_MULTIPLIER = 512.f;
  float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * MIN_SCALING_FACTOR_DENOM_MULTIPLIER);

const int64_t hidden_dim,
const int64_t num_tokens) {
const int token_idx = blockIdx.x;
if (token_idx >= num_tokens) return;

const int tid = threadIdx.x;
const int block_dim = blockDim.x;
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the other kernel, the magic number 512.f is used here without explanation. It's good practice to define it as a named constant with a comment for better readability and maintainability.

  // This value is a heuristic from TensorRT-LLM to prevent overflow when inverting the scale.
  constexpr float MIN_SCALING_FACTOR_DENOM_MULTIPLIER = 512.f;
  float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * MIN_SCALING_FACTOR_DENOM_MULTIPLIER);

@BBuf
Copy link
Collaborator

BBuf commented Sep 2, 2025

This change has soloved #9466 ?

@Hongbosherlock
Copy link
Contributor Author

This change has soloved #9466 ?

No, this is part of the fix for #9466. We still need to figure out some other problems.#9466 (comment)

@AniZpZ AniZpZ self-assigned this Sep 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants