Skip to content

Conversation

mdazz
Copy link
Contributor

@mdazz mdazz commented Sep 5, 2025

This commit teaches AtenToDtypeOp::fold to constant-fold dtype conversions when the operand is a splat DenseElementsAttr.

Folding is done according to torch's rounding behavior, i.e.

  • Bool: 0 and -0.0 → false; nonzero/NaN/±Inf → true.
  • Float → Int: round toward zero.
  • Int → Float: sign-aware, rmNearestTiesToEven.
  • Float ↔ Float: use builtin mlir::FloatType::getFloatSemantics().
  • Int ↔ Int: use zextOrTrunc / sextOrTrunc based on source signedness.

Folding is only performed when non_blocking == false, copy == false, and memory_format is None.

@mdazz
Copy link
Contributor Author

mdazz commented Sep 5, 2025

Not sure who can review, maybe you would know @vivekkhandelwal1 @zjgarvey ?

// int32 splat → float32
%int_splat = torch.vtensor.literal(dense<42> : tensor<2x3xsi32>) : !torch.vtensor<[2,3],si32>
%int6 = torch.constant.int 6 // torch.float32
// CHECK: %[[R1:.*]] = torch.vtensor.literal({{.*}} : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put the actual value which I think here will be 42.0 ?

-> !torch.vtensor<[4,4],si32>

// int64 splat (max int32) → int32 (trunc)
%int64_splat = torch.vtensor.literal(dense<2147483647> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this value be int32max+1 to ensure that trucation does happen in the IR being locked down?

// float32 splat → float64
%float32_splat = torch.vtensor.literal(dense<2.71828> : tensor<5x5xf32>) : !torch.vtensor<[5,5],f32>
%int7 = torch.constant.int 7 // torch.float64
// CHECK: %[[R4:.*]] = torch.vtensor.literal({{.*}} : tensor<5x5xf64>) : !torch.vtensor<[5,5],f64>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's capture the actual value here too and other such places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done for all testcases.

// int32 splat → float32
%int_splat = torch.vtensor.literal(dense<42> : tensor<2x3xsi32>) : !torch.vtensor<[2,3],si32>
%int6 = torch.constant.int 6 // torch.float32
// CHECK: %[[R1:.*]] = torch.vtensor.literal({{.*}} : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are not locking the values being returned from the output IR, I think we should add CHECK-NOT:torch.aten.to.dtype as well to ensure that the op is being folded.

@@ -1762,6 +1762,78 @@ func.func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch
return %0 : !torch.tensor
}

// CHECK-LABEL: @torch.aten.to.dtype$fold_splat(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some e2e tests to ensure that torch's rounding logic is accurately captured in this implementation?
Also please fix the CI failures, we cannot merge until CI pipelines are green.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added e2e tests in test_conversions.py. Note that one test in the fx_importer_stablehlo group fails due to an error in the conversion in stablehlo: stablehlo sign extends all int constants regardless of whether they are signed or not, and now that we fold constants it finds an unsigned int that it will wrongly sing-extend. I have a PR that fixes it, so I guess we'll need to wait for that to go in first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sahas3 #4313 was merged and now the CI passes for this PR:)

@mdazz mdazz force-pushed the mdazz/add-todtype-folder branch 2 times, most recently from 9b8168c to 42edabd Compare September 15, 2025 10:06
This commit teaches `AtenToDtypeOp::fold` to constant-fold dtype conversions
when the operand is a splat `DenseElementsAttr`.

Folding is done according to torch's rounding behavior, i.e.
  * Bool: 0 and -0.0 → false; nonzero/NaN/±Inf → true.
  * Float → Int: round toward zero.
  * Int → Float: sign-aware, rmNearestTiesToEven.
  * Float ↔ Float: use builtin `mlir::FloatType::getFloatSemantics()`.
  * Int ↔ Int: use `zextOrTrunc` / `sextOrTrunc` based on source signedness.

Folding is only performed when `non_blocking == false`, `copy == false`, and `memory_format` is None.
@mdazz mdazz force-pushed the mdazz/add-todtype-folder branch from 42edabd to 3abbd48 Compare September 16, 2025 08:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants