-
Notifications
You must be signed in to change notification settings - Fork 634
[Torch] Fold aten.to.dtype
on splat constants.
#4306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Not sure who can review, maybe you would know @vivekkhandelwal1 @zjgarvey ? |
3bf4e4b
to
1d7b55b
Compare
test/Dialect/Torch/canonicalize.mlir
Outdated
// int32 splat → float32 | ||
%int_splat = torch.vtensor.literal(dense<42> : tensor<2x3xsi32>) : !torch.vtensor<[2,3],si32> | ||
%int6 = torch.constant.int 6 // torch.float32 | ||
// CHECK: %[[R1:.*]] = torch.vtensor.literal({{.*}} : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put the actual value which I think here will be 42.0
?
test/Dialect/Torch/canonicalize.mlir
Outdated
-> !torch.vtensor<[4,4],si32> | ||
|
||
// int64 splat (max int32) → int32 (trunc) | ||
%int64_splat = torch.vtensor.literal(dense<2147483647> : tensor<10xsi64>) : !torch.vtensor<[10],si64> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this value be int32max+1
to ensure that trucation does happen in the IR being locked down?
test/Dialect/Torch/canonicalize.mlir
Outdated
// float32 splat → float64 | ||
%float32_splat = torch.vtensor.literal(dense<2.71828> : tensor<5x5xf32>) : !torch.vtensor<[5,5],f32> | ||
%int7 = torch.constant.int 7 // torch.float64 | ||
// CHECK: %[[R4:.*]] = torch.vtensor.literal({{.*}} : tensor<5x5xf64>) : !torch.vtensor<[5,5],f64> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's capture the actual value here too and other such places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done for all testcases.
test/Dialect/Torch/canonicalize.mlir
Outdated
// int32 splat → float32 | ||
%int_splat = torch.vtensor.literal(dense<42> : tensor<2x3xsi32>) : !torch.vtensor<[2,3],si32> | ||
%int6 = torch.constant.int 6 // torch.float32 | ||
// CHECK: %[[R1:.*]] = torch.vtensor.literal({{.*}} : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are not locking the values being returned from the output IR, I think we should add CHECK-NOT:torch.aten.to.dtype
as well to ensure that the op is being folded.
@@ -1762,6 +1762,78 @@ func.func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch | |||
return %0 : !torch.tensor | |||
} | |||
|
|||
// CHECK-LABEL: @torch.aten.to.dtype$fold_splat( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add some e2e tests to ensure that torch's rounding logic is accurately captured in this implementation?
Also please fix the CI failures, we cannot merge until CI pipelines are green.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added e2e tests in test_conversions.py
. Note that one test in the fx_importer_stablehlo
group fails due to an error in the conversion in stablehlo: stablehlo sign extends all int constants regardless of whether they are signed or not, and now that we fold constants it finds an unsigned int that it will wrongly sing-extend. I have a PR that fixes it, so I guess we'll need to wait for that to go in first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9b8168c
to
42edabd
Compare
This commit teaches `AtenToDtypeOp::fold` to constant-fold dtype conversions when the operand is a splat `DenseElementsAttr`. Folding is done according to torch's rounding behavior, i.e. * Bool: 0 and -0.0 → false; nonzero/NaN/±Inf → true. * Float → Int: round toward zero. * Int → Float: sign-aware, rmNearestTiesToEven. * Float ↔ Float: use builtin `mlir::FloatType::getFloatSemantics()`. * Int ↔ Int: use `zextOrTrunc` / `sextOrTrunc` based on source signedness. Folding is only performed when `non_blocking == false`, `copy == false`, and `memory_format` is None.
42edabd
to
3abbd48
Compare
This commit teaches
AtenToDtypeOp::fold
to constant-fold dtype conversions when the operand is a splatDenseElementsAttr
.Folding is done according to torch's rounding behavior, i.e.
mlir::FloatType::getFloatSemantics()
.zextOrTrunc
/sextOrTrunc
based on source signedness.Folding is only performed when
non_blocking == false
,copy == false
, andmemory_format
is None.