Commit 990b599f authored by a-chmil's avatar a-chmil
Browse files

Added callback on_empty in thread_pool

parent 15e1b703
......@@ -57,7 +57,9 @@ sub main()
req_count B[i][j]=N;
};
cf calc[i][j]: calc_mat(A, B, C[i][j], i, j, N);
cf calc[i][j]: calc_mat(A, B, C[i][j], i, j, N) @{
stealable;
};
empty() @ {
request C[i][j];
......
......@@ -69,7 +69,7 @@ public:
bool migrate(const Locator &loc);
bool check_steal();
bool check_steal();
void destroy(const Id &id, const Locator &);
......
......@@ -29,6 +29,7 @@ public:
std::string get_help() const;
std::string get_version() const;
bool dynamic_balance() const;
unsigned int get_worker_threads_count() const noexcept;
unsigned int get_comm_request_threads_count() const noexcept;
......@@ -38,4 +39,5 @@ private:
RunMode mode_;
std::string fp_path_;
std::vector<std::string> argv_;
bool dynamic_balance_;
};
......@@ -27,6 +27,7 @@ class RTS
ThreadPool pool_;
std::mutex m_;
std::condition_variable cv_;
bool need_jobs_;
public:
virtual ~RTS();
......@@ -56,9 +57,9 @@ public:
void unexpect_pushes(const Id &cfid);
int get_steal_req();
int get_steal_req();
void steal_req();
void steal_req(int tag);
void change_load(int delta) {
stopper_->change_works(delta);
......@@ -77,7 +78,7 @@ private:
std::map<Id, std::vector<std::pair<Id, DF> > > pushed_;
std::map<Id, std::function<void (const Id &, const DF &)> > waiters_;
std::set<int> steal_requests_;
std::set<int> steal_requests_;
void _check_requests(const Id &id);
......
......@@ -10,5 +10,6 @@ enum Tags {
TAG_RESPONSE,
TAG_DESTROY,
TAG_PUSH,
TAG_STEAL
TAG_STEAL,
TAG_STEAL_REVOKE
};
......@@ -24,6 +24,10 @@ public:
void on_empty(std::function<void()>);
void on_submit(std::function<void()>);
bool has_jobs();
virtual std::string to_string() const;
private:
......@@ -31,7 +35,8 @@ private:
std::condition_variable cv_;
std::vector<std::thread*> threads_;
std::queue<std::function<void()> > jobs_;
std::function<void()> on_empty_handler_;
std::function<void()> on_empty_handler_;
std::function<void()> on_submit_handler_;
size_t running_jobs_;
bool stop_flag_;
......
......@@ -119,6 +119,7 @@ def parse_args(args):
conf['DEBUG']=False
conf['CLEANUP']=True
conf['TIME']=False
conf['BALANCE']=False
VERBOSE_FLAG=False
COMPILE_ONLY_FLAG=False
......@@ -155,6 +156,8 @@ def parse_args(args):
conf['CLEANUP']=False
elif arg=='-t':
conf['TIME']=True
elif arg=='-b':
conf['BALANCE']=True
else:
if arg[cur].startswith('-'):
warn(1, "suspicious program name: '%s' (mistyped a key?)" \
......@@ -681,6 +684,8 @@ def main():
rts='rts.dbg' if conf['DEBUG'] else 'rts'
cmd=[os.path.join(conf['LUNA_HOME'], 'bin', rts),
os.path.join(conf['BUILD_DIR'], 'libucodes.so')] + conf['ARGV']
if conf['BALANCE']:
cmd+=' -b '
env=dict(os.environ)
env['LD_LIBRARY_PATH']=env.get('LD_LIBRARY_PATH', '') \
+ ':' + os.path.join(conf['LUNA_HOME'], 'lib')
......
......@@ -6,7 +6,7 @@
#include "common.h"
Config::Config(int argc, char **argv)
: mode_(UNSET)
: mode_(UNSET), dynamic_balance_(false)
{
assert(argc>=1);
program_name_=argv[0];
......@@ -22,6 +22,8 @@ Config::Config(int argc, char **argv)
} else if (arg=="--version") {
mode_=VERSION;
break;
} else if (arg=="-b"){
dynamic_balance_ = true;
} else {
if (mode_==UNSET) {
mode_=NORMAL;
......@@ -87,6 +89,11 @@ std::string Config::get_version() const
return os.str();
}
bool Config::dynamic_balance() const
{
return dynamic_balance_;
}
unsigned int Config::get_worker_threads_count() const noexcept
{
return DEFAULT_WORKER_THREADS_COUNT;
......
......@@ -21,7 +21,7 @@ RTS::~RTS()
}
RTS::RTS(Comm &comm, const Config &conf, const FP &fp)
: comm_(&comm), conf_(&conf), fp_(&fp)
: comm_(&comm), conf_(&conf), fp_(&fp), need_jobs_(false)
{
comm_->set_handler([this](int from, int tag, void *buf, size_t size) {
on_recv(from, tag, buf, size);
......@@ -38,7 +38,24 @@ RTS::RTS(Comm &comm, const Config &conf, const FP &fp)
comm_->bcast(TAG_STOP);
}
);
comm_->barrier();
if(conf.dynamic_balance()) {
pool_.on_empty([this]() {
if (!need_jobs_) {
need_jobs_ = true;
steal_req(TAG_STEAL);
}
});
pool_.on_submit([this]() {
if (need_jobs_) {
need_jobs_ = false;
steal_req(TAG_STEAL_REVOKE);
}
});
}
comm_->barrier();
}
int RTS::run()
......@@ -46,10 +63,6 @@ int RTS::run()
double start_time=wtime();
comm_->barrier();
std::unique_lock<std::mutex> lk(m_);
pool_.on_empty([this](){
steal_req();
});
pool_.start(conf_->get_worker_threads_count());
finished_flag_=false;
......@@ -309,19 +322,15 @@ void RTS::unexpect_pushes(const Id &cfid)
waiters_.erase(it);
}
void RTS::steal_req()
void RTS::steal_req(int tag)
{
int rank = comm_->rank();
int req_rank = (rank + 1) % comm_->size();
if(req_rank != rank) {
comm_->send(comm_->next_rank(), TAG_STEAL, nullptr, 0);
}
comm_->bcast(tag);
}
int RTS::get_steal_req()
{
std::lock_guard<std::mutex> lk(m_);
if(!steal_requests_.empty()) {
if(pool_.has_jobs() && !steal_requests_.empty()) {
int rank=*steal_requests_.begin();
steal_requests_.erase(rank);
return rank;
......@@ -454,10 +463,24 @@ void RTS::on_recv(int src, int tag, void *buf, size_t size)
break;
}
case TAG_STEAL: {
if(src==comm_->rank()){
return;
}
std::lock_guard <std::mutex> lock(m_);
LOG("Steal request from " + std::to_string(src))
steal_requests_.insert(src);
operator delete(buf);
break;
}
case TAG_STEAL_REVOKE: {
if(src==comm_->rank()){
return;
}
std::lock_guard <std::mutex> lock(m_);
LOG("Steal revoke request from " + std::to_string(src))
steal_requests_.erase(src);
operator delete(buf);
break;
}
default:
ABORT("Tag not implemented: " + std::to_string(tag));
......@@ -542,8 +565,6 @@ void RTS::_submit(CF *cf)
+ std::to_string(ret));
}
} while (ret==CONTINUE);
});
}
......
......@@ -5,7 +5,7 @@
#include "common.h"
ThreadPool::ThreadPool()
: running_jobs_(0), stop_flag_(false), on_empty_handler_([](){})
: on_empty_handler_([](){}), on_submit_handler_([](){}), running_jobs_(0), stop_flag_(false)
{
stop();
}
......@@ -13,7 +13,6 @@ ThreadPool::ThreadPool()
void ThreadPool::start(size_t threads_num)
{
std::lock_guard<std::mutex> lk(m_);
if (stop_flag_) {
throw std::runtime_error("start while stopping ThreadPool");
}
......@@ -28,7 +27,6 @@ void ThreadPool::start(size_t threads_num)
void ThreadPool::stop()
{
std::unique_lock<std::mutex> lk(m_);
if (stop_flag_) {
throw std::runtime_error("stop while stopping ThreadPool");
}
......@@ -55,33 +53,43 @@ void ThreadPool::stop()
void ThreadPool::submit(std::function<void()> job)
{
std::lock_guard<std::mutex> lk(m_);
jobs_.push(job);
on_submit_handler_();
cv_.notify_one();
}
void ThreadPool::on_empty(std::function<void()> on_empty_handler)
{
std::lock_guard<std::mutex> lk(m_);
on_empty_handler_ = on_empty_handler;
}
void ThreadPool::on_submit(std::function<void()> on_submit_handler)
{
std::lock_guard<std::mutex> lk(m_);
on_submit_handler_ = on_submit_handler;
}
std::string ThreadPool::to_string() const
{
std::lock_guard<std::mutex> lk(m_);
return std::to_string(threads_.size()) + "Th "
+ std::to_string(jobs_.size()) + "Jb "
+ std::to_string(running_jobs_) + " RJ "
+ (stop_flag_? "S": "");
}
bool ThreadPool::has_jobs()
{
std::lock_guard<std::mutex> lk(m_);
return !jobs_.empty();
}
void ThreadPool::routine()
{
std::unique_lock<std::mutex> lk(m_);
while (!stop_flag_ || !jobs_.empty() || running_jobs_>0) {
if (jobs_.empty()) {
on_empty_handler_();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment