9
9
#include < climits>
10
10
#include < cstddef>
11
11
#include < cstdint>
12
+ #include < iterator>
12
13
#include < map>
13
14
#include < memory>
14
15
#include < numeric>
@@ -442,56 +443,49 @@ bool tokenize_node(const std::shared_ptr<ov::Node>& node, const SnippetsTokeniza
442
443
return true ;
443
444
}
444
445
445
- std::shared_ptr<ov::snippets::op::Subgraph> tokenize_ordered_nodes (const ov::NodeVector& ordered_ops) {
446
+ std::shared_ptr<ov::snippets::op::Subgraph> tokenize_ordered_nodes (const ov::NodeVector& ordered_ops,
447
+ bool are_shared_internal_params_allowed) {
446
448
OPENVINO_ASSERT (!ordered_ops.empty (), " Nothing to be tokenized!" );
447
449
448
- ov::OutputVector body_inputs, subgraph_inputs;
450
+ ov::OutputVector subgraph_inputs;
449
451
ov::ParameterVector body_parameters;
450
- ov::ResultVector body_results;
451
- std::vector<std::set<Input<Node>>> subgraph_result_inputs;
452
-
453
452
auto create_body_inputs = [&](const std::shared_ptr<ov::Node>& node) -> void {
454
453
for (size_t i = 0 ; i < node->get_input_size (); ++i) {
455
454
const auto input = node->input (i);
456
455
const auto parent = input.get_source_output ().get_node_shared_ptr ();
457
456
const auto constant = ov::as_type_ptr<ov::op::v0::Constant>(parent);
458
457
if (constant && (ov::shape_size (input.get_shape ()) == 1 || ov::is_type<ov::op::v0::FakeQuantize>(node) ||
459
458
op::Subgraph::constant_input_should_be_inside_body (node))) {
460
- // If Constant has one consumer - target node, we add Constant to body_inputs
461
- // If Constant has several consumers, we should check that all these consumers are inside Subgraph body
462
- // and if all of them are inside body, we can explicitly add Constant to the body_inputs, otherwise we
463
- // should make a copy and add copy of Constant to body_inputs For example, this case is especially valid
464
- // for Transposes nodes
465
- // (several Transposes have the same order so there can be the common Constant with this order)
466
- if (constant->get_output_target_inputs (0 ).size () == 1 ) {
467
- body_inputs.push_back (input.get_source_output ());
468
- } else {
459
+ // If not all Constant consumers are inside Subgraph body,
460
+ // we should make a copy of this Constant for Subgraph body.
461
+ if (constant->get_output_target_inputs (0 ).size () > 1 ) {
469
462
const auto constant_consumers = constant->get_output_target_inputs (0 );
470
- bool all_consumers_are_inside =
471
- std::all_of (constant_consumers.begin (),
463
+ bool has_external_consumers =
464
+ std::any_of (constant_consumers.begin (),
472
465
constant_consumers.end (),
473
466
[&ordered_ops](const ov::Input<ov::Node>& input) {
474
467
return std::find (ordered_ops.begin (),
475
468
ordered_ops.end (),
476
- input.get_node ()->shared_from_this ()) ! = ordered_ops.end ();
469
+ input.get_node ()->shared_from_this ()) = = ordered_ops.end ();
477
470
});
478
- if (all_consumers_are_inside) {
479
- body_inputs.push_back (input.get_source_output ());
480
- } else {
471
+ if (has_external_consumers) {
481
472
const auto constant_copy = constant->clone_with_new_inputs ({});
482
473
node->set_argument (input.get_index (), constant_copy);
483
- body_inputs.emplace_back (constant_copy);
484
474
}
485
475
}
486
476
} else if (std::find (ordered_ops.begin (), ordered_ops.end (), parent) == ordered_ops.end ()) {
487
- auto parameter =
488
- std::make_shared<ov::opset1::Parameter>(input.get_element_type (), input.get_partial_shape ());
489
- body_parameters.push_back (parameter);
490
- body_parameters.back ()->set_friendly_name (input.get_node ()->get_friendly_name ());
491
- body_inputs.push_back (parameter->output (0 ));
492
-
493
- subgraph_inputs.push_back (input.get_source_output ());
494
-
477
+ const auto & parent_output = input.get_source_output ();
478
+ auto it = std::find (subgraph_inputs.begin (), subgraph_inputs.end (), parent_output);
479
+ if (!are_shared_internal_params_allowed || it == subgraph_inputs.end ()) {
480
+ auto new_param =
481
+ std::make_shared<ov::op::v0::Parameter>(input.get_element_type (), input.get_partial_shape ());
482
+ new_param->set_friendly_name (input.get_node ()->get_friendly_name ());
483
+ subgraph_inputs.push_back (parent_output);
484
+ body_parameters.push_back (new_param);
485
+ it = subgraph_inputs.end () - 1 ;
486
+ }
487
+ const auto param_index = static_cast <size_t >(std::distance (subgraph_inputs.begin (), it));
488
+ const auto & parameter = body_parameters[param_index];
495
489
node->input (i).replace_source_output (parameter);
496
490
}
497
491
}
@@ -511,20 +505,17 @@ std::shared_ptr<ov::snippets::op::Subgraph> tokenize_ordered_nodes(const ov::Nod
511
505
}
512
506
}
513
507
508
+ ov::ResultVector body_results;
509
+ std::vector<std::set<Input<Node>>> subgraph_result_inputs;
514
510
for (const auto & output : last_node->outputs ()) {
511
+ // Note: since we need to save only original consumers,
512
+ // subgraph_result_inputs must be taken before result creation
515
513
subgraph_result_inputs.push_back (output.get_target_inputs ());
516
- }
517
- for (const auto & output : last_node->outputs ()) {
518
514
body_results.push_back (std::make_shared<ov::opset1::Result>(last_node->output (output.get_index ())));
519
515
}
520
516
521
- if (body_results.size () != subgraph_result_inputs.size ()) {
522
- OPENVINO_THROW (" body results and node results size mismatch during subgraph collapse" );
523
- }
524
-
525
517
auto body = op::create_body (last_node->get_friendly_name (), body_results, body_parameters);
526
518
auto subgraph = std::make_shared<op::Subgraph>(subgraph_inputs, body);
527
- // Copy runtime info from last node to subgraph - to copy topological order
528
519
copy_runtime_info (last_node, subgraph);
529
520
subgraph->set_friendly_name (last_node->get_friendly_name ());
530
521
0 commit comments