Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H

#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
Expand All @@ -26,8 +27,8 @@ namespace tosa {
// rounding mode
Value buildRescale(PatternRewriter &rewriter, Operation *op,
ShapedType output_type, Value input_val, double scale,
int64_t input_zp, int64_t output_zp, StringRef rounding_mode,
bool scale32);
int64_t input_zp, int64_t output_zp,
tosa::RoundingMode rounding_mode, bool scale32);

// Creates TOSA rescale op with int32 output
Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
Expand Down
69 changes: 48 additions & 21 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ class ConvertAtenBinaryOp : public OpConversionPattern<AtenOpT> {
// tosa.minimum
binaryOp = rewriter.create<TosaOpT>(
op->getLoc(), outTy, lhs, rhs,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
} else {
binaryOp =
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
Expand Down Expand Up @@ -907,7 +909,9 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, outTy, self, minFloatAttr, maxFloatAttr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
tosa::NanPropagationMode::PROPAGATE));
return success();
}

Expand Down Expand Up @@ -1237,7 +1241,9 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
.create<tosa::ArgMaxOp>(
op->getLoc(), getTypeConverter()->convertType(outputReduceTy),
input, reduceDimAttr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"))
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE))
.getResult();
};

Expand Down Expand Up @@ -3925,7 +3931,9 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern<AtenOpT> {
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(reducedShape),
selfElemType),
self, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
self, dimAttr, /*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
} else {
reduceOp = rewriter.create<TosaOpT>(
op->getLoc(),
Expand All @@ -3946,14 +3954,18 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern<AtenOpT> {
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
negateOp, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
negateOp, dimAttr, /*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
} else {
// Use default NaN Propagation mode "PROPAGATE" for tosa.argmax
argMaxOp = rewriter.create<tosa::ArgMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
self, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
self, dimAttr, /*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
}

if (argMaxOp.getType() != indicesType) {
Expand Down Expand Up @@ -5202,7 +5214,9 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, outType, adaptor.getSelf(), minIntAttr, maxIntAttr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
tosa::NanPropagationMode::PROPAGATE));
} else {
FloatAttr minFloatAttr, maxFloatAttr;
if (outElemTy.isF16()) {
Expand Down Expand Up @@ -5231,7 +5245,9 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, outType, adaptor.getSelf(), minFloatAttr, maxFloatAttr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
tosa::NanPropagationMode::PROPAGATE));
}

return success();
Expand Down Expand Up @@ -5340,13 +5356,17 @@ LogicalResult ConvertAtenOp<AtenClampTensorOp>::matchAndRewrite(
// Use default NaN Propagation mode "PROPAGATE" for tosa.maximum
auto minThresholdCheck = rewriter.create<tosa::MaximumOp>(
op->getLoc(), resultType, self, min,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
tosa::NanPropagationMode::PROPAGATE));

// yi = min(max(xi, min_valuei), max_valuei)
// Use default NaN Propagation mode "PROPAGATE" for tosa.minimum
auto result = rewriter.create<tosa::MinimumOp>(
op->getLoc(), resultType, minThresholdCheck, max,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
tosa::NanPropagationMode::PROPAGATE));

rewriter.replaceOp(op, result);
return success();
Expand Down Expand Up @@ -5934,7 +5954,10 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
pooledOutput = rewriter
.create<TosaOpT>(
op->getLoc(), outputTy, input, kernel, stride, pad,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"))
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(),
tosa::NanPropagationMode::PROPAGATE))
.getResult();
} else if constexpr (std::is_same<TosaOpT, tosa::AvgPool2dOp>::value) {
TypeAttr accType;
Expand Down Expand Up @@ -6825,11 +6848,11 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "Only nearest and bilinear interpolation modes supported");

std::string mode;
tosa::ResizeMode mode;
if (pyMode == "bilinear") {
mode = "BILINEAR";
mode = tosa::ResizeMode::BILINEAR;
} else {
mode = "NEAREST_NEIGHBOR";
mode = tosa::ResizeMode::NEAREST_NEIGHBOR;
}

bool alignCorners;
Expand Down Expand Up @@ -6891,7 +6914,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
offset = 0;

// If nearest neighbours we need to guarantee we round up.
if (mode == "NEAREST_NEIGHBOR" && alignCorners) {
if (mode == tosa::ResizeMode::NEAREST_NEIGHBOR && alignCorners) {
offset += n / 2;
}

Expand All @@ -6911,7 +6934,8 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
tosa::getTosaConstShape(rewriter, op->getLoc(), {offset_y, offset_x});
auto border =
tosa::getTosaConstShape(rewriter, op->getLoc(), {border_y, border_x});
StringAttr modeAttr = rewriter.getStringAttr(mode);

auto modeAttr = tosa::ResizeModeAttr::get(rewriter.getContext(), mode);

auto resizeOpResult =
rewriter
Expand Down Expand Up @@ -8605,11 +8629,14 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
// Clamp input to [eps, 1 - eps] when eps is not None
// Use default NaN Propagation mode "PROPAGATE" for tosa.clamp
if (!isEpsNone) {
zi = rewriter
.create<tosa::ClampOp>(
op->getLoc(), resultType, self, minFloatAttr, maxFloatAttr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"))
.getResult();
zi =
rewriter
.create<tosa::ClampOp>(
op->getLoc(), resultType, self, minFloatAttr, maxFloatAttr,
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE))
.getResult();
}

auto one =
Expand Down
7 changes: 5 additions & 2 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
Expand Down Expand Up @@ -764,7 +765,9 @@ std::optional<Value> convertReduceOpCommon(
// and tosa.reduce_max
reduce_op = CreateOpAndInfer<T>(
rewriter, op->getLoc(), reduce_type, val, axis_attr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
/*nan_mode=*/
tosa::NanPropagationModeAttr::get(
rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE));
} else {
reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
val, axis_attr);
Expand All @@ -777,7 +780,7 @@ std::optional<Value> convertReduceOpCommon(
RankedTensorType output_rescale_type =
RankedTensorType::get(shape_vec, output_type.getElementType());
val = buildRescale(rewriter, op, output_rescale_type, val, output_scale,
0, output_zp, "SINGLE_ROUND", true);
0, output_zp, tosa::RoundingMode::SINGLE_ROUND, true);
}

// Optionally squeeze out the reduced axes.
Expand Down
17 changes: 11 additions & 6 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ Value buildRescaleMultiplier(bool scale32, PatternRewriter &rewriter,
// rounding mode
Value buildRescale(PatternRewriter &rewriter, Operation *op,
ShapedType output_type, Value input_val, double scale,
int64_t input_zp, int64_t output_zp, StringRef rounding_mode,
bool scale32) {
int64_t input_zp, int64_t output_zp,
tosa::RoundingMode rounding_mode, bool scale32) {
int32_t multiplier;
int32_t shift;

Expand Down Expand Up @@ -70,7 +70,8 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op,
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
rewriter, op->getLoc(), output_type, input_val, multiplier_val, shift_val,
input_zp_val.value(), output_zp_val.value(),
rewriter.getBoolAttr(scale32), rewriter.getStringAttr(rounding_mode),
rewriter.getBoolAttr(scale32),
tosa::RoundingModeAttr::get(rewriter.getContext(), rounding_mode),
rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned),
rewriter.getBoolAttr(output_unsigned));

Expand All @@ -87,7 +88,7 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
auto output_type = input_type.clone(rewriter.getI32Type());

return buildRescale(rewriter, op, output_type, input_val, input_scale,
input_zp, 0, "SINGLE_ROUND", true);
input_zp, 0, tosa::RoundingMode::SINGLE_ROUND, true);
}

// Creates a TOSA rescale op based on conv2d parameters.
Expand Down Expand Up @@ -146,7 +147,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
rewriter, op->getLoc(), output_type, conv_val, multiplier_val,
shift_val, input_zp_val.value(), output_zp_val.value(),
rewriter.getBoolAttr(scale32), rewriter.getStringAttr("DOUBLE_ROUND"),
rewriter.getBoolAttr(scale32),
tosa::RoundingModeAttr::get(rewriter.getContext(),
tosa::RoundingMode::DOUBLE_ROUND),
rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned),
rewriter.getBoolAttr(output_unsigned));

Expand Down Expand Up @@ -188,7 +191,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
rewriter, op->getLoc(), output_type, conv_val, multiplier_val,
shift_val, input_zp_val.value(), output_zp_val.value(),
rewriter.getBoolAttr(scale32), rewriter.getStringAttr("DOUBLE_ROUND"),
rewriter.getBoolAttr(scale32),
tosa::RoundingModeAttr::get(rewriter.getContext(),
tosa::RoundingMode::DOUBLE_ROUND),
rewriter.getBoolAttr(true), rewriter.getBoolAttr(input_unsigned),
rewriter.getBoolAttr(output_unsigned));

Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1559,7 +1559,7 @@ func.func @torch.aten.isclose$basic(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !to
// CHECK-DAG: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[4, 2, 4, 2]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK-DAG: %[[VAL_9:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = "BILINEAR"} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32>
// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = BILINEAR} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32>
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_11]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x270x480x16xf32>) -> tensor<1x16x270x480xf32>
// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32>
// CHECK: return %[[VAL_13]] : !torch.vtensor<[1,16,270,480],f32>
Expand Down Expand Up @@ -1588,7 +1588,7 @@ func.func @torch.aten.__interpolate.size_list_scale_list.bilinear(%arg0: !torch.
// CHECK-DAG: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[4, 2, 4, 2]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK-DAG: %[[VAL_9:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = "NEAREST_NEIGHBOR"} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32>
// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = NEAREST_NEIGHBOR} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32>
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_11]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x270x480x16xf32>) -> tensor<1x16x270x480xf32>
// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32>
// CHECK: return %[[VAL_13]] : !torch.vtensor<[1,16,270,480],f32>
Expand Down
Loading