Skip to content

Commit 2ca6e64

Browse files
committed
Only fast exit for non-shm cases
1 parent e15b255 commit 2ca6e64

File tree

9 files changed

+27
-14
lines changed

9 files changed

+27
-14
lines changed

src/c++/perf_analyzer/infer_context.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ InferContext::AsyncCallbackFuncImpl(cb::InferResult* result)
298298
// Add the request record to thread request records vector with
299299
// proper locking
300300
std::lock_guard<std::mutex> lock(thread_stat_->mu_);
301-
if (exiting_) {
301+
if (exiting_ && fast_exit_) {
302302
return;
303303
}
304304

src/c++/perf_analyzer/infer_context.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,11 @@ class InferContext {
105105
void Init();
106106

107107
// Signal to the context to stop working and exit
108-
void Exit() { exiting_ = true; }
108+
void Exit(bool fast_exit)
109+
{
110+
exiting_ = true;
111+
fast_exit_ = fast_exit;
112+
}
109113

110114
// Send a single inference request to the server
111115
void SendInferRequest(bool delayed = false);
@@ -196,6 +200,7 @@ class InferContext {
196200
const uint32_t id_{0};
197201
const size_t thread_id_{0};
198202
bool exiting_{false};
203+
bool fast_exit_{false};
199204

200205
size_t GetNumActiveThreads() { return num_active_threads_; }
201206

src/c++/perf_analyzer/iworker.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace triton { namespace perfanalyzer {
3333
class IWorker {
3434
public:
3535
virtual void Infer() = 0;
36-
virtual void Exit() = 0;
36+
virtual void Exit(bool fast_exit) = 0;
3737
};
3838

3939
}} // namespace triton::perfanalyzer

src/c++/perf_analyzer/load_manager.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ LoadManager::LoadManager(
164164
const std::unordered_map<std::string, cb::RequestParameter>&
165165
request_parameters)
166166
: async_(async), streaming_(streaming), batch_size_(batch_size),
167-
max_threads_(max_threads), parser_(parser), factory_(factory),
168-
using_json_data_(false)
167+
max_threads_(max_threads), shared_memory_type_{shared_memory_type},
168+
parser_(parser), factory_(factory), using_json_data_(false)
169169
{
170170
on_sequence_model_ =
171171
((parser_->SchedulerType() == ModelParser::SEQUENCE) ||
@@ -248,9 +248,11 @@ LoadManager::InitManagerInputs(
248248
void
249249
LoadManager::StopWorkerThreads()
250250
{
251+
bool fast_exit = shared_memory_type_ == SharedMemoryType::NO_SHARED_MEMORY;
252+
251253
// FIXME do I need to acquire the lock first?
252254
for (auto& worker : workers_) {
253-
worker->Exit();
255+
worker->Exit(fast_exit);
254256
}
255257

256258
{

src/c++/perf_analyzer/load_manager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ class LoadManager {
140140
size_t batch_size_;
141141
size_t max_threads_;
142142
bool on_sequence_model_;
143+
SharedMemoryType shared_memory_type_;
143144

144145
std::shared_ptr<ModelParser> parser_;
145146
std::shared_ptr<cb::ClientBackendFactory> factory_;

src/c++/perf_analyzer/load_worker.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@
3535
namespace triton { namespace perfanalyzer {
3636

3737
void
38-
LoadWorker::Exit()
38+
LoadWorker::Exit(bool fast_exit)
3939
{
4040
for (auto ctx : ctxs_) {
41-
ctx->Exit();
41+
ctx->Exit(fast_exit);
4242
}
4343

4444
exiting_ = true;
45+
fast_exit_ = fast_exit;
4546

4647
{
4748
std::lock_guard<std::mutex> lk(cb_mtx_);
@@ -67,6 +68,9 @@ LoadWorker::HandleExitConditions()
6768
{
6869
if (ShouldExit()) {
6970
CompleteOngoingSequences();
71+
if (!fast_exit_) {
72+
WaitForOngoingRequests();
73+
}
7074
return true;
7175
}
7276
return false;
@@ -86,7 +90,7 @@ LoadWorker::CompleteOngoingSequences()
8690
void
8791
LoadWorker::WaitForOngoingRequests()
8892
{
89-
while (GetNumOngoingRequests() != 0) {
93+
while (GetNumOngoingRequests() != 0 && !fast_exit_) {
9094
std::this_thread::sleep_for(std::chrono::milliseconds(50));
9195
}
9296
}

src/c++/perf_analyzer/load_worker.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class LoadWorker : public IWorker {
6969

7070
virtual ~LoadWorker() = default;
7171

72-
virtual void Exit() override;
72+
virtual void Exit(bool fast_exit) override;
7373

7474
protected:
7575
// Return the total number of async requests that have started and not
@@ -120,6 +120,7 @@ class LoadWorker : public IWorker {
120120
void AsyncCallbackFinalize(uint32_t ctx_id);
121121

122122
bool exiting_ = false;
123+
bool fast_exit_ = false;
123124

124125
uint32_t id_;
125126

src/c++/perf_analyzer/test_concurrency_manager.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ TEST_CASE("concurrency_free_ctx_ids")
474474

475475
std::this_thread::sleep_for(std::chrono::milliseconds(15));
476476

477-
worker->Exit();
477+
worker->Exit(false);
478478
infer_future.get();
479479

480480
// The first sequence should only be called two times, once at the very start,
@@ -590,7 +590,7 @@ TEST_CASE("Concurrency - shared memory infer input calls")
590590

591591
std::this_thread::sleep_for(std::chrono::milliseconds(18));
592592

593-
worker->Exit();
593+
worker->Exit(false);
594594
infer_future.get();
595595

596596
const auto& actual_append_raw_calls{tcm.stats_->num_append_raw_calls};

src/c++/perf_analyzer/test_request_rate_manager.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ TEST_CASE("request_rate_streaming: test that streaming-specific logic works")
975975
std::dynamic_pointer_cast<IScheduler>(worker)->SetSchedule(schedule);
976976
std::future<void> infer_future{std::async(&IWorker::Infer, worker)};
977977

978-
worker->Exit();
978+
worker->Exit(false);
979979
infer_future.get();
980980

981981
CHECK(
@@ -1825,7 +1825,7 @@ TEST_CASE("Request rate - Shared memory infer input calls")
18251825

18261826
std::this_thread::sleep_for(milliseconds(18));
18271827

1828-
worker->Exit();
1828+
worker->Exit(false);
18291829
infer_future.get();
18301830

18311831
const auto& actual_append_raw_calls{trrm.stats_->num_append_raw_calls};

0 commit comments

Comments
 (0)