Commit 886fedcd authored by Александр Гемуев's avatar Александр Гемуев
Browse files

sending bug / locator serialization

fixed a logical error - sending a message of request/post/push to only one step of the locator
parent e3f748de
#pragma once
#include <string>
#include "locator.h"
class CyclicLocator : public Locator
{
private:
int pos_;
static int locator_class_id_;
public:
virtual ~CyclicLocator() {}
int locator_class_id() const noexcept override;
static void set_locator_class_id(int locator_class_id);
~CyclicLocator() override = default;
explicit CyclicLocator(int pos);
int get_next_rank(Comm &) const noexcept override;
size_t get_serialization_size() const override;
CyclicLocator(int pos);
size_t serialize(void *buf, size_t buf_size) const override;
virtual int get_next_rank(Comm &) const noexcept;
size_t deserialize(const void *buf, size_t buf_size) override;
};
#pragma once
#include <memory>
#include "locator_creator.h"
#include "cyclic_locator.h"
#include "locator_factory.h"
class CyclicLocatorCreator: public LocatorCreator{
public:
CyclicLocatorCreator() = default;
std::unique_ptr<Locator> create() override;
~CyclicLocatorCreator() override = default;
};
#pragma once
#include <memory>
#include "serializable.h"
class Comm;
class LocatorFactory;
class Locator
class Locator : public Serializable
{
public:
virtual ~Locator() {}
virtual int locator_class_id() const noexcept = 0;
~Locator() override = default;
virtual int get_next_rank(Comm &) const = 0;
static size_t get_locator_serialization_size(const Locator &locator);
static void serialize_locator(void *&buf,
size_t &buf_size,
const Locator &locator);
virtual int get_next_rank(Comm &) const=0;
static std::unique_ptr<Locator> deserialize_locator(
LocatorFactory &locator_factory,
const void *&buf,
size_t &buf_size);
};
#pragma once
#include <memory>
#include "locator.h"
class LocatorCreator
{
public:
virtual std::unique_ptr<Locator> create() = 0;
virtual ~LocatorCreator() = default;
};
#pragma once
#include <memory>
#include <mutex>
#include <stdexcept>
#include <vector>
#include "locator_creator.h"
class LocatorFactory final
{
private:
std::vector<std::unique_ptr<LocatorCreator>> creators_;
public:
LocatorFactory() = default;
~LocatorFactory() = default;
std::unique_ptr<Locator> create_locator(int id);
int register_locator(std::unique_ptr<LocatorCreator> &&creator);
};
......@@ -10,6 +10,8 @@
#include "id.h"
#include "idle_stopper.h"
#include "locator.h"
#include "locator_factory.h"
#include "tags.h"
#include "thread_pool.h"
class Comm;
......@@ -22,6 +24,7 @@ class RTS
Comm *comm_;
const Config *conf_;
const FP *fp_;
LocatorFactory* loc_factory_;
IdleStopper<int> *stopper_;
bool finished_flag_;
ThreadPool pool_;
......@@ -31,7 +34,7 @@ class RTS
public:
virtual ~RTS();
RTS(Comm &, const Config &, const FP &);
RTS(Comm &, const Config &, const FP &, LocatorFactory &);
int run();
......@@ -82,6 +85,9 @@ private:
void _check_requests(const Id &id);
// needs to be forwarded to next node
bool _resend(const Locator& loc, Tags tag, void* buf, size_t size);
void _post(const Id &id, const DF &, int req_count);
void _request(const Id &id, std::function<void (const DF &)> cb);
void _destroy(const Id &id);
......
......@@ -8,7 +8,7 @@
class Serializable
{
public:
~Serializable() {}
virtual ~Serializable() = default;
virtual size_t get_serialization_size() const=0;
virtual size_t serialize(void *buf, size_t buf_size) const=0;
......
#include "cyclic_locator.h"
#include "cyclic_locator.h"
#include "comm.h"
CyclicLocator::CyclicLocator(int pos)
: pos_(pos)
{
}
int CyclicLocator::locator_class_id_ = -1;
CyclicLocator::CyclicLocator(int pos) : pos_(pos) {}
int CyclicLocator::get_next_rank(Comm &comm) const noexcept
{
if (pos_>=0) {
return pos_%comm.size();
int comm_size_int = static_cast<int>(comm.size());
if (pos_ >= 0) {
return pos_ % comm_size_int;
} else {
int rest=(-pos_)%comm.size();
return (comm.size()-rest)%comm.size();
return comm_size_int - ((-pos_) % comm_size_int);
}
}
size_t CyclicLocator::get_serialization_size() const
{
return sizeof(int);
}
size_t CyclicLocator::serialize(void *buf, size_t buf_size) const
{
assert(buf_size >= get_serialization_size());
size_t orig_buf_size = buf_size;
put<int>(buf, buf_size, pos_);
return orig_buf_size - buf_size;
}
size_t CyclicLocator::deserialize(const void *buf, size_t buf_size)
{
assert(buf_size >= get_serialization_size());
size_t orig_buf_size = buf_size;
pos_ = get<int>(buf, buf_size);
return orig_buf_size - buf_size;
}
int CyclicLocator::locator_class_id() const noexcept
{
return locator_class_id_;
}
void CyclicLocator::set_locator_class_id(int locator_class_id)
{
locator_class_id_ = locator_class_id;
}
#include "cyclic_locator_creator.h"
std::unique_ptr<Locator> CyclicLocatorCreator::create()
{
return std::unique_ptr<Locator>(new CyclicLocator(0));
}
#include "locator.h"
#include "locator_factory.h"
void Locator::serialize_locator(void *&buf,
size_t &buf_size,
const Locator &locator)
{
put<int>(buf, buf_size, locator.locator_class_id());
shift(buf, buf_size, locator.serialize(buf, buf_size));
}
std::unique_ptr<Locator> Locator::deserialize_locator(LocatorFactory& locator_factory,
const void *&buf,
size_t &buf_size)
{
int locator_class_id = get<int>(buf, buf_size);
std::unique_ptr<Locator> locator =
locator_factory.create_locator(locator_class_id);
shift(buf, buf_size, locator->deserialize(buf, buf_size));
return locator;
}
size_t Locator::get_locator_serialization_size(const Locator &locator)
{
return sizeof(int) + locator.get_serialization_size();
}
#include "locator_factory.h"
std::unique_ptr<Locator> LocatorFactory::create_locator(int id)
{
return creators_[id]->create();
}
int LocatorFactory::register_locator(std::unique_ptr<LocatorCreator> &&creator)
{
creators_.push_back(std::move(creator));
return static_cast<int>(creators_.size()) - 1;
}
......@@ -7,6 +7,14 @@
#include "mpi_comm.h"
#include "rts.h"
#include "cyclic_locator.h"
#include "cyclic_locator_creator.h"
static void init_loc_factory(LocatorFactory& locator_factory){
CyclicLocator::set_locator_class_id(locator_factory.register_locator(
std::unique_ptr<LocatorCreator>(new CyclicLocatorCreator())));
}
RTS *rts;
void init_mpi(int &argc, char **&argv)
......@@ -15,7 +23,7 @@ void init_mpi(int &argc, char **&argv)
MPI_Init_thread(&argc, &argv, desired, &provided);
if (provided!=desired) {
ABORT("Mpi thread safety level not provided: "
ABORT("Mpi thread safety level not provided: "
+ std::to_string(provided) + "<"
+ std::to_string(desired))
}
......@@ -51,7 +59,9 @@ int main(int argc, char **argv)
}
FP fp(conf.get_fp_path());
rts=new RTS(*comm, conf, fp);
LocatorFactory locator_factory;
init_loc_factory(locator_factory);
rts=new RTS(*comm, conf, fp, locator_factory);
auto old_handler=signal(SIGINT, ctrl_c_handler);
......
......@@ -20,8 +20,8 @@ RTS::~RTS()
delete stopper_;
}
RTS::RTS(Comm &comm, const Config &conf, const FP &fp)
: comm_(&comm), conf_(&conf), fp_(&fp), need_jobs_(false)
RTS::RTS(Comm &comm, const Config &conf, const FP &fp, LocatorFactory &loc_factory)
: comm_(&comm), conf_(&conf), fp_(&fp), loc_factory_(&loc_factory), need_jobs_(false)
{
comm_->set_handler([this](int from, int tag, void *buf, size_t size) {
on_recv(from, tag, buf, size);
......@@ -135,20 +135,21 @@ void RTS::post(const Id &id, const DF &val, const Locator &loc,
if (rank==comm_->rank()) {
_post(id, val, req_count);
} else {
// MSG format: [id] [df] [count]
size_t size=
id.get_serialization_size()
+val.get_serialization_size()
+sizeof(int);
void *buf=operator new(size);
// MSG format: [loc] [id] [df] [count]
size_t size = Locator::get_locator_serialization_size(loc) +
id.get_serialization_size() +
val.get_serialization_size() + sizeof(int);
void *buf = operator new(size);
size_t s = size;
void *b = buf;
size_t s=size;
void *b=buf;
Locator::serialize_locator(b, s, loc);
shift(b, s, id.serialize(b, s));
shift(b, s, val.serialize(b, s));
put<int>(b, s, req_count);
assert(s==0);
assert(s == 0);
comm_->send(rank, TAG_POST, buf, size, [buf](){
operator delete(buf);
......@@ -166,21 +167,26 @@ void RTS::request(const Id &id, const Locator &loc,
if (rank==comm_->rank()) {
_request(id, cb);
} else {
// Format: [id] [rCB]
size_t size=id.get_serialization_size() + sizeof(void*);
void *buf=operator new(size);
// Format: [] [locator] [id] [rCB] [original request sender`s rank]
const size_t size = Locator::get_locator_serialization_size(loc) +
id.get_serialization_size() + sizeof(void *) + sizeof(int);
void * const buf = operator new(size);
void *b=buf; size_t s=size;
void *b = buf;
size_t s = size;
Locator::serialize_locator(b, s, loc);
shift(b, s, id.serialize(b, s));
std::function<void (const DF &)> *fptr=nullptr;
fptr=new std::function<void (const DF &)>(
[cb, fptr](const DF &df){
std::function<void(const DF &)> *fptr = nullptr;
fptr = new std::function<void(const DF &)>([cb, fptr](const DF &df) {
cb(df);
delete fptr;
});
put<void*>(b, s, (void*)fptr);
put<void *>(b, s, (void *)fptr);
put<int>(b, s, comm_->rank());
comm_->send(rank, TAG_REQUEST, buf, size, [buf](){
operator delete(buf);
......@@ -265,16 +271,22 @@ void RTS::push(const Id &dfid, const DF &val, const Id &cfid,
if (rank==comm_->rank()) {
_push(dfid, val, cfid);
} else {
// Format dfid, val, cfid
size_t size=dfid.get_serialization_size()
+val.get_serialization_size()
+cfid.get_serialization_size();
void *buf=operator new(size);
void *b=buf; size_t s=size;
// Format [loc] dfid, val, cfid
size_t size = Locator::get_locator_serialization_size(loc) +
dfid.get_serialization_size() +
val.get_serialization_size() +
cfid.get_serialization_size();
void *buf = operator new(size);
void *b = buf;
size_t s = size;
Locator::serialize_locator(b, s, loc);
shift(b, s, dfid.serialize(b, s));
shift(b, s, val.serialize(b, s));
shift(b, s, cfid.serialize(b, s));
assert(s==0);
assert(s == 0);
comm_->send(rank, TAG_PUSH, buf, size, [buf]() {
operator delete(buf);
});
......@@ -383,67 +395,80 @@ void RTS::on_recv(int src, int tag, void *buf, size_t size)
break;
}
case TAG_POST: {
// MSG format: [id] [df] [count]
const void *b=buf; size_t s=size;
// MSG format: [loc] [id] [df] [count]
const void *b = buf;
size_t s = size;
std::unique_ptr<Locator> locator =
Locator::deserialize_locator(*loc_factory_, b, s);
if(_resend(*locator, TAG_POST, buf, size)){
break;
}
Id id;
shift(b, s, id.deserialize(b, s));
DF val;
shift(b, s, val.deserialize(b, s));
int req_count = get<int>(b, s);
int req_count=get<int>(b, s);
assert(s==0);
assert(s == 0);
operator delete(buf);
_post(id, val, req_count);
break;
}
case TAG_REQUEST: {
// Format: [id] [rCB]
const void *b=buf; size_t s=size;
// Format: [loc] [id] [rCB] [origin request sender`s rank]
const void *b = buf;
size_t s = size;
std::unique_ptr<Locator> locator =
Locator::deserialize_locator(*loc_factory_, b, s);
if(_resend(*locator, TAG_REQUEST, buf, size)){
break;
}
Id id;
shift(b, s, id.deserialize(b, s));
void *rptr=get<void *>(b, s);
void *rptr = get<void *>(b, s);
int origin_sender_rank = get<int>(b, s);
operator delete(buf);
_request(id, [rptr, src, this](const DF &df){
_request(id, [rptr, this, origin_sender_rank](const DF &df) {
// Format: [rCB] [df]
size_t size=sizeof(void *)
+df.get_serialization_size();
void *buf=operator new(size);
size_t size = sizeof(void *) + df.get_serialization_size();
void *buf = operator new(size);
void *b=buf; size_t s=size;
void *b = buf;
size_t s = size;
put<void *>(b, s, rptr);
shift(b, s, df.serialize(b, s));
assert(s==0);
assert(s == 0);
comm_->send(src, TAG_RESPONSE, buf, size, [buf](){
operator delete(buf);
});
comm_->send(origin_sender_rank, TAG_RESPONSE, buf, size,
[buf]() { operator delete(buf); });
});
break;
}
case TAG_RESPONSE: {
// Format: [rCB] [df]
const void *b=buf; size_t s=size;
// Format: [rCB] [df]
const void *b = buf;
size_t s = size;
void *rptr=get<void *>(b, s);
void *rptr = get<void *>(b, s);
DF df;
shift(b, s, df.deserialize(b, s));
assert(s==0);
assert(s == 0);
operator delete(buf);
std::function<void (const DF &)> *fptr=
static_cast<std::function<void (const DF &)> *>(rptr);
std::function<void(const DF &)> *fptr =
static_cast<std::function<void(const DF &)> *>(rptr);
(*fptr)(df);
......@@ -459,15 +484,22 @@ void RTS::on_recv(int src, int tag, void *buf, size_t size)
break;
}
case TAG_PUSH: {
// Format dfid, val, cfid
// Format [loc] dfid, val, cfid
const void *b = buf;
size_t s = size;
std::unique_ptr<Locator> locator =
Locator::deserialize_locator(*loc_factory_, b, s);
if(_resend(*locator, TAG_PUSH, buf, size)){
break;
}
Id dfid, cfid;
DF val;
const void *b=buf;
size_t s=size;
shift(b, s, dfid.deserialize(b, s));
shift(b, s, val.deserialize(b, s));
shift(b, s, cfid.deserialize(b, s));
assert(s==0);
assert(s == 0);
_push(dfid, val, cfid);
operator delete(buf);
break;
......@@ -687,3 +719,16 @@ void RTS::_push(const Id &dfid, const DF &val, const Id &cfid)
it->second.push_back(std::make_pair(dfid, val));
}
}
bool RTS::_resend(const Locator& loc, Tags tag, void* buf, size_t size)
{
int next_rank = loc.get_next_rank(*comm_);
if(next_rank != comm_->rank()){
comm_->send(next_rank, tag, buf, size,
[buf](){
operator delete(buf);
});
return true;
}
return false;
}
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment