Skip to content

Commit 5beb837

Browse files
authored
[Snippets] Improve Subgraph tokenization (#31900)
### Details: - *Error-prone constructor is removed from Subgraph node* - *`tokenize_ordered_nodes` helper: avoid Parameters duplication in case of shared external input* ### Tickets: - *N\A*
1 parent a4f2b55 commit 5beb837

File tree

20 files changed

+152
-85
lines changed

20 files changed

+152
-85
lines changed

src/common/snippets/include/snippets/op/subgraph.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,6 @@ class Subgraph : public ov::op::util::SubGraphOp {
9292

9393
explicit Subgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& body);
9494

95-
explicit Subgraph(const NodeVector& args, const std::shared_ptr<ov::Model>& body);
96-
9795
bool visit_attributes(AttributeVisitor& visitor) override;
9896

9997
void validate_and_infer_types() override;

src/common/snippets/include/snippets/utils/tokenization_utils.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ bool tokenize_node(const std::shared_ptr<ov::Node>& node,
3131
* 1. The user is responsible for valid count of parameters, results and hidden virtual ports (constants)
3232
* 2. The list of nodes cannot contain Subgraph ops
3333
* @param ordered_ops node list which should be tokenized
34+
* @param are_shared_internal_params_allowed if true, allows sharing internal parameters.
35+
* Note: Snippets support only internal parameters which are used by all the consumers as is.
36+
* This means that e.g. if the shared parameter is used by 2 MatMuls on B input,
37+
* both matmuls must have equal transpose_b parameter.
38+
* This is a user responsibility to ensure that the shared internal parameters can be used.
3439
* @return tokenized subgraph
3540
*/
36-
std::shared_ptr<ov::snippets::op::Subgraph> tokenize_ordered_nodes(const ov::NodeVector& ordered_ops);
41+
std::shared_ptr<ov::snippets::op::Subgraph> tokenize_ordered_nodes(const ov::NodeVector& ordered_ops,
42+
bool are_shared_internal_params_allowed = false);
3743
} // namespace ov::snippets::utils

src/common/snippets/src/op/subgraph.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,6 @@ Subgraph::Subgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& b
219219
m_shape_infer = std::make_shared<OVShapeInfer>(body);
220220
}
221221

222-
Subgraph::Subgraph(const NodeVector& args, const std::shared_ptr<ov::Model>& body)
223-
: Subgraph(as_output_vector(args), body) {}
224-
225222
std::shared_ptr<Node> Subgraph::clone_with_new_inputs(const OutputVector& inputs) const {
226223
INTERNAL_OP_SCOPE(Subgraph);
227224
return make_shared<Subgraph>(inputs, body().clone());

src/common/snippets/src/pass/gated_mlp_tokenization.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <memory>
88

9+
#include "openvino/core/except.hpp"
910
#include "openvino/core/node.hpp"
1011
#include "openvino/core/node_output.hpp"
1112
#include "openvino/core/type.hpp"
@@ -116,7 +117,13 @@ TokenizeGatedMLPSnippets::TokenizeGatedMLPSnippets(const SnippetsTokenization::C
116117
}
117118

118119
const auto ordered_ops = ov::NodeVector{fc_gate, fc_up, act, mul, fc_down};
119-
const auto subgraph = ov::snippets::utils::tokenize_ordered_nodes(ordered_ops);
120+
const bool allow_shared_params = [&]() {
121+
const auto mm_gate = ov::as_type_ptr<ov::op::v0::MatMul>(fc_gate);
122+
const auto mm_up = ov::as_type_ptr<ov::op::v0::MatMul>(fc_up);
123+
OPENVINO_ASSERT(mm_gate && mm_up, "fc_gate and fc_up must have MatMul type");
124+
return mm_gate->get_transpose_a() == mm_up->get_transpose_a();
125+
}();
126+
const auto subgraph = ov::snippets::utils::tokenize_ordered_nodes(ordered_ops, allow_shared_params);
120127

121128
// mark the Subgraph as Completed to not allow Snippets to include any nodes into this Subgraph in common
122129
// Tokenization

src/common/snippets/src/utils/tokenization_utils.cpp

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <climits>
1010
#include <cstddef>
1111
#include <cstdint>
12+
#include <iterator>
1213
#include <map>
1314
#include <memory>
1415
#include <numeric>
@@ -442,56 +443,49 @@ bool tokenize_node(const std::shared_ptr<ov::Node>& node, const SnippetsTokeniza
442443
return true;
443444
}
444445

445-
std::shared_ptr<ov::snippets::op::Subgraph> tokenize_ordered_nodes(const ov::NodeVector& ordered_ops) {
446+
std::shared_ptr<ov::snippets::op::Subgraph> tokenize_ordered_nodes(const ov::NodeVector& ordered_ops,
447+
bool are_shared_internal_params_allowed) {
446448
OPENVINO_ASSERT(!ordered_ops.empty(), "Nothing to be tokenized!");
447449

448-
ov::OutputVector body_inputs, subgraph_inputs;
450+
ov::OutputVector subgraph_inputs;
449451
ov::ParameterVector body_parameters;
450-
ov::ResultVector body_results;
451-
std::vector<std::set<Input<Node>>> subgraph_result_inputs;
452-
453452
auto create_body_inputs = [&](const std::shared_ptr<ov::Node>& node) -> void {
454453
for (size_t i = 0; i < node->get_input_size(); ++i) {
455454
const auto input = node->input(i);
456455
const auto parent = input.get_source_output().get_node_shared_ptr();
457456
const auto constant = ov::as_type_ptr<ov::op::v0::Constant>(parent);
458457
if (constant && (ov::shape_size(input.get_shape()) == 1 || ov::is_type<ov::op::v0::FakeQuantize>(node) ||
459458
op::Subgraph::constant_input_should_be_inside_body(node))) {
460-
// If Constant has one consumer - target node, we add Constant to body_inputs
461-
// If Constant has several consumers, we should check that all these consumers are inside Subgraph body
462-
// and if all of them are inside body, we can explicitly add Constant to the body_inputs, otherwise we
463-
// should make a copy and add copy of Constant to body_inputs For example, this case is especially valid
464-
// for Transposes nodes
465-
// (several Transposes have the same order so there can be the common Constant with this order)
466-
if (constant->get_output_target_inputs(0).size() == 1) {
467-
body_inputs.push_back(input.get_source_output());
468-
} else {
459+
// If not all Constant consumers are inside Subgraph body,
460+
// we should make a copy of this Constant for Subgraph body.
461+
if (constant->get_output_target_inputs(0).size() > 1) {
469462
const auto constant_consumers = constant->get_output_target_inputs(0);
470-
bool all_consumers_are_inside =
471-
std::all_of(constant_consumers.begin(),
463+
bool has_external_consumers =
464+
std::any_of(constant_consumers.begin(),
472465
constant_consumers.end(),
473466
[&ordered_ops](const ov::Input<ov::Node>& input) {
474467
return std::find(ordered_ops.begin(),
475468
ordered_ops.end(),
476-
input.get_node()->shared_from_this()) != ordered_ops.end();
469+
input.get_node()->shared_from_this()) == ordered_ops.end();
477470
});
478-
if (all_consumers_are_inside) {
479-
body_inputs.push_back(input.get_source_output());
480-
} else {
471+
if (has_external_consumers) {
481472
const auto constant_copy = constant->clone_with_new_inputs({});
482473
node->set_argument(input.get_index(), constant_copy);
483-
body_inputs.emplace_back(constant_copy);
484474
}
485475
}
486476
} else if (std::find(ordered_ops.begin(), ordered_ops.end(), parent) == ordered_ops.end()) {
487-
auto parameter =
488-
std::make_shared<ov::opset1::Parameter>(input.get_element_type(), input.get_partial_shape());
489-
body_parameters.push_back(parameter);
490-
body_parameters.back()->set_friendly_name(input.get_node()->get_friendly_name());
491-
body_inputs.push_back(parameter->output(0));
492-
493-
subgraph_inputs.push_back(input.get_source_output());
494-
477+
const auto& parent_output = input.get_source_output();
478+
auto it = std::find(subgraph_inputs.begin(), subgraph_inputs.end(), parent_output);
479+
if (!are_shared_internal_params_allowed || it == subgraph_inputs.end()) {
480+
auto new_param =
481+
std::make_shared<ov::op::v0::Parameter>(input.get_element_type(), input.get_partial_shape());
482+
new_param->set_friendly_name(input.get_node()->get_friendly_name());
483+
subgraph_inputs.push_back(parent_output);
484+
body_parameters.push_back(new_param);
485+
it = subgraph_inputs.end() - 1;
486+
}
487+
const auto param_index = static_cast<size_t>(std::distance(subgraph_inputs.begin(), it));
488+
const auto& parameter = body_parameters[param_index];
495489
node->input(i).replace_source_output(parameter);
496490
}
497491
}
@@ -511,20 +505,17 @@ std::shared_ptr<ov::snippets::op::Subgraph> tokenize_ordered_nodes(const ov::Nod
511505
}
512506
}
513507

508+
ov::ResultVector body_results;
509+
std::vector<std::set<Input<Node>>> subgraph_result_inputs;
514510
for (const auto& output : last_node->outputs()) {
511+
// Note: since we need to save only original consumers,
512+
// subgraph_result_inputs must be taken before result creation
515513
subgraph_result_inputs.push_back(output.get_target_inputs());
516-
}
517-
for (const auto& output : last_node->outputs()) {
518514
body_results.push_back(std::make_shared<ov::opset1::Result>(last_node->output(output.get_index())));
519515
}
520516

521-
if (body_results.size() != subgraph_result_inputs.size()) {
522-
OPENVINO_THROW("body results and node results size mismatch during subgraph collapse");
523-
}
524-
525517
auto body = op::create_body(last_node->get_friendly_name(), body_results, body_parameters);
526518
auto subgraph = std::make_shared<op::Subgraph>(subgraph_inputs, body);
527-
// Copy runtime info from last node to subgraph - to copy topological order
528519
copy_runtime_info(last_node, subgraph);
529520
subgraph->set_friendly_name(last_node->get_friendly_name());
530521

src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/subgraph_serialize.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ TEST_F(SubgraphSnippetSerializationTest, smoke_SerializeSubgraph) {
3535
auto add = std::make_shared<Add>(ininput0, ininput1);
3636
auto subgraph_body =
3737
std::make_shared<ov::Model>(ov::OutputVector{add}, ov::ParameterVector{ininput0, ininput1});
38-
auto subgraph = std::make_shared<ov::snippets::op::Subgraph>(ov::NodeVector{input0, input1}, subgraph_body.get()->clone());
38+
auto subgraph = std::make_shared<ov::snippets::op::Subgraph>(ov::OutputVector{input0, input1}, subgraph_body.get()->clone());
3939
return std::make_shared<ov::Model>(ov::OutputVector{subgraph}, ov::ParameterVector{input0, input1});
4040
})();
4141
ov::Core core;
@@ -84,7 +84,7 @@ TEST_F(SubgraphSnippetSerializationTest, smoke_SerializeSubgraphWithScalarConst)
8484
auto internal_add = std::make_shared<Add>(internal_input, internal_constant);
8585
auto subgraph_body =
8686
std::make_shared<ov::Model>(ov::OutputVector{internal_add}, ov::ParameterVector{internal_input});
87-
auto subgraph = std::make_shared<ov::snippets::op::Subgraph>(ov::NodeVector{add}, subgraph_body.get()->clone());
87+
auto subgraph = std::make_shared<ov::snippets::op::Subgraph>(ov::OutputVector{add}, subgraph_body.get()->clone());
8888
return std::make_shared<ov::Model>(ov::OutputVector{subgraph}, ov::ParameterVector{input});
8989
})();
9090
ov::Core core;

src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_transposed_b.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,23 @@ INSTANTIATE_TEST_SUITE_P(
5050
::testing::Values(CPUTestUtils::empty_plugin_config)),
5151
MHA::getTestCaseName);
5252

53+
std::vector<std::vector<ov::test::InputShape>> shared_kv_shapes = {{
54+
{PartialShape{-1, -1, -1, -1}, {{1, 3, 64, 128}}},
55+
{PartialShape{-1, -1, -1, -1}, {{1, 3, 64, 128}}},
56+
}};
57+
58+
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_SharedKV,
59+
MHASharedKV,
60+
::testing::Combine(::testing::ValuesIn(shared_kv_shapes),
61+
::testing::ValuesIn(precision_f32(2)),
62+
::testing::Values(ov::element::f32),
63+
::testing::Values(false),
64+
::testing::Values(MHA::default_thread_count),
65+
::testing::Values(expected_num_nodes),
66+
::testing::Values(1),
67+
::testing::Values(ov::test::utils::DEVICE_CPU),
68+
::testing::Values(CPUTestUtils::empty_plugin_config)),
69+
MHA::getTestCaseName);
5370
} // namespace
5471
} // namespace snippets
5572
} // namespace test

src/tests/functional/plugin/shared/include/snippets/mha.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ class MHAWithDynamicMul : public testing::WithParamInterface<ov::test::snippets:
134134
void init_params(std::vector<InputShape>& input_shapes, ov::element::Type& prc, ov::AnyMap& additional_config) override;
135135
};
136136

137+
class MHASharedKV : public MHA {
138+
protected:
139+
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
140+
};
141+
137142
} // namespace snippets
138143
} // namespace test
139144
} // namespace ov

src/tests/functional/plugin/shared/src/snippets/mha.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ std::shared_ptr<SnippetsFunctionBase> MHAWithDynamicMul::get_subgraph() const {
246246
return std::make_shared<ov::test::snippets::MHAWithDynamicMulFunction>(inputDynamicShapes, m_input_types);
247247
}
248248

249+
std::shared_ptr<SnippetsFunctionBase> MHASharedKV::get_subgraph() const {
250+
return std::make_shared<ov::test::snippets::MHASharedKVFunction>(inputDynamicShapes, m_input_types);
251+
}
252+
249253
TEST_P(MHA, CompareWithRefImpl) {
250254
SKIP_IF_CURRENT_TEST_IS_DISABLED()
251255
run();
@@ -331,6 +335,12 @@ TEST_P(MHAWithDynamicMul, CompareWithRefImpl) {
331335
validateNumSubgraphs();
332336
}
333337

338+
TEST_P(MHASharedKV, CompareWithRefImpl) {
339+
SKIP_IF_CURRENT_TEST_IS_DISABLED()
340+
run();
341+
validateNumSubgraphs();
342+
}
343+
334344
} // namespace snippets
335345
} // namespace test
336346
} // namespace ov

src/tests/ov_helpers/ov_snippets_models/include/subgraph_mha.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,32 @@ class MHARankUpgradeToReductionFunction : public SnippetsFunctionBase {
494494
std::shared_ptr<ov::Model> initReference() const override;
495495
};
496496

497+
/* Graph:
498+
* input0 input1
499+
* \ / \
500+
* MatMul0 \
501+
* | \
502+
* Softmax \
503+
* \ /
504+
* MatMul1
505+
* Note: This is a MHA pattern with shared K and V inputs, duplicating one of the python TF tests
506+
*/
507+
class MHASharedKVFunction : public SnippetsFunctionBase {
508+
public:
509+
explicit MHASharedKVFunction(const std::vector<PartialShape>& inputShapes,
510+
const std::vector<ov::element::Type>& precisions)
511+
: SnippetsFunctionBase(inputShapes),
512+
precisions(precisions) {
513+
OPENVINO_ASSERT(input_shapes.size() == 2, "Got invalid number of input shapes");
514+
OPENVINO_ASSERT(precisions.size() == 2, "Got invalid number of input precisions");
515+
}
516+
517+
protected:
518+
std::shared_ptr<ov::Model> initOriginal() const override;
519+
520+
const std::vector<ov::element::Type> precisions;
521+
};
522+
497523
} // namespace snippets
498524
} // namespace test
499525
} // namespace ov

0 commit comments

Comments
 (0)