ps-lite_part2_Postoffice讲解

3,444次阅读
2 条评论

共计 12731 个字符,预计需要花费 32 分钟才能阅读完成。

晚上也在思考这篇文章该以怎样的结构来撰写,ps-lite虽然是轻量版,但是很多东西都是在复用,东西都混在一起,所以在解释理清思路的时候有点绕。思来想去还是按照ps程序启动的逻辑来讲,涉及到本文需要讲的点在重点描写。

ps启动

启动一个完整的 ps 服务,需要启动scheduler、sever和worker,启动的时候都会产生一个实际的物理进程,这个进程里都会包含PostOffice,负责管理全局信息。

ps-lite给出一个简单的demo,启动的顺序是先启动schedule,然后是server 、worker

#!/bin/bash
# set -x
if [ # -lt 3 ]; then
    echo "usage:0 num_servers num_workers bin [args..]"
    exit -1;
fi

export DMLC_NUM_SERVER=1
shift
export DMLC_NUM_WORKER=1
shift
bin=1
shift
arg="@"

# start the scheduler
export DMLC_PS_ROOT_URI='127.0.0.1'
export DMLC_PS_ROOT_PORT=8000
export DMLC_ROLE='scheduler'
{bin}{arg} &


# start servers
export DMLC_ROLE='server'
for ((i=0; i<{DMLC_NUM_SERVER}; ++i)); do
    export HEAPPROFILE=./S{i}
    {bin}{arg} &
done

# start workers
export DMLC_ROLE='worker'
for ((i=0; i<{DMLC_NUM_WORKER}; ++i)); do
    export HEAPPROFILE=./W{i}
    {bin}{arg} &
done

从上面的启动程序可以看到,ps-lite启动任务一些参数是来自环境变量的,你会看到shell脚本里充斥着export。

那我们现在看看它是怎么启动程序的?看一个demo

#include <cmath>
#include "ps/ps.h"

using namespace ps;

void StartServer() {
  if (!IsServer()) {
    return;
  }
  auto server = new KVServer<float>(0);
  server->set_request_handle(KVServerDefaultHandle<float>());
  RegisterExitCallback([server](){ delete server; });
}

void RunWorker() {
  if (!IsWorker()) return;
  KVWorker<float> kv(0, 0);

  // init
  int num = 10000;
  std::vector<Key> keys(num);
  std::vector<float> vals(num);

  int rank = MyRank();
  srand(rank + 7);
  for (int i = 0; i < num; ++i) {
    keys[i] = kMaxKey / num * i + rank;
    vals[i] = (rand() % 1000);
  }

  // push
  int repeat = 50;
  std::vector<int> ts;
  for (int i = 0; i < repeat; ++i) {
    ts.push_back(kv.Push(keys, vals));

    // to avoid too frequency push, which leads huge memory usage
    if (i > 10) kv.Wait(ts[ts.size()-10]);
  }
  for (int t : ts) kv.Wait(t);

  // pull
  std::vector<float> rets;
  kv.Wait(kv.Pull(keys, &rets));

  // pushpull
  std::vector<float> outs;
  for (int i = 0; i < repeat; ++i) {
    // PushPull on the same keys should be called serially
    kv.Wait(kv.PushPull(keys, vals, &outs));
  }

  float res = 0;
  float res2 = 0;
  for (int i = 0; i < num; ++i) {
    res += std::fabs(rets[i] - vals[i] * repeat);
    res2 += std::fabs(outs[i] - vals[i] * 2 * repeat);
  }
  CHECK_LT(res / repeat, 1e-5);
  CHECK_LT(res2 / (2 * repeat), 1e-5);
  LL << "error: " << res / repeat << ", " << res2 / (2 * repeat);
}

int main(int argc, char *argv[]) {
  // start system ,这里会根据实际的角色名执行相应的逻辑,比如一开始执行是scheduler的启动任务,
  // 会以scheduler的角色启动,后面的StartServer也是不会执行的,当然这里会在是server角色的时候还会在执行一次
  Start(0);
  // setup server nodes
  StartServer();
  // run worker nodes
  RunWorker();
  // stop system
  Finalize(0, true);
  return 0;
}

ps-lite在这里其实共用了一套代码,也就是说你启动scheduler、sever和worker 这些都会走一套启动代码,根据不同的角色名称去执行相应的代码逻辑,比如只有在角色scheduler的时候才会触发scheduler相关的代码逻辑。

PostOffice启动

接下来就先以scheduler 启动来介绍

  Start(0);
//实际调用的方法
inline void Start(int customer_id, const char* argv0 = nullptr) {
  Postoffice::Get()->Start(customer_id, argv0, true);
}

这里会有一个

Postoffice::Get()

调用Get方法是去获取PostOffice全局单例对象,这里想要强调的一点就是PostOffice是单例,即一个进程内只有这一个对象,全局变量。

无论scheduler还是worker 都会调用,那么你可以理解这里PostOffice单例是相对而言的,scheduler进程下有一个,sever进程下也有一个。

接下来再来看看 Start 函数做了哪些事情?

void Postoffice::Start(int customer_id, const char* argv0, const bool do_barrier) {
  start_mu_.lock();
  if (init_stage_ == 0) {
    // 初始化环境变量,主要是从 shell 执行脚本中获取相关参数比如 role 角色变量、server 数量和worker 数量
    InitEnvironment();
    // init glog
    if (argv0) {
      dmlc::InitLogging(argv0);
    } else {
      dmlc::InitLogging("ps-lite\0");
    }

    // init node info.
    for (int i = 0; i < num_workers_; ++i) {
      int id = WorkerRankToID(i);
      for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                    kWorkerGroup + kScheduler,
                    kWorkerGroup + kServerGroup + kScheduler}) {
        node_ids_[g].push_back(id);
      }
    }

    for (int i = 0; i < num_servers_; ++i) {
      int id = ServerRankToID(i);
      for (int g : {id, kServerGroup, kWorkerGroup + kServerGroup,
                    kServerGroup + kScheduler,
                    kWorkerGroup + kServerGroup + kScheduler}) {
        node_ids_[g].push_back(id);
      }
    }

    for (int g : {kScheduler, kScheduler + kServerGroup + kWorkerGroup,
                  kScheduler + kWorkerGroup, kScheduler + kServerGroup}) {
      node_ids_[g].push_back(kScheduler);
    }
    init_stage_++;
  }
  start_mu_.unlock();

  // start van
  van_->Start(customer_id);

  start_mu_.lock();
  if (init_stage_ == 1) {
    // record start time
    start_time_ = time(NULL);
    init_stage_++;
  }
  start_mu_.unlock();
  // do a barrier here
  if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);
}

现在来看看这个初始化环境的函数做了哪些事情?

void Postoffice::InitEnvironment() {
  const char* val = NULL;
  std::string van_type = GetEnv("DMLC_PS_VAN_TYPE", "zmq");
  // 核心的一个点就是创建了 Van,至于Van 是什么后续也会做相应的详细介绍
  van_ = Van::Create(van_type);
  //接下来都是在解析环境变量,对于我们而言就是判断这次启动的是哪个?
  val = CHECK_NOTNULL(Environment::Get()->find("DMLC_NUM_WORKER"));
  num_workers_ = atoi(val);
  val =  CHECK_NOTNULL(Environment::Get()->find("DMLC_NUM_SERVER"));
  num_servers_ = atoi(val);
  val = CHECK_NOTNULL(Environment::Get()->find("DMLC_ROLE"));
  std::string role(val);
  is_worker_ = role == "worker";
  is_server_ = role == "server";
  is_scheduler_ = role == "scheduler";
  verbose_ = GetEnv("PS_VERBOSE", 0);
}

ok,让我们再回到PostOffice的start函数里,假设我们启动的是scheduler角色下的程序

    for (int g : {kScheduler, kScheduler + kServerGroup + kWorkerGroup,
                  kScheduler + kWorkerGroup, kScheduler + kServerGroup}) {
      node_ids_[g].push_back(kScheduler);
    }
    init_stage_++;

那这个node_ids_ 是干啥的?这里就需要讨论下 node_id 的概念

Node管理

其实是可以分为两个部分:node group 和 single node_id

首先我们介绍下 node id 映射功能,就是如何在逻辑节点和物理节点之间做映射,如何把物理节点划分成各个逻辑组,如何用简便的方法做到给组内物理节点统一发消息。

  • 1,2,4分别标识Scheduler, ServerGroup, WorkerGroup。
  • SingleWorker:rank * 2 + 9;SingleServer:rank * 2 + 8。
  • 任意一组节点都可以用单个id标识,等于所有id之和。

概念

  • Rank 是一个逻辑概念,是每一个节点(scheduler,work,server)内部的唯一逻辑标示。
  • Node id 是物理节点的唯一标识,可以和一个 host + port 的二元组唯一对应。
  • Node Group 是一个逻辑概念,每一个 group 可以包含多个 node id。ps-lite 一共有三组 group : scheduler 组,server 组,worker 组。
  • Node group id 是 是节点组的唯一标示。
    • ps-lite 使用 1,2,4 这三个数字分别标识 Scheduler,ServerGroup,WorkerGroup。每一个数字都代表着一组节点,等于所有该类型节点 id 之和。比如 2 就代表server 组,就是所有 server node 的组合。
    • 为什么选择这三个数字?因为在二进制下这三个数值分别是 “001, 010, 100″,这样如果想给多个 group 发消息,直接把 几个 node group id 做 或操作 就行。
    • 即 1-7 内任意一个数字都代表的是Scheduler / ServerGroup / WorkerGroup的某一种组合。
    • 如果想把某一个请求发送给所有的 worker node,把请求目标节点 id 设置为 4 即可。
    • 假设某一个 worker 希望向所有的 server 节点 和 scheduler 节点同时发送请求,则只要把请求目标节点的 id 设置为 3 即可,因为 3 = 2 + 1 = kServerGroup + kScheduler。
    • 如果想给所有节点发送消息,则设置为 7 即可。

逻辑组的实现

三个逻辑组的定义如下:

/** \brief node ID for the scheduler */
static const int kScheduler = 1;
/**
 * \brief the server node group ID
 *
 * group id can be combined:
 * - kServerGroup + kScheduler means all server nodes and the scheuduler
 * - kServerGroup + kWorkerGroup means all server and worker nodes
 */
static const int kServerGroup = 2;
/** \brief the worker node group ID */
static const int kWorkerGroup = 4;
for (int i = 0; i < num_workers_; ++i) {
      int id = WorkerRankToID(i);
      for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                    kWorkerGroup + kScheduler,
                    kWorkerGroup + kServerGroup + kScheduler}) {
        node_ids_[g].push_back(id);
      }

如下面代码所示,如果配置了 3 个worker,则 worker 的 rank 从 0 ~ 2,那么这几个 worker 实际对应的 物理 node ID 就会使用 WorkerRankToID 来计算出来。

node id 是物理节点的唯一标示,rank 是每一个逻辑概念(scheduler,work,server)内部的唯一标示。这两个标示由一个算法来确定。

Rank vs node id

node id 是物理节点的唯一标示,rank 是每一个逻辑概念(scheduler,work,server)内部的唯一标示。这两个标示由一个算法来确定。

如下面代码所示,如果配置了 3 个worker,则 worker 的 rank 从 0 ~ 2,那么这几个 worker 实际对应的 物理 node ID 就会使用 WorkerRankToID 来计算出来。

    for (int i = 0; i < num_workers_; ++i) {
      int id = WorkerRankToID(i);
      for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                    kWorkerGroup + kScheduler,
                    kWorkerGroup + kServerGroup + kScheduler}) {
        node_ids_[g].push_back(id);
      }
    }

具体计算规则如下:

  /**
   * \brief convert from a worker rank into a node id
   * \param rank the worker rank
   */
  static inline int WorkerRankToID(int rank) {
    return rank * 2 + 9;
  }
  /**
   * \brief convert from a server rank into a node id
   * \param rank the server rank
   */
  static inline int ServerRankToID(int rank) {
    return rank * 2 + 8;
  }
  /**
   * \brief convert from a node id into a server or worker rank
   * \param id the node id
   */
  static inline int IDtoRank(int id) {
#ifdef _MSC_VER
#undef max
#endif
    return std::max((id - 8) / 2, 0);
  }


  • SingleWorker:rank * 2 + 9;
  • SingleServer:rank * 2 + 8;

而且这个算法保证server id为偶数,node id为奇数。

这样我们可以知道,1-7 的id表示的是node group,单个节点的id 就从 8 开始。

具体计算规则如下:

Group vs node

因为有时请求要发送给多个节点,所以ps-lite用了一个 map 来存储每个 node group / single node 对应的实际的node节点集合,即 确定每个id值对应的节点id集。

std::unordered_map<int, std::vector<int>> node_ids_ 

    for (int i = 0; i < num_workers_; ++i) {
      int id = WorkerRankToID(i);
      for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                    kWorkerGroup + kScheduler,
                    kWorkerGroup + kServerGroup + kScheduler}) {
        node_ids_[g].push_back(id);
      }
    }

这 5 个id 相对应,即需要在 node_ids_ 这个映射表中对应的 4, 4 + 1, 4 + 2, 4 +1 + 2, 12 这五个 item 之中添加。就是上面代码中的内部 for 循环条件。即,node_ids_ [4], node_ids_ [5],node_ids_ [6],node_ids_ [7] ,node_ids_ [12] 之中,都需要把 12 添加到 vector 最后。

  • 12(本身)
  • 4(kWorkerGroup)
  • 4+1(kWorkerGroup + kScheduler)
  • 4+2(kWorkerGroup + kServerGroup)
  • 4+1+2,(kWorkerGroup + kServerGroup + kScheduler )

所以,为了实现 “设置 1-7 内任意一个数字 可以发送给其对应的 所有node” 这个功能,对于每一个新节点,需要将其对应多个id(node,node group)上,这些id组就是本节点可以与之通讯的节点。例如对于 worker 2 来说,其 node id 是 2 * 2 + 8 = 12,所以需要将它与

  • 1 ~ 7 的 id 表示的是 node group;
  • 后续的 id(8,9,10,11 …)表示单个的 node。其中双数 8,10,12… 表示 worker 0, worker 1, worker 2,… 即(2n + 8),9,11,13,…,表示 server 0, server 1,server 2,…,即(2n + 9);

还是花了不少的功夫在讲解node,那么这个node 的标记是用来干啥的?

这些node的标记实际上与我们的worker还有server都是对应的关心,所以通过这些node标记就可以快速找打,这样通信同步一些数据就方便。

在记录完node_id之后,开始调用Van的启动程序。Van其实是一个通信模块。Van的东西还是蛮多的,打算放在下一篇文章里讲了。

在继续就是讲到 Barrier

  if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);

Barrier

同步

总的来讲,schedular节点通过计数的方式实现各个节点的同步。具体来说就是:

  • 每个节点在自己指定的命令运行完后会向schedular节点发送一个Control::BARRIER命令的请求并自己阻塞直到收到schedular对应的返回后才解除阻塞;
  • schedular节点收到请求后则会在本地计数,看收到的请求数是否和barrier_group的数量是否相等,相等则表示每个机器都运行完指定的命令了,此时schedular节点会向barrier_group的每个机器发送一个返回的信息,并解除其阻塞。

初始化

ps-lite 使用 Barrier 来控制系统的初始化,就是大家都准备好了再一起前进。这是一个可选项。具体如下:

  • Scheduler等待所有的worker和server发送BARRIER信息;
  • 在完成ADD_NODE后,各个节点会进入指定 group 的Barrier阻塞同步机制(发送 BARRIER 给 Scheduler),以保证上述过程每个节点都已经完成;
  • 所有节点(worker和server,包括scheduler) 等待scheduler收到所有节点 BARRIER 信息后的应答;
  • 最终所有节点收到scheduler 应答的Barrier message后退出阻塞状态;
等待 BARRIER 消息

Node会调用 Barrier 函数 告知Scheduler,随即自己进入等待状态。

注意,调用时候是

if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);  

复制代码
void Postoffice::Barrier(int customer_id, int node_group) {
  if (GetNodeIDs(node_group).size() <= 1) return;
  auto role = van_->my_node().role;
  if (role == Node::SCHEDULER) {
    CHECK(node_group & kScheduler);
  } else if (role == Node::WORKER) {
    CHECK(node_group & kWorkerGroup);
  } else if (role == Node::SERVER) {
    CHECK(node_group & kServerGroup);
  }

  std::unique_lock<std::mutex> ulk(barrier_mu_);
  barrier_done_[0][customer_id] = false;
  Message req;
  req.meta.recver = kScheduler;
  req.meta.request = true;
  req.meta.control.cmd = Control::BARRIER;
  req.meta.app_id = 0;
  req.meta.customer_id = customer_id;
  req.meta.control.barrier_group = node_group; // 记录了等待哪些
  req.meta.timestamp = van_->GetTimestamp();
  van_->Send(req); // 给 scheduler 发给 BARRIER
  barrier_cond_.wait(ulk, [this, customer_id] { // 然后等待
      return barrier_done_[0][customer_id];
    });
}

这就是说,等待所有的 group,即 scheduler 节点也要给自己发送消息。

处理 BARRIER 消息

处理等待的动作在 Van 类之中,我们提前放出来。

具体ProcessBarrierCommand逻辑如下:

  • 如果 msg->meta.request 为true,说明是 scheduler 收到消息进行处理。
    • Scheduler会对Barrier请求进行增加计数。
    • 当 Scheduler 收到最后一个请求时(计数等于此group节点总数),则将计数清零,发送结束Barrier的命令。这时候 meta.request 设置为 false;
    • 向此group所有节点发送request==falseBARRIER消息。
  • 如果 msg->meta.request 为 false,说明是收到消息这个 respones,可以解除barrier了,于是进行处理,调用 Manage 函数 。
    • Manage 函数 将app_id对应的所有costomer的barrier_done_置为true,然后通知所有等待条件变量barrier_cond_.notify_all()
void Van::ProcessBarrierCommand(Message* msg) {
  auto& ctrl = msg->meta.control;
  if (msg->meta.request) {  // scheduler收到了消息,因为 Postoffice::Barrier函数 会在发送时候做设置为true。
    if (barrier_count_.empty()) {
      barrier_count_.resize(8, 0);
    }
    int group = ctrl.barrier_group;
    ++barrier_count_[group]; // Scheduler会对Barrier请求进行计数
    if (barrier_count_[group] ==
        static_cast<int>(Postoffice::Get()->GetNodeIDs(group).size())) { // 如果相等,说明已经收到了最后一个请求,所以发送解除 barrier 消息。
      barrier_count_[group] = 0;
      Message res;
      res.meta.request = false; // 回复时候,这里就是false
      res.meta.app_id = msg->meta.app_id;
      res.meta.customer_id = msg->meta.customer_id;
      res.meta.control.cmd = Control::BARRIER;
      for (int r : Postoffice::Get()->GetNodeIDs(group)) {
        int recver_id = r;
        if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) {
          res.meta.recver = recver_id;
          res.meta.timestamp = timestamp_++;
          Send(res);
        }
      }
    }
  } else { // 说明这里收到了 barrier respones,可以解除 barrier了。具体见上面的设置为false处。
    Postoffice::Get()->Manage(*msg);
  }
}


Manage 函数就是解除了 barrier。

void Postoffice::Manage(const Message& recv) {
  CHECK(!recv.meta.control.empty());
  const auto& ctrl = recv.meta.control;
  if (ctrl.cmd == Control::BARRIER && !recv.meta.request) {
    barrier_mu_.lock();
    auto size = barrier_done_[recv.meta.app_id].size();
    for (size_t customer_id = 0; customer_id < size; customer_id++) {
      barrier_done_[recv.meta.app_id][customer_id] = true;
    }
    barrier_mu_.unlock();
    barrier_cond_.notify_all(); // 这里解除了barrier
  }
}

在上面的启动程序中可能没见到下面两个函数的调用,但是这也是 Postoffice 重要的成员组成

数据key分布式存储

到现在为止,邮车和customer都有了,信件本身无非就是embedding这些参数,但是这些参数的存放也是有讲究的,这也是在上一篇文章中提到的分布式存储,这个分布式是如何体现的?

const std::vector<Range>& Postoffice::GetServerKeyRanges() {
  server_key_ranges_mu_.lock();
  //循环遍历所有的server,配置server key 的范围
  //本质上就是根据server的数量均匀划分而已,就是这么简单
  if (server_key_ranges_.empty()) {
    for (int i = 0; i < num_servers_; ++i) {
      server_key_ranges_.push_back(Range(
          kMaxKey / num_servers_ * i,
          kMaxKey / num_servers_ * (i+1)));
    }
  }
  server_key_ranges_mu_.unlock();
  return server_key_ranges_;
}

通过以上的操作的确解决了数据分布式存储,而且可以明确在worker向server端拉取数据的时候要去哪个server拉数据的问题。

用户管理

现在大概知道了邮车,那么怎么知道要给哪些customer送信件呢?邮局需要管理一份用户的名单。

Customer* Postoffice::GetCustomer(int app_id, int customer_id, int timeout) const {
  Customer* obj = nullptr;
  for (int i = 0; i < timeout * 1000 + 1; ++i) {
    {
      std::lock_guard<std::mutex> lk(mu_);
      // app_id 是对应 kv存储的id,举个例子FM 里存在一阶weight app_id=0
      // 通过app_id 去寻找customer,一般 worker 会有多个thread 对应不同的customer
        //但是消费的都是同一个 kv,所以根据app_id可以找到对应的 customer
      const auto it = customers_.find(app_id);
      if (it != customers_.end()) {
        std::unordered_map<int, Customer*> customers_in_app = it->second;
        obj = customers_in_app[customer_id];
        break;
      }
    }
    std::this_thread::sleep_for(std::chrono::milliseconds(1));
  }
  return obj;
}

这个 GetCustomer 的操作主要是在Van中的 ProcessDataMsg 调用,这里就是Van要把传递的信件交给customer,然后通过 GetCustomer 这个方式来获取相应的customer。

上面的函数列的是读取,还有 AddCustomer 和 RemoveCustomer 负责添加和删除。

正文完
请博主喝杯咖啡吧!
post-qrcode
 
admin
版权声明:本站原创文章,由 admin 2023-03-09发表,共计12731字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(2 条评论)
验证码