Skip to content

Commit c146212

Browse files
authored
Fix abs Operator int accuracy issue (#31920)
### Details: - Current Abs operator's implementation 1. Convert int32 to float 2. Apply sign mask 3. Convert float to int32 Step 3 will result in a precision loss. ### Tickets: - [CVS-171704](https://jira.devtools.intel.com/browse/CVS-171704)
1 parent 5beb837 commit c146212

File tree

5 files changed

+95
-17
lines changed

5 files changed

+95
-17
lines changed

src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_ext_emitters.hpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,6 @@ class jit_elu_emitter : public jit_dnnl_emitter {
8585
}
8686
};
8787

88-
class jit_abs_emitter : public jit_dnnl_emitter {
89-
public:
90-
jit_abs_emitter(dnnl::impl::cpu::x64::jit_generator_t* host,
91-
dnnl::impl::cpu::x64::cpu_isa_t host_isa,
92-
const std::shared_ptr<ov::Node>& n,
93-
ov::element::Type exec_prc = ov::element::f32)
94-
: jit_dnnl_emitter(host, host_isa, n, exec_prc) {
95-
kind = dnnl_eltwise_abs;
96-
alpha = 0.F;
97-
beta = 0.F;
98-
99-
set_injector();
100-
}
101-
};
102-
10388
class jit_clamp_emitter : public jit_dnnl_emitter {
10489
public:
10590
jit_clamp_emitter(dnnl::impl::cpu::x64::jit_generator_t* host,

src/plugins/intel_cpu/src/emitters/plugin/x64/jit_eltwise_emitters.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2852,4 +2852,75 @@ void jit_bitwise_xor_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs,
28522852
h->uni_vxorps(vmm_dst, vmm_src0, vmm_src1);
28532853
}
28542854

2855+
/// ABS ///
2856+
jit_abs_emitter::jit_abs_emitter(x64::jit_generator_t* host,
2857+
x64::cpu_isa_t host_isa,
2858+
const std::shared_ptr<ov::Node>& node)
2859+
: jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {
2860+
prepare_table();
2861+
}
2862+
2863+
jit_abs_emitter::jit_abs_emitter(x64::jit_generator_t* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc)
2864+
: jit_emitter(host, host_isa, exec_prc) {
2865+
prepare_table();
2866+
}
2867+
2868+
size_t jit_abs_emitter::get_inputs_num() const {
2869+
return 1;
2870+
}
2871+
2872+
void jit_abs_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
2873+
if (host_isa_ == x64::sse41) {
2874+
emit_isa<x64::sse41>(in_vec_idxs, out_vec_idxs);
2875+
} else if (host_isa_ == x64::avx2) {
2876+
emit_isa<x64::avx2>(in_vec_idxs, out_vec_idxs);
2877+
} else if (host_isa_ == x64::avx512_core) {
2878+
emit_isa<x64::avx512_core>(in_vec_idxs, out_vec_idxs);
2879+
} else {
2880+
OV_CPU_JIT_EMITTER_THROW("Unsupported ISA ", host_isa_);
2881+
}
2882+
}
2883+
2884+
template <x64::cpu_isa_t isa>
2885+
void jit_abs_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
2886+
using Vmm = typename conditional3<isa == x64::sse41, Xmm, isa == x64::avx2, Ymm, Zmm>::type;
2887+
auto vmm_src = Vmm(in_vec_idxs[0]);
2888+
auto vmm_dst = Vmm(out_vec_idxs[0]);
2889+
2890+
auto uni_vpabsd = [this](Vmm vmm_dst, Vmm vmm_src) {
2891+
switch (exec_prc_) {
2892+
case ov::element::f32:
2893+
h->uni_vandps(vmm_dst, vmm_src, table_val("positive_mask"));
2894+
break;
2895+
case ov::element::i32:
2896+
if (isa == x64::sse41) {
2897+
h->pabsd(vmm_dst, vmm_src);
2898+
} else if (any_of(host_isa_, x64::avx2, x64::avx512_core)) {
2899+
h->vpabsd(vmm_dst, vmm_src);
2900+
} else {
2901+
OV_CPU_JIT_EMITTER_THROW("Unsupported ISA ", host_isa_);
2902+
}
2903+
break;
2904+
default:
2905+
OV_CPU_JIT_EMITTER_THROW("Unsupported precision");
2906+
}
2907+
};
2908+
2909+
if (isa == x64::sse41) {
2910+
h->uni_vmovups(vmm_dst, vmm_src);
2911+
uni_vpabsd(vmm_dst, vmm_dst);
2912+
} else {
2913+
uni_vpabsd(vmm_dst, vmm_src);
2914+
}
2915+
}
2916+
2917+
std::set<std::vector<element::Type>> jit_abs_emitter::get_supported_precisions(
2918+
[[maybe_unused]] const std::shared_ptr<ov::Node>& node) {
2919+
return {{element::f32}, {element::i32}};
2920+
}
2921+
2922+
void jit_abs_emitter::register_table_entries() {
2923+
push_arg_entry_of("positive_mask", 0x7fffffff, true);
2924+
}
2925+
28552926
} // namespace ov::intel_cpu

src/plugins/intel_cpu/src/emitters/plugin/x64/jit_eltwise_emitters.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,4 +933,25 @@ class jit_bitwise_xor_emitter : public jit_emitter {
933933
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;
934934
};
935935

936+
class jit_abs_emitter : public jit_emitter {
937+
public:
938+
jit_abs_emitter(dnnl::impl::cpu::x64::jit_generator_t* host,
939+
dnnl::impl::cpu::x64::cpu_isa_t host_isa,
940+
ov::element::Type exec_prc = ov::element::f32);
941+
jit_abs_emitter(dnnl::impl::cpu::x64::jit_generator_t* host,
942+
dnnl::impl::cpu::x64::cpu_isa_t host_isa,
943+
const std::shared_ptr<ov::Node>& n);
944+
945+
size_t get_inputs_num() const override;
946+
static std::set<std::vector<element::Type>> get_supported_precisions(
947+
const std::shared_ptr<ov::Node>& node = nullptr);
948+
949+
private:
950+
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;
951+
952+
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
953+
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;
954+
void register_table_entries() override;
955+
};
956+
936957
} // namespace ov::intel_cpu

src/plugins/intel_cpu/src/nodes/kernels/x64/jit_uni_eltwise_generic.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
395395
OV_CASE(Algorithm::EltwiseElu, jit_dnnl_aux_emitter),
396396
OV_CASE(Algorithm::EltwiseTanh, jit_dnnl_aux_emitter),
397397
OV_CASE(Algorithm::EltwiseSigmoid, jit_dnnl_aux_emitter),
398-
OV_CASE(Algorithm::EltwiseAbs, jit_dnnl_aux_emitter),
398+
OV_CASE(Algorithm::EltwiseAbs, jit_abs_emitter),
399399
OV_CASE(Algorithm::EltwiseSqrt, jit_dnnl_aux_emitter),
400400
OV_CASE(Algorithm::EltwiseSoftRelu, jit_dnnl_aux_emitter),
401401
OV_CASE(Algorithm::EltwiseClamp, jit_dnnl_aux_emitter),
@@ -888,7 +888,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
888888
OV_CASE(Algorithm::EltwiseElu, jit_dnnl_aux_emitter),
889889
OV_CASE(Algorithm::EltwiseTanh, jit_dnnl_aux_emitter),
890890
OV_CASE(Algorithm::EltwiseSigmoid, jit_dnnl_aux_emitter),
891-
OV_CASE(Algorithm::EltwiseAbs, jit_dnnl_aux_emitter),
891+
OV_CASE(Algorithm::EltwiseAbs, jit_abs_emitter),
892892
OV_CASE(Algorithm::EltwiseSqrt, jit_dnnl_aux_emitter),
893893
OV_CASE(Algorithm::EltwiseSoftRelu, jit_dnnl_aux_emitter),
894894
OV_CASE(Algorithm::EltwiseClamp, jit_dnnl_aux_emitter),

src/plugins/intel_cpu/tests/functional/shared_tests_instances/single_layer_tests/activation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
6363

6464
// List of operations that should be tested also with integer precision
6565
const std::map<ActivationTypes, std::vector<std::vector<float>>> intActivationTypes = {
66+
{ActivationTypes::Abs, {}},
6667
{ActivationTypes::Acosh, {}},
6768
{ActivationTypes::Asinh, {}},
6869
{ActivationTypes::Atan, {}},

0 commit comments

Comments
 (0)