Skip to content

Commit 3abbd48

Browse files
committed
[Torch] Fold aten.to.dtype on splat constants.
This commit teaches `AtenToDtypeOp::fold` to constant-fold dtype conversions when the operand is a splat `DenseElementsAttr`. Folding is done according to torch's rounding behavior, i.e. * Bool: 0 and -0.0 → false; nonzero/NaN/±Inf → true. * Float → Int: round toward zero. * Int → Float: sign-aware, rmNearestTiesToEven. * Float ↔ Float: use builtin `mlir::FloatType::getFloatSemantics()`. * Int ↔ Int: use `zextOrTrunc` / `sextOrTrunc` based on source signedness. Folding is only performed when `non_blocking == false`, `copy == false`, and `memory_format` is None.
1 parent dee5158 commit 3abbd48

File tree

4 files changed

+355
-21
lines changed

4 files changed

+355
-21
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 86 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/TypeUtilities.h"
2020
#include "mlir/Support/LLVM.h"
2121
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
22+
#include "llvm/ADT/APSInt.h"
2223
#include "llvm/ADT/BitVector.h"
2324
#include "llvm/ADT/StringMap.h"
2425
#include "llvm/Support/Casting.h"
@@ -892,26 +893,101 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
892893
// The non_blocking arg must be `False`.
893894
if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
894895
nonBlocking)
895-
return nullptr;
896+
return {};
896897
// The copy arg must be `False`.
897898
if (!matchPattern(getCopy(), m_TorchConstantBool(&copyArg)) || copyArg)
898-
return nullptr;
899+
return {};
899900
// The memory_format arg must be `none`.
900901
if (!isa<Torch::NoneType>(getMemoryFormat().getType()))
901-
return nullptr;
902+
return {};
902903

903904
auto inputType = cast<BaseTensorType>(getSelf().getType());
904905
auto resType = cast<BaseTensorType>(getType());
905-
// If the types aren't equal, then we can't fold.
906-
if (inputType != resType)
907-
return nullptr;
906+
907+
// Fold when both the input tensor and result are of the same type.
908908
// If the type does not have a statically known dtype, then we cannot fold.
909909
// For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong,
910910
// since the `unk` could be dynamically different for the operand and result.
911-
if (!inputType.hasDtype())
912-
return nullptr;
913-
// Fold when both the input tensor and result are of the same type.
914-
return getOperand(0);
911+
if (inputType == resType && inputType.hasDtype())
912+
return getOperand(0);
913+
914+
// Fold conversion of splat values.
915+
auto elems = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
916+
if (!elems || !elems.isSplat())
917+
return {};
918+
919+
auto outVTy = dyn_cast<ValueTensorType>(getType());
920+
if (!outVTy)
921+
return {};
922+
923+
auto outShaped = outVTy.toBuiltinTensor();
924+
if (!outShaped.hasStaticShape())
925+
return {};
926+
927+
Type srcEltTy = inputType.getDtype();
928+
Type dstEltTy = outVTy.getDtype();
929+
930+
// Handle integer destination.
931+
if (auto dstI = dyn_cast<IntegerType>(dstEltTy)) {
932+
// any -> bool(i1).
933+
if (dstI.isSignlessInteger(1)) {
934+
bool truthy = false;
935+
if (isa<mlir::FloatType>(srcEltTy)) {
936+
const APFloat &floatVal = elems.getSplatValue<APFloat>();
937+
truthy = !floatVal.isZero();
938+
} else {
939+
const APInt &intVal = elems.getSplatValue<APInt>();
940+
truthy = !intVal.isZero();
941+
}
942+
return DenseElementsAttr::get(outShaped, APInt(/*numBits=*/1, truthy));
943+
}
944+
// float -> intN
945+
if (auto srcF = dyn_cast<mlir::FloatType>(srcEltTy)) {
946+
APSInt result(dstI.getWidth(), /*isUnsigned=*/dstI.isUnsignedInteger());
947+
bool isExact = false;
948+
APFloat f = elems.getSplatValue<APFloat>();
949+
APFloat::opStatus st =
950+
f.convertToInteger(result, APFloat::rmTowardZero, &isExact);
951+
if (st == APFloat::opOK || st == APFloat::opInexact)
952+
return DenseElementsAttr::get(outShaped, APInt(result));
953+
return {}; // NaN/Inf/out-of-range: preserve runtime semantics.
954+
}
955+
// intM -> intN
956+
const APInt &v = elems.getSplatValue<APInt>();
957+
auto isUnsigned = cast<IntegerType>(srcEltTy).isUnsignedInteger();
958+
auto isSignless = cast<IntegerType>(srcEltTy).isSignlessInteger();
959+
APInt casted = isUnsigned || isSignless ? v.zextOrTrunc(dstI.getWidth())
960+
: v.sextOrTrunc(dstI.getWidth());
961+
return DenseElementsAttr::get(outShaped, casted);
962+
}
963+
964+
// Handle float destination.
965+
if (auto dstF = dyn_cast<mlir::FloatType>(dstEltTy)) {
966+
const llvm::fltSemantics &dstSem = dstF.getFloatSemantics();
967+
968+
// int -> float
969+
if (auto srcI = dyn_cast<IntegerType>(srcEltTy)) {
970+
APFloat f(dstSem);
971+
APFloat::opStatus st = f.convertFromAPInt(
972+
elems.getSplatValue<APInt>(),
973+
/*isSigned=*/!srcI.isUnsignedInteger() && !srcI.isSignlessInteger(),
974+
APFloat::rmNearestTiesToEven);
975+
if (st == APFloat::opOK || st == APFloat::opInexact)
976+
return DenseElementsAttr::get(outShaped, f);
977+
return {};
978+
}
979+
980+
// floatX -> floatY
981+
APFloat f = elems.getSplatValue<APFloat>();
982+
bool losesInfo = false;
983+
APFloat::opStatus st =
984+
f.convert(dstSem, APFloat::rmNearestTiesToEven, &losesInfo);
985+
if (st == APFloat::opOK || st == APFloat::opInexact)
986+
return DenseElementsAttr::get(outShaped, f);
987+
return {};
988+
}
989+
990+
return {};
915991
}
916992

917993
//===----------------------------------------------------------------------===//

projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,179 @@ def forward(self, x):
370370
@register_test_case(module_factory=lambda: PrimsConvertElementTypeModule())
371371
def PrimsConvertElementTypeModule_basic(module, tu: TestUtils):
372372
module.forward(tu.rand(3, 5))
373+
374+
375+
# ==============================================================================
376+
377+
378+
class ToDtypeConstIntFromDoubleModule(torch.nn.Module):
379+
def __init__(self):
380+
super().__init__()
381+
self.const = torch.tensor([1.1], dtype=torch.float64)
382+
383+
@export
384+
@annotate_args([None])
385+
def forward(self):
386+
return torch.ops.aten.to(
387+
self.const,
388+
dtype=torch.int64,
389+
)
390+
391+
392+
@register_test_case(module_factory=lambda: ToDtypeConstIntFromDoubleModule())
393+
def ToDtypeConstIntFromDoubleModule_basic(module, tu: TestUtils):
394+
module.forward()
395+
396+
397+
class ToDtypeConstInt32FromInt64Module(torch.nn.Module):
398+
def __init__(self):
399+
super().__init__()
400+
self.const = torch.tensor([2147483648], dtype=torch.int64)
401+
402+
@export
403+
@annotate_args([None])
404+
def forward(self):
405+
return torch.ops.aten.to(
406+
self.const,
407+
dtype=torch.int32,
408+
)
409+
410+
411+
@register_test_case(module_factory=lambda: ToDtypeConstInt32FromInt64Module())
412+
def ToDtypeConstInt32FromInt64Module_basic(module, tu: TestUtils):
413+
module.forward()
414+
415+
416+
class ToDtypeConstFloat16FromFloat64Module(torch.nn.Module):
417+
def __init__(self):
418+
super().__init__()
419+
self.const = torch.tensor([1.2345], dtype=torch.float64)
420+
421+
@export
422+
@annotate_args([None])
423+
def forward(self):
424+
return torch.ops.aten.to(
425+
self.const,
426+
dtype=torch.float16,
427+
)
428+
429+
430+
@register_test_case(module_factory=lambda: ToDtypeConstFloat16FromFloat64Module())
431+
def ToDtypeConstFloat16FromFloat64Module_basic(module, tu: TestUtils):
432+
module.forward()
433+
434+
435+
class ToDtypeConstBFloat16FromFloat32Module(torch.nn.Module):
436+
def __init__(self):
437+
super().__init__()
438+
self.const = torch.tensor([-0.5101], dtype=torch.float32)
439+
440+
@export
441+
@annotate_args([None])
442+
def forward(self):
443+
return torch.ops.aten.to(
444+
self.const,
445+
dtype=torch.float16,
446+
)
447+
448+
449+
@register_test_case(module_factory=lambda: ToDtypeConstBFloat16FromFloat32Module())
450+
def ToDtypeConstBFloat16FromFloat32Module_basic(module, tu: TestUtils):
451+
module.forward()
452+
453+
454+
class ToDtypeConstBoolFromInt32ZeroModule(torch.nn.Module):
455+
def __init__(self):
456+
super().__init__()
457+
self.const = torch.tensor([0], dtype=torch.int32)
458+
459+
@export
460+
@annotate_args([None])
461+
def forward(self):
462+
return torch.ops.aten.to(
463+
self.const,
464+
dtype=torch.bool,
465+
)
466+
467+
468+
@register_test_case(module_factory=lambda: ToDtypeConstBoolFromInt32ZeroModule())
469+
def ToDtypeConstBoolFromInt32ZeroModule_basic(module, tu: TestUtils):
470+
module.forward()
471+
472+
473+
class ToDtypeConstBoolFromInt32NonZeroIntModule(torch.nn.Module):
474+
def __init__(self):
475+
super().__init__()
476+
self.const = torch.tensor([32], dtype=torch.int32)
477+
478+
@export
479+
@annotate_args([None])
480+
def forward(self):
481+
return torch.ops.aten.to(
482+
self.const,
483+
dtype=torch.bool,
484+
)
485+
486+
487+
@register_test_case(module_factory=lambda: ToDtypeConstBoolFromInt32NonZeroIntModule())
488+
def ToDtypeConstBoolFromInt32NonZeroIntModule_basic(module, tu: TestUtils):
489+
module.forward()
490+
491+
492+
class ToDtypeConstBoolFromFloat32NonZeroNanModule(torch.nn.Module):
493+
def __init__(self):
494+
super().__init__()
495+
self.const = torch.tensor([float("nan")], dtype=torch.float32)
496+
497+
@export
498+
@annotate_args([None])
499+
def forward(self):
500+
return torch.ops.aten.to(
501+
self.const,
502+
dtype=torch.bool,
503+
)
504+
505+
506+
@register_test_case(
507+
module_factory=lambda: ToDtypeConstBoolFromFloat32NonZeroNanModule()
508+
)
509+
def ToDtypeConstBoolFromFloat32NonZeroNanModule_basic(module, tu: TestUtils):
510+
module.forward()
511+
512+
513+
class ToDtypeConstFloat32FromBoolModule(torch.nn.Module):
514+
def __init__(self):
515+
super().__init__()
516+
self.const = torch.tensor([True], dtype=torch.bool)
517+
518+
@export
519+
@annotate_args([None])
520+
def forward(self):
521+
return torch.ops.aten.to(
522+
self.const,
523+
dtype=torch.float32,
524+
)
525+
526+
527+
@register_test_case(module_factory=lambda: ToDtypeConstFloat32FromBoolModule())
528+
def ToDtypeConstFloat32FromBoolModule_basic(module, tu: TestUtils):
529+
module.forward()
530+
531+
532+
class ToDtypeConstInt32FromBoolModule(torch.nn.Module):
533+
def __init__(self):
534+
super().__init__()
535+
self.const = torch.tensor([True], dtype=torch.bool)
536+
537+
@export
538+
@annotate_args([None])
539+
def forward(self):
540+
return torch.ops.aten.to(
541+
self.const,
542+
dtype=torch.int32,
543+
)
544+
545+
546+
@register_test_case(module_factory=lambda: ToDtypeConstInt32FromBoolModule())
547+
def ToDtypeConstInt32FromBoolModule_basic(module, tu: TestUtils):
548+
module.forward()

test/Dialect/Torch/canonicalize.mlir

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,94 @@ func.func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch
17621762
return %0 : !torch.tensor
17631763
}
17641764

1765+
// CHECK-LABEL: @torch.aten.to.dtype$fold_splat(
1766+
func.func @torch.aten.to.dtype$fold_splat() -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[4,4],si32>, !torch.vtensor<[10],si32>, !torch.vtensor<[5,5],f64>, !torch.vtensor<[3,3],f16>, !torch.vtensor<[2,2],bf16>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si16>, !torch.vtensor<[2],i1>, !torch.vtensor<[2],i1>) {
1767+
// CHECK-NOT: torch.aten.to.dtype
1768+
%false = torch.constant.bool false
1769+
%none = torch.constant.none
1770+
1771+
// int32 splat → float32
1772+
%int_splat = torch.vtensor.literal(dense<42> : tensor<2x3xsi32>) : !torch.vtensor<[2,3],si32>
1773+
%int6 = torch.constant.int 6 // torch.float32
1774+
// CHECK: %[[R1:.*]] = torch.vtensor.literal(dense<4.200000e+01> : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32>
1775+
%result1 = torch.aten.to.dtype %int_splat, %int6, %false, %false, %none
1776+
: !torch.vtensor<[2,3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1777+
-> !torch.vtensor<[2,3],f32>
1778+
1779+
// float32 splat → int32 (rmTowardZero)
1780+
%float_splat = torch.vtensor.literal(dense<3.14159> : tensor<4x4xf32>) : !torch.vtensor<[4,4],f32>
1781+
%int3 = torch.constant.int 3 // torch.int32
1782+
// CHECK: %[[R2:.*]] = torch.vtensor.literal(dense<3> : tensor<4x4xsi32>) : !torch.vtensor<[4,4],si32>
1783+
%result2 = torch.aten.to.dtype %float_splat, %int3, %false, %false, %none
1784+
: !torch.vtensor<[4,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1785+
-> !torch.vtensor<[4,4],si32>
1786+
1787+
// int64 splat (max int32 + 1) → int32 (trunc)
1788+
%int64_splat = torch.vtensor.literal(dense<2147483648> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
1789+
// CHECK: %[[R3:.*]] = torch.vtensor.literal(dense<-2147483648> : tensor<10xsi32>) : !torch.vtensor<[10],si32>
1790+
%result3 = torch.aten.to.dtype %int64_splat, %int3, %false, %false, %none
1791+
: !torch.vtensor<[10],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none
1792+
-> !torch.vtensor<[10],si32>
1793+
1794+
// float32 splat → float64
1795+
%float32_splat = torch.vtensor.literal(dense<2.71828> : tensor<5x5xf32>) : !torch.vtensor<[5,5],f32>
1796+
%int7 = torch.constant.int 7 // torch.float64
1797+
// CHECK: %[[R4:.*]] = torch.vtensor.literal(dense<2.7182800769805908> : tensor<5x5xf64>) : !torch.vtensor<[5,5],f64>
1798+
%result4 = torch.aten.to.dtype %float32_splat, %int7, %false, %false, %none
1799+
: !torch.vtensor<[5,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1800+
-> !torch.vtensor<[5,5],f64>
1801+
1802+
// float64 splat → float16
1803+
%float64_splat = torch.vtensor.literal(dense<1.2> : tensor<3x3xf64>) : !torch.vtensor<[3,3],f64>
1804+
%int5 = torch.constant.int 5 // torch.float16
1805+
// CHECK: %[[R5:.*]] = torch.vtensor.literal(dense<1.200200e+00> : tensor<3x3xf16>) : !torch.vtensor<[3,3],f16>
1806+
%result5 = torch.aten.to.dtype %float64_splat, %int5, %false, %false, %none
1807+
: !torch.vtensor<[3,3],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none
1808+
-> !torch.vtensor<[3,3],f16>
1809+
1810+
// float32 splat → bfloat16
1811+
%float32_bf16 = torch.vtensor.literal(dense<-0.51> : tensor<2x2xf32>) : !torch.vtensor<[2,2],f32>
1812+
%int15 = torch.constant.int 15 // torch.bfloat16
1813+
// CHECK: %[[R6:.*]] = torch.vtensor.literal(dense<-5.117190e-01> : tensor<2x2xbf16>) : !torch.vtensor<[2,2],bf16>
1814+
%result6 = torch.aten.to.dtype %float32_bf16, %int15, %false, %false, %none
1815+
: !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1816+
-> !torch.vtensor<[2,2],bf16>
1817+
1818+
// int32 splat → int64 (sign-extend)
1819+
%int32_ext = torch.vtensor.literal(dense<-1000> : tensor<4xsi32>) : !torch.vtensor<[4],si32>
1820+
%int4 = torch.constant.int 4 // torch.int64
1821+
// CHECK: %[[R7:.*]] = torch.vtensor.literal(dense<-1000> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
1822+
%result7 = torch.aten.to.dtype %int32_ext, %int4, %false, %false, %none
1823+
: !torch.vtensor<[4],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1824+
-> !torch.vtensor<[4],si64>
1825+
1826+
// int32 splat → int16 (trunc)
1827+
%int32_trunc = torch.vtensor.literal(dense<32768> : tensor<3xsi32>) : !torch.vtensor<[3],si32>
1828+
%int2 = torch.constant.int 2 // torch.int16
1829+
// CHECK: %[[R8:.*]] = torch.vtensor.literal(dense<-32768> : tensor<3xsi16>) : !torch.vtensor<[3],si16>
1830+
%result8 = torch.aten.to.dtype %int32_trunc, %int2, %false, %false, %none
1831+
: !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1832+
-> !torch.vtensor<[3],si16>
1833+
1834+
// int32 splat → bool (i1), non-zero
1835+
%int40_splat = torch.vtensor.literal(dense<40> : tensor<2xsi32>) : !torch.vtensor<[2],si32>
1836+
%int11 = torch.constant.int 11 // torch.bool
1837+
// CHECK: %[[R9:.*]] = torch.vtensor.literal(dense<true> : tensor<2xi1>) : !torch.vtensor<[2],i1>
1838+
%result9 = torch.aten.to.dtype %int40_splat, %int11, %false, %false, %none
1839+
: !torch.vtensor<[2],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1840+
-> !torch.vtensor<[2],i1>
1841+
1842+
// float32 splat → bool (i1), zero
1843+
%float_zero = torch.vtensor.literal(dense<0.0> : tensor<2xf32>) : !torch.vtensor<[2],f32>
1844+
// CHECK: %[[R11:.*]] = torch.vtensor.literal(dense<false> : tensor<2xi1>) : !torch.vtensor<[2],i1>
1845+
%result10 = torch.aten.to.dtype %float_zero, %int11, %false, %false, %none
1846+
: !torch.vtensor<[2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1847+
-> !torch.vtensor<[2],i1>
1848+
1849+
return %result1, %result2, %result3, %result4, %result5, %result6, %result7, %result8, %result9, %result10
1850+
: !torch.vtensor<[2,3],f32>, !torch.vtensor<[4,4],si32>, !torch.vtensor<[10],si32>, !torch.vtensor<[5,5],f64>, !torch.vtensor<[3,3],f16>, !torch.vtensor<[2,2],bf16>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si16>, !torch.vtensor<[2],i1>, !torch.vtensor<[2],i1>
1851+
}
1852+
17651853
// CHECK-LABEL: func.func @torch.aten.to.other$basic(
17661854
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
17671855
// CHECK: %[[NONE:.*]] = torch.constant.none

0 commit comments

Comments
 (0)