Skip to content

Commit fe465f3

Browse files
committed
Drop intel_npu::SyncInferRequest and move logic to intel_npu::ZeroInferRequest
1 parent 99d9d9a commit fe465f3

File tree

12 files changed

+676
-634
lines changed

12 files changed

+676
-634
lines changed

src/plugins/intel_npu/src/backend/include/zero_device.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ class ZeroDevice : public IDevice {
3232
std::map<ov::element::Type, float> getGops() const override;
3333
ov::device::Type getDeviceType() const override;
3434

35-
std::shared_ptr<SyncInferRequest> createInferRequest(const std::shared_ptr<const ICompiledModel>& compiledModel,
36-
const Config& config) override;
35+
std::shared_ptr<ov::IInferRequest> createInferRequest(const std::shared_ptr<const ICompiledModel>& compiledModel,
36+
const Config& config,
37+
std::function<void(void)>& inferAsyncF,
38+
std::function<void(void)>& getResultF) override;
3739
void updateInfo(const Config& config) override {
3840
log.setLevel(config.get<LOG_LEVEL>());
3941
}

src/plugins/intel_npu/src/backend/include/zero_infer_request.hpp

Lines changed: 172 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include <ze_graph_ext.h>
99

1010
#include "intel_npu/common/npu.hpp"
11-
#include "intel_npu/common/sync_infer_request.hpp"
1211
#include "intel_npu/utils/logger/logger.hpp"
1312
#include "intel_npu/utils/zero/zero_remote_tensor.hpp"
1413
#include "intel_npu/utils/zero/zero_utils.hpp"
@@ -18,23 +17,162 @@
1817

1918
namespace intel_npu {
2019

21-
class ZeroInferRequest final : public SyncInferRequest {
20+
class ZeroInferRequest final : public ov::IInferRequest {
2221
public:
2322
explicit ZeroInferRequest(const std::shared_ptr<ZeroInitStructsHolder>& initStructs,
2423
const std::shared_ptr<const ICompiledModel>& compiledModel,
2524
const Config& config);
2625

26+
/**
27+
* @brief Gets an input/output tensor for inference.
28+
* @note If the tensor with the specified @p port is not found, an exception is thrown.
29+
* @param port Port of the tensor to get.
30+
* @return Tensor for the port @p port.
31+
*/
2732
ov::SoPtr<ov::ITensor> get_tensor(const ov::Output<const ov::Node>& port) const override;
33+
34+
/**
35+
* @brief Sets an input/output tensor to infer.
36+
* @param port Port of the input or output tensor.
37+
* @param tensor Reference to a tensor. The element_type and shape of a tensor must match
38+
* the model's input/output element_type and size.
39+
*/
2840
void set_tensor(const ov::Output<const ov::Node>& port, const ov::SoPtr<ov::ITensor>& tensor) override;
41+
42+
/**
43+
* @brief Gets an input/output tensor for inference.
44+
* @note If the tensor with the specified @p port is not found, am exception is thrown.
45+
* @param port Port of the batched tensors to get.
46+
* @return Vector of batched tensors for the input port @p port or empty vector if port is output.
47+
*/
48+
std::vector<ov::SoPtr<ov::ITensor>> get_tensors(const ov::Output<const ov::Node>& port) const override;
49+
50+
/**
51+
* @brief Sets batched input tensors to infer
52+
* @param port Port of the batched input tensor.
53+
* @param tensors Vector of references to batched tensors. The element_type and shape of each must match.
54+
* @note Batched tensors for outputs is not supported.
55+
* @note If single element vector is provided for @p tensors param, fallback to "set_tensor" function will occur.
56+
*/
2957
void set_tensors(const ov::Output<const ov::Node>& port,
3058
const std::vector<ov::SoPtr<ov::ITensor>>& tensors) override;
3159

60+
/**
61+
* @brief Gets inputs for infer request
62+
*
63+
* @return vector of input ports
64+
*/
65+
const std::vector<ov::Output<const ov::Node>>& get_inputs() const override;
66+
67+
/**
68+
* @brief Gets outputs for infer request
69+
*
70+
* @return vector of output ports
71+
*/
72+
const std::vector<ov::Output<const ov::Node>>& get_outputs() const override;
73+
74+
/**
75+
* @brief Gets pointer to compiled model (usually synchronous request holds the compiled model)
76+
*
77+
* @return Pointer to the compiled model
78+
*/
79+
const std::shared_ptr<const ov::ICompiledModel>& get_compiled_model() const override;
80+
81+
/**
82+
* @brief Calls "infer_async" then "get_result"
83+
*/
3284
void infer() override;
33-
void infer_async() override;
3485

35-
void get_result() override;
86+
/**
87+
* @brief Used for executing the inference.
88+
*/
89+
void infer_async();
90+
91+
/**
92+
* @brief Used for retrieving the prediction's result.
93+
*/
94+
void get_result();
95+
96+
/**
97+
* @brief Used for retrieving the current values of the network's variables.
98+
* @return Vector of each state value
99+
*/
100+
std::vector<ov::SoPtr<ov::IVariableState>> query_state() const override;
101+
102+
/**
103+
* @brief Initializes the tensor values corresponding to the state variables.
104+
* @details The inital values are usually all 0s.
105+
*/
106+
void initialize_states();
36107

37108
private:
109+
/**
110+
* @see ov::ISyncInferRequest
111+
*/
112+
struct FoundPort {
113+
size_t idx;
114+
enum class Type { NOT_FOUND = 0, INPUT, OUTPUT } type;
115+
116+
bool found() {
117+
return type != Type::NOT_FOUND;
118+
}
119+
bool is_input() {
120+
return type == Type::INPUT;
121+
}
122+
bool is_output() {
123+
return type == Type::OUTPUT;
124+
}
125+
};
126+
127+
/**
128+
* @brief Finds input or output port
129+
* @return structure which contains index of Input/Output or report that port wasn't found
130+
* @see ov::ISyncInferRequest
131+
*/
132+
FoundPort find_port(const ov::Output<const ov::Node>& port) const;
133+
134+
/**
135+
* @brief Basic checks for input/output tensor
136+
*
137+
* @param port Input/Output port
138+
* @param tensor Input/Output tensor
139+
*/
140+
void check_tensor(const ov::Output<const ov::Node>& port, const ov::SoPtr<ov::ITensor>& tensor) const;
141+
142+
/**
143+
* @brief Basic checks for input tensors
144+
*
145+
* @param port Input port
146+
* @param tensors Input tensors
147+
*/
148+
void check_batched_tensors(const ov::Output<const ov::Node>& port,
149+
const std::vector<ov::SoPtr<ov::ITensor>>& tensors) const;
150+
151+
/**
152+
* @brief Check that all tensors are valid. Throws an exception if it's not.
153+
*/
154+
void check_tensors() const override;
155+
156+
/**
157+
* @brief Allocates a tensor on host and stores the reference inside multiple attributes.
158+
* @param descriptor Tensor's metadata
159+
* @param index The index which the allocated tensor shall use.
160+
* @param isInput Determines the containers in which the newly allocated tensors will be stored.
161+
* @param allocator If provided, the tensor uses the custom allocator instead of using the default one.
162+
* @param batchSize If provided, the value of the shape on the 0th axis is overriden with this value.
163+
* @return Pointer towards the allocated tensor
164+
*/
165+
std::shared_ptr<ov::ITensor> allocate_tensor(const IODescriptor& descriptor,
166+
const size_t index,
167+
const bool isInput,
168+
const ov::Allocator& allocator = {},
169+
const std::optional<std::size_t> batchSize = std::nullopt) const;
170+
171+
bool is_batched_input(size_t idx) const;
172+
173+
ov::SoPtr<ov::ITensor>& get_user_input(size_t index) const;
174+
std::vector<ov::SoPtr<ov::ITensor>>& get_user_inputs(size_t index) const;
175+
38176
std::vector<ov::ProfilingInfo> get_profiling_info() const override;
39177

40178
/**
@@ -55,23 +193,32 @@ class ZeroInferRequest final : public SyncInferRequest {
55193
const size_t index,
56194
const bool isInput);
57195

58-
void check_network_precision(const ov::element::Type_t precision) const override;
196+
/**
197+
* @brief Checks if the provided precision value is supported by the current backend, should throw an error
198+
* otherwise.
199+
* @param precision The precision value to be checked.
200+
*/
201+
void check_network_precision(const ov::element::Type_t precision) const;
59202
void create_pipeline();
60203

61204
std::shared_ptr<ov::ITensor>& get_level_zero_input(size_t index, size_t tensorNo = 0) const;
62205
std::vector<std::shared_ptr<ov::ITensor>>& get_level_zero_inputs(size_t index) const;
63206

64-
std::shared_ptr<ov::ITensor> create_tensor(ov::element::Type type,
65-
const ov::Shape& shape,
66-
const ov::Allocator& allocator = {}) const override;
207+
std::shared_ptr<ZeroTensor> create_tensor(ov::element::Type type,
208+
const ov::Shape& shape,
209+
const ov::Allocator& allocator = {}) const;
67210

68-
void add_state(const IODescriptor& descriptor, size_t tensorIndex) const override;
211+
void add_state(const IODescriptor& descriptor, size_t tensorIndex) const;
69212

70213
void update_pipeline_if_memory_changed();
71214
void update_states_if_memory_changed();
72215

73216
const std::shared_ptr<ZeroInitStructsHolder> _initStructs;
74217
const std::shared_ptr<IGraph> _graph;
218+
NetworkMetadata _metadata;
219+
// This is intel_npu::ICompiledModel pointer, but need to use OV base class because
220+
// ov::IInferRequest::get_compiled_model returns a refernce to shared_ptr!
221+
std::shared_ptr<const ov::ICompiledModel> _compiledModel;
75222
const Config _config;
76223
Logger _logger;
77224

@@ -83,6 +230,12 @@ class ZeroInferRequest final : public SyncInferRequest {
83230
mutable std::vector<std::vector<std::shared_ptr<ov::ITensor>>> _levelZeroInputTensors;
84231
mutable std::vector<std::shared_ptr<ov::ITensor>> _levelZeroOutputTensors;
85232

233+
// In case set_tensors is called, we receive a vector with N tensors otherwise only 1 tensor is needed
234+
mutable std::vector<std::vector<ov::SoPtr<ov::ITensor>>> _userInputTensors;
235+
mutable std::vector<ov::SoPtr<ov::ITensor>> _userOutputTensors;
236+
237+
mutable std::vector<ov::SoPtr<ov::IVariableState>> _variableStates;
238+
86239
std::shared_ptr<const zeroMemory::HostMemAllocator> _inputAllocator;
87240
std::shared_ptr<const zeroMemory::HostMemAllocator> _outputAllocator;
88241

@@ -91,6 +244,16 @@ class ZeroInferRequest final : public SyncInferRequest {
91244
bool _pipelineIsCreated = false;
92245
bool _dynamicBatchValueChanged = false;
93246
bool _externalMemoryStandardAllocationSupported = false;
247+
248+
/**
249+
* @see ov::ISyncInferRequest
250+
*/
251+
mutable std::unordered_map<size_t, FoundPort> _cachedPorts;
252+
253+
/**
254+
* @see ov::ISyncInferRequest
255+
*/
256+
mutable std::mutex _cacheMutex;
94257
};
95258

96259
} // namespace intel_npu

src/plugins/intel_npu/src/backend/include/zero_memory.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class HostMemAllocator {
3838
* @param handle Pointer to allocated data
3939
* @return false if handle cannot be released, otherwise - true.
4040
*/
41-
bool deallocate(void* handle, const size_t bytes, size_t alignment = utils::STANDARD_PAGE_SIZE) noexcept;
41+
virtual bool deallocate(void* handle, const size_t bytes, size_t alignment = utils::STANDARD_PAGE_SIZE) noexcept;
4242

4343
bool is_equal(const HostMemAllocator& other) const;
4444

src/plugins/intel_npu/src/backend/src/zero_device.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,23 @@ ov::device::Type ZeroDevice::getDeviceType() const {
169169
return ov::device::Type::INTEGRATED;
170170
}
171171

172-
std::shared_ptr<SyncInferRequest> ZeroDevice::createInferRequest(
172+
std::shared_ptr<ov::IInferRequest> ZeroDevice::createInferRequest(
173173
const std::shared_ptr<const ICompiledModel>& compiledModel,
174-
const Config& config) {
175-
return std::make_shared<ZeroInferRequest>(_initStructs, compiledModel, config);
174+
const Config& config,
175+
std::function<void(void)>& inferAsyncF,
176+
std::function<void(void)>& getResultF) {
177+
auto inferRequest = std::make_shared<ZeroInferRequest>(_initStructs, compiledModel, config);
178+
inferAsyncF = [&inferRequest]() {
179+
if (!inferRequest) {
180+
OPENVINO_THROW("ZeroInferRequest object was destroyed!");
181+
}
182+
inferRequest->infer_async();
183+
};
184+
getResultF = [&inferRequest]() {
185+
if (!inferRequest) {
186+
OPENVINO_THROW("ZeroInferRequest object was destroyed!");
187+
}
188+
inferRequest->get_result();
189+
};
190+
return inferRequest;
176191
}

0 commit comments

Comments
 (0)