[go: up one dir, main page]

Skip to content

Commit

Permalink
optimize thread yield in thread pool
Browse files Browse the repository at this point in the history
  • Loading branch information
chenqy4933 committed Aug 15, 2023
1 parent e2b820b commit 9bc347c
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 13 deletions.
47 changes: 34 additions & 13 deletions src/core/thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,28 @@ ThreadPool::ThreadPool(uint32_t threads_num)
while (!m_stop) {
while (m_active) {
//! if the thread should work
if (m_workers[i]->work_flag) {
// printf("thread %d work form %d to %d\n", i,
// i * m_task_per_thread,
// (i + 1) * m_task_per_thread);
if (m_workers[i]->work_flag.load(std::memory_order_acquire)) {
m_task(TaskId{
i * m_task_per_thread,
std::min((i + 1) * m_task_per_thread, m_nr_task),
i});
// printf("thread %d finished\n", i);
//! Flag worker is finished
m_workers[i]->work_flag = false;
m_workers[i]->work_flag.store(
false, std::memory_order_release);
}
//! Wait next task coming
//std::this_thread::yield();
for (int it = 0; it < WORKER_ACTIVE_WAIT; it++) {
if (m_workers[i]->work_flag.load(
std::memory_order_acquire)) {
break;
}
if (it < ACTIVE_WAIT_PAUSE_LIMIT || (it & 1)) {
INFER_PAUSE(16); // Spin lock's CPU-level yield
} else {
// Spin lock's OS-level yield
std::this_thread::yield();
}
}
}
{
std::unique_lock<std::mutex> lock(m_mutex);
Expand All @@ -53,12 +61,13 @@ void ThreadPool::add_task(const MultiThreadingTask& task, uint32_t nr_task) {
return;
} else {
active();
INFER_ASSERT(m_active, "thread pool is not actived.");
m_nr_task = nr_task;
//! Set the task number, task iter and task
m_task_per_thread = (nr_task + m_nr_threads - 1) / m_nr_threads;
m_task = std::move(task);
for (uint32_t i = 0; i < m_nr_threads - 1; i++) {
m_workers[i]->work_flag = true;
m_workers[i]->work_flag.store(true, std::memory_order_release);
}
//! Main thread working
uint32_t start = (m_nr_threads - 1) * m_task_per_thread;
Expand All @@ -71,17 +80,29 @@ void ThreadPool::add_task(const MultiThreadingTask& task, uint32_t nr_task) {

inline void ThreadPool::sync() {
bool no_finished = false;
uint32_t no_finished_id = 0;
do {
no_finished = false;
for (uint32_t i = 0; i < m_nr_threads - 1; ++i) {
if (m_workers[i]->work_flag) {
for (uint32_t i = no_finished_id; i < m_nr_threads - 1; ++i) {
if (m_workers[i]->work_flag.load(std::memory_order_acquire)) {
no_finished = true;
no_finished_id = i;
break;
}
}
// if (no_finished) {
// std::this_thread::yield();
// }
if (no_finished) {
for (int it = 0; it < MAIN_THREAD_ACTIVE_WAIT; it++) {
if (!m_workers[no_finished_id]->work_flag.load(
std::memory_order_acquire)) {
break;
}
if ((it < ACTIVE_WAIT_PAUSE_LIMIT || (it & 1))) {
INFER_PAUSE(16);
} else {
std::this_thread::yield();
}
}
}
} while (no_finished);
}
inline void ThreadPool::active() {
Expand Down
34 changes: 34 additions & 0 deletions src/core/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,33 @@
#include "kern/kernel_define.h"
#include "utils.h"

// clang-format off
#ifndef INFER_PAUSE
# if defined __GNUC__ && (defined __i386__ || defined __x86_64__)
# if !defined(__SSE2__)
static inline void non_sse_mm_pause() { __asm__ __volatile__ ("rep; nop"); }
# define _mm_pause non_sse_mm_pause
# else
# include <immintrin.h>
# endif
# define INFER_PAUSE(v) do { for (int __delay = (v); __delay > 0; --__delay) { _mm_pause(); } } while (0)
# elif defined __GNUC__ && defined __aarch64__
# define INFER_PAUSE(v) do { for (int __delay = (v); __delay > 0; --__delay) { asm volatile("yield" ::: "memory"); } } while (0)
# elif defined __GNUC__ && defined __arm__
# define INFER_PAUSE(v) do { for (int __delay = (v); __delay > 0; --__delay) { asm volatile("" ::: "memory"); } } while (0)
# elif defined __GNUC__ && defined __riscv
// PAUSE HINT is not part of RISC-V ISA yet, but is under discussion now. For details see:
// https://github.com/riscv/riscv-isa-manual/pull/398
// https://github.com/riscv/riscv-isa-manual/issues/43
// # define INFER_PAUSE(v) do { for (int __delay = (v); __delay > 0; --__delay) { asm volatile("pause"); } } while (0)
# define INFER_PAUSE(v) do { for (int __delay = (v); __delay > 0; --__delay) { asm volatile("nop"); } } while (0)
# else
# warning "Can't detect 'pause' (CPU-yield) instruction on the target platform. Specify INFER_PAUSE() definition via compiler flags."
# define INFER_PAUSE(...) do { /* no-op: works, but not effective */ } while (0)
# endif
#endif // MTDA_PAUSE
// clang-format on

namespace inferllm {

/**
Expand Down Expand Up @@ -47,6 +74,13 @@ class ThreadPool {

uint32_t nr_threads() const { return m_nr_threads; }

//! The number of iterations < main thread yeild resource>
static constexpr int MAIN_THREAD_ACTIVE_WAIT = 10000;
//! The number of iterations < worker thread yeild resource>
static constexpr int WORKER_ACTIVE_WAIT = 2000;
//! The number of iterations <pause>
static constexpr int ACTIVE_WAIT_PAUSE_LIMIT = 16;

private:
uint32_t m_nr_threads = 1;
//! All the sub task number
Expand Down

0 comments on commit 9bc347c

Please sign in to comment.