@@ -1762,6 +1762,94 @@ func.func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch
1762
1762
return %0 : !torch.tensor
1763
1763
}
1764
1764
1765
+ // CHECK-LABEL: @torch.aten.to.dtype$fold_splat(
1766
+ func.func @torch.aten.to.dtype$fold_splat () -> (!torch.vtensor <[2 ,3 ],f32 >, !torch.vtensor <[4 ,4 ],si32 >, !torch.vtensor <[10 ],si32 >, !torch.vtensor <[5 ,5 ],f64 >, !torch.vtensor <[3 ,3 ],f16 >, !torch.vtensor <[2 ,2 ],bf16 >, !torch.vtensor <[4 ],si64 >, !torch.vtensor <[3 ],si16 >, !torch.vtensor <[2 ],i1 >, !torch.vtensor <[2 ],i1 >) {
1767
+ // CHECK-NOT: torch.aten.to.dtype
1768
+ %false = torch.constant.bool false
1769
+ %none = torch.constant.none
1770
+
1771
+ // int32 splat → float32
1772
+ %int_splat = torch.vtensor.literal (dense <42 > : tensor <2 x3 xsi32 >) : !torch.vtensor <[2 ,3 ],si32 >
1773
+ %int6 = torch.constant.int 6 // torch.float32
1774
+ // CHECK: %[[R1:.*]] = torch.vtensor.literal(dense<4.200000e+01> : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32>
1775
+ %result1 = torch.aten.to.dtype %int_splat , %int6 , %false , %false , %none
1776
+ : !torch.vtensor <[2 ,3 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1777
+ -> !torch.vtensor <[2 ,3 ],f32 >
1778
+
1779
+ // float32 splat → int32 (rmTowardZero)
1780
+ %float_splat = torch.vtensor.literal (dense <3.14159 > : tensor <4 x4 xf32 >) : !torch.vtensor <[4 ,4 ],f32 >
1781
+ %int3 = torch.constant.int 3 // torch.int32
1782
+ // CHECK: %[[R2:.*]] = torch.vtensor.literal(dense<3> : tensor<4x4xsi32>) : !torch.vtensor<[4,4],si32>
1783
+ %result2 = torch.aten.to.dtype %float_splat , %int3 , %false , %false , %none
1784
+ : !torch.vtensor <[4 ,4 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1785
+ -> !torch.vtensor <[4 ,4 ],si32 >
1786
+
1787
+ // int64 splat (max int32 + 1) → int32 (trunc)
1788
+ %int64_splat = torch.vtensor.literal (dense <2147483648 > : tensor <10 xsi64 >) : !torch.vtensor <[10 ],si64 >
1789
+ // CHECK: %[[R3:.*]] = torch.vtensor.literal(dense<-2147483648> : tensor<10xsi32>) : !torch.vtensor<[10],si32>
1790
+ %result3 = torch.aten.to.dtype %int64_splat , %int3 , %false , %false , %none
1791
+ : !torch.vtensor <[10 ],si64 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1792
+ -> !torch.vtensor <[10 ],si32 >
1793
+
1794
+ // float32 splat → float64
1795
+ %float32_splat = torch.vtensor.literal (dense <2.71828 > : tensor <5 x5 xf32 >) : !torch.vtensor <[5 ,5 ],f32 >
1796
+ %int7 = torch.constant.int 7 // torch.float64
1797
+ // CHECK: %[[R4:.*]] = torch.vtensor.literal(dense<2.7182800769805908> : tensor<5x5xf64>) : !torch.vtensor<[5,5],f64>
1798
+ %result4 = torch.aten.to.dtype %float32_splat , %int7 , %false , %false , %none
1799
+ : !torch.vtensor <[5 ,5 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1800
+ -> !torch.vtensor <[5 ,5 ],f64 >
1801
+
1802
+ // float64 splat → float16
1803
+ %float64_splat = torch.vtensor.literal (dense <1.2 > : tensor <3 x3 xf64 >) : !torch.vtensor <[3 ,3 ],f64 >
1804
+ %int5 = torch.constant.int 5 // torch.float16
1805
+ // CHECK: %[[R5:.*]] = torch.vtensor.literal(dense<1.200200e+00> : tensor<3x3xf16>) : !torch.vtensor<[3,3],f16>
1806
+ %result5 = torch.aten.to.dtype %float64_splat , %int5 , %false , %false , %none
1807
+ : !torch.vtensor <[3 ,3 ],f64 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1808
+ -> !torch.vtensor <[3 ,3 ],f16 >
1809
+
1810
+ // float32 splat → bfloat16
1811
+ %float32_bf16 = torch.vtensor.literal (dense <-0.51 > : tensor <2 x2 xf32 >) : !torch.vtensor <[2 ,2 ],f32 >
1812
+ %int15 = torch.constant.int 15 // torch.bfloat16
1813
+ // CHECK: %[[R6:.*]] = torch.vtensor.literal(dense<-5.117190e-01> : tensor<2x2xbf16>) : !torch.vtensor<[2,2],bf16>
1814
+ %result6 = torch.aten.to.dtype %float32_bf16 , %int15 , %false , %false , %none
1815
+ : !torch.vtensor <[2 ,2 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1816
+ -> !torch.vtensor <[2 ,2 ],bf16 >
1817
+
1818
+ // int32 splat → int64 (sign-extend)
1819
+ %int32_ext = torch.vtensor.literal (dense <-1000 > : tensor <4 xsi32 >) : !torch.vtensor <[4 ],si32 >
1820
+ %int4 = torch.constant.int 4 // torch.int64
1821
+ // CHECK: %[[R7:.*]] = torch.vtensor.literal(dense<-1000> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
1822
+ %result7 = torch.aten.to.dtype %int32_ext , %int4 , %false , %false , %none
1823
+ : !torch.vtensor <[4 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1824
+ -> !torch.vtensor <[4 ],si64 >
1825
+
1826
+ // int32 splat → int16 (trunc)
1827
+ %int32_trunc = torch.vtensor.literal (dense <32768 > : tensor <3 xsi32 >) : !torch.vtensor <[3 ],si32 >
1828
+ %int2 = torch.constant.int 2 // torch.int16
1829
+ // CHECK: %[[R8:.*]] = torch.vtensor.literal(dense<-32768> : tensor<3xsi16>) : !torch.vtensor<[3],si16>
1830
+ %result8 = torch.aten.to.dtype %int32_trunc , %int2 , %false , %false , %none
1831
+ : !torch.vtensor <[3 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1832
+ -> !torch.vtensor <[3 ],si16 >
1833
+
1834
+ // int32 splat → bool (i1), non-zero
1835
+ %int40_splat = torch.vtensor.literal (dense <40 > : tensor <2 xsi32 >) : !torch.vtensor <[2 ],si32 >
1836
+ %int11 = torch.constant.int 11 // torch.bool
1837
+ // CHECK: %[[R9:.*]] = torch.vtensor.literal(dense<true> : tensor<2xi1>) : !torch.vtensor<[2],i1>
1838
+ %result9 = torch.aten.to.dtype %int40_splat , %int11 , %false , %false , %none
1839
+ : !torch.vtensor <[2 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1840
+ -> !torch.vtensor <[2 ],i1 >
1841
+
1842
+ // float32 splat → bool (i1), zero
1843
+ %float_zero = torch.vtensor.literal (dense <0.0 > : tensor <2 xf32 >) : !torch.vtensor <[2 ],f32 >
1844
+ // CHECK: %[[R11:.*]] = torch.vtensor.literal(dense<false> : tensor<2xi1>) : !torch.vtensor<[2],i1>
1845
+ %result10 = torch.aten.to.dtype %float_zero , %int11 , %false , %false , %none
1846
+ : !torch.vtensor <[2 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1847
+ -> !torch.vtensor <[2 ],i1 >
1848
+
1849
+ return %result1 , %result2 , %result3 , %result4 , %result5 , %result6 , %result7 , %result8 , %result9 , %result10
1850
+ : !torch.vtensor <[2 ,3 ],f32 >, !torch.vtensor <[4 ,4 ],si32 >, !torch.vtensor <[10 ],si32 >, !torch.vtensor <[5 ,5 ],f64 >, !torch.vtensor <[3 ,3 ],f16 >, !torch.vtensor <[2 ,2 ],bf16 >, !torch.vtensor <[4 ],si64 >, !torch.vtensor <[3 ],si16 >, !torch.vtensor <[2 ],i1 >, !torch.vtensor <[2 ],i1 >
1851
+ }
1852
+
1765
1853
// CHECK-LABEL: func.func @torch.aten.to.other$basic(
1766
1854
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
1767
1855
// CHECK: %[[NONE:.*]] = torch.constant.none
0 commit comments