ps-lite_part5_kvworker和kvsever

3,161次阅读
没有评论

共计 5762 个字符,预计需要花费 15 分钟才能阅读完成。

先回顾一下之前写到哪里了

  1. 介绍ps-lite的基本概念 https://www.deeplearn.me/4302.html
  2. 介绍ps-lite核心组成 postOffice https://www.deeplearn.me/4303.html
  3. 介绍ps-lite 通信模块van https://www.deeplearn.me/4306.html
  4. 介绍ps-lite 中介 customer https://www.deeplearn.me/4308.html

这篇文章主要讲一下server 和woker,在扒拉一下ps架构的一张图

ps-lite_part5_kvworker和kvsever

一般意义上来说:

  1. server负责梯度和参数的更新
  2. woker端负责前向和后向的计算

这也是之前有customer出现的缘故,server和worker集中去计算,负责通信的任务就交给customer。在上一节讲customer在哪里被创建的时候就提到kvworker和kvserver,这里在着重讲一下吧!

在这之前还是要补充一点kvworker 和kvserver都继承 SimpleApp,那么SimpleApp 又是啥?

SimpleApp:KVServer和KVWorker的父类,它提供了简单的Request, Wait, Response,Process功能;KVServer和KVWorker分别根据自己的使命重写了这些功能;

kvwoker

构造函数

 explicit KVWorker(int app_id, int customer_id) : SimpleApp() {
    using namespace std::placeholders;
    slicer_ = std::bind(&KVWorker<Val>::DefaultSlicer, this, _1, _2, _3);
    obj_ = new Customer(app_id, customer_id, std::bind(&KVWorker<Val>::Process, this, _1));
  }

这里关于构造函数的定义也在上一节提到了,此处略过哈!

PULL函数

从开始的图你也看到worker需要从server拉取参数数据,那么肯定需要pull。

 int Pull(const std::vector<Key>& keys,
           std::vector<Val>* vals,
           std::vector<int>* lens = nullptr,
           int cmd = 0,
           const Callback& cb = nullptr,
           int priority = 0) {
    SArray<Key> skeys(keys);
    int ts = AddPullCB(skeys, vals, lens, cmd, cb);
    KVPairs<Val> kvs;
    kvs.keys = skeys;
    kvs.priority = priority;
    Send(ts, false, true, cmd, kvs);
    return ts;
  }

这里面有两个需要关注的调用,AddPullCB 和 Send,依次来看下这两个函数的定义和功能

AddPullCB 是添加一个callback,这个callback等所有server返回结果之后在执行,可以认为是一个阻塞等操作。

int KVWorker<Val>::AddPullCB(
// C* vals和D* lens指向由调用者指定的结构体。
// 等所有server都返回后,从所有server拉来的数据
    const SArray<Key>& keys, C* vals, D* lens, int cmd,
// Callback& cb代表在所有server回复后要执行的额外的回调
// 一般我们都是在pull后就立刻阻塞等待,所以cb一般为空
    const Callback& cb) {
// ************** 创建request,返回的ts是该request_id
  int ts = obj_->NewRequest(kServerGroup);

// ************** 添加callback,等所有server都回复后再执行
  AddCallback(ts, [this, ts, keys, vals, lens, cb]() mutable {
      ......
      // 容纳ts(即request_id)所接受数据的缓冲区
      auto& kvs = recv_kvs_[ts];
      ......

      // total_keys是根据kvs统计出来的接收到的key的总数
      // keys是当初请求的所有keys,检查二者是否相等
      ......
      CHECK_EQ(total_key, keys.size()) << "lost some servers?";

// ************** 将所有server返回的数据,合并,填充到用户指定的输出位置
      // vals和lens都指向调用者传入的结构体
      // p_vals和p_lens都是指向输出区的指针
      Val* p_vals = vals->data();
      ......
        p_lens = lens->data();
      ......
      // 遍历从各台server接收到的内容,填充到输出区p_vals和p_lens
      for (const auto& s : kvs) {
        memcpy(p_vals, s.vals.data(), s.vals.size() * sizeof(Val));
        p_vals += s.vals.size();
        if (p_lens) {
          memcpy(p_lens, s.lens.data(), s.lens.size() * sizeof(int));
          p_lens += s.lens.size();
        }
      }
      ......
      recv_kvs_.erase(ts);//清空本次请求的接收缓冲区
      ......
      if (cb) cb();// 如果有额外的callback,执行之
    });

  return ts;
}

send的操作才是真正的去请求server,下面看下send的定义

void KVWorker<Val>::Send(int timestamp, bool push, bool pull, int cmd, const KVPairs<Val>& kvs) {
  // ****************** 决定要向哪些server发送请求
  SlicedKVs sliced;// 存储分配结果
  slicer_(kvs, Postoffice::Get()->GetServerKeyRanges(), &sliced);

  // ****************** 有些server不包含本次请求要求的keys,提前处理
  int skipped = 0;// 本次请求不涉及的servers的总数
  //这里调用first参数,需要去追溯一下SlicedKVs 的定义
  // using SlicedKVs = std::vector<std::pair<bool, KVPairs<Val>>>;
  // bool 参数决定是否需要去这个server节点拉取数据,不需要直接跳过
  for (size_t i = 0; i < sliced.size(); ++i) {
    if (!sliced[i].first) ++skipped;
  }
  // 内部不过是tracker_[timestamp].second += skipped
  // 假设这些不涉及的servers已经返回了
  obj_->AddResponse(timestamp, skipped);

  ......

  // ****************** 向所有涉及到的server发送请求
  for (size_t i = 0; i < sliced.size(); ++i) {
    const auto& s = sliced[i];
    if (!s.first) continue;//本次请求不需要访问的server节点直接跳过

    Message msg;
    msg.meta.app_id = obj_->app_id();
    msg.meta.customer_id = obj_->customer_id();
    msg.meta.request     = true;
    msg.meta.push        = push;
    msg.meta.pull        = pull;
    msg.meta.head        = cmd;
    msg.meta.timestamp   = timestamp;
    msg.meta.recver      = Postoffice::Get()->ServerRankToID(i);
    msg.meta.priority    = kvs.priority;

    const auto& kvs = s.second;//分配到当前节点上的key-value pairs
    if (kvs.keys.size()) {
      msg.AddData(kvs.keys);
      msg.AddData(kvs.vals);
      if (kvs.lens.size()) {
        msg.AddData(kvs.lens);
      }
    }
    //通过van通信模块发送请求
    Postoffice::Get()->van()->Send(msg);
  }
}

至此再回去看pull 应该就差不多了,除了pull之外还有一个zpull ,全称是zero pull,说是实现了零拷贝,起到一个加速的作用,这里就不细看了。

PUSH

说完pull 就是push了,woker的push 就是要把梯度传给server,让server 去更新参数。

  int ZPush(const SArray<Key>& keys,
            const SArray<Val>& vals,
            const SArray<int>& lens = {},
            int cmd = 0,
            const Callback& cb = nullptr,
            int priority = 0) {
    int ts = obj_->NewRequest(kServerGroup);
    AddCallback(ts, cb);
    KVPairs<Val> kvs;
    kvs.keys = keys;
    kvs.vals = vals;
    kvs.lens = lens;
    kvs.priority = priority;
    // send 将这些梯度传递到指定的server上
    Send(ts, true, false, cmd, kvs);
    return ts;
  }

同时也还有一个zpush,本质上实现的功能是一致的。

差不多 woker 就这些事情,接下来讲下server ,其实都差不多,因为只是各自干的事情内容又一点不一样而已。

kvserver

构造函数

explicit KVServer(int app_id) : SimpleApp() {
    using namespace std::placeholders;
    obj_ = new Customer(app_id, app_id, std::bind(&KVServer<Val>::Process, this, _1));
  }

Server 主要是处理参数更新和数据查询

  1. 参数更新:根据梯度更新相应的神经网络参数
  2. 数据查询:worker需要拉取参数去执行前向传播

完成上述需求主要依靠两个函数

Process

这个主要是来处理woker push 过来的数据

template <typename Val>
void KVServer<Val>::Process(const Message& msg) {
  if (msg.meta.simple_app) {
    SimpleApp::Process(msg); return;
  }
  KVMeta meta;
  meta.cmd       = msg.meta.head;
  meta.push      = msg.meta.push;
  meta.pull      = msg.meta.pull;
  meta.sender    = msg.meta.sender;
  meta.timestamp = msg.meta.timestamp;
  meta.customer_id = msg.meta.customer_id;
  //KVPairs 保存的就是传递的数据
  KVPairs<Val> data;
  int n = msg.data.size();
  if (n) {
    CHECK_GE(n, 2);
    data.keys = msg.data[0];
    data.vals = msg.data[1];
    if (n > 2) {
      CHECK_EQ(n, 3);
      data.lens = msg.data[2];
      CHECK_EQ(data.lens.size(), data.keys.size());
    }
  }
  CHECK(request_handle_);
  //这个request_handle_是用户自定义的处理逻辑函数,主要是梯度更新参数的规则等
  request_handle_(meta, data, this);
}

这里给出test里面的一个实例

void StartServer() {
  if (!IsServer()) return;
  auto server = new KVServer<float>(0);
  //这一步就是在设置 request_handle_
  server->set_request_handle(KVServerDefaultHandle<float>());
  RegisterExitCallback([server](){ delete server; });
}

Response

故名思义就是将数据回复给worker,好像没啥要讲的。。。

template <typename Val>
void KVServer<Val>::Response(const KVMeta& req, const KVPairs<Val>& res) {
  //res里存储的就是worker需要数据,这里只是在包装以 Message 封装一下,最后在通过send回复给worker
  Message msg;
  msg.meta.app_id = obj_->app_id();
  msg.meta.customer_id = req.customer_id;
  msg.meta.request     = false;
  msg.meta.push        = req.push;
  msg.meta.pull        = req.pull;
  msg.meta.head        = req.cmd;
  msg.meta.timestamp   = req.timestamp;
  msg.meta.recver      = req.sender;
  if (res.keys.size()) {
    msg.AddData(res.keys);
    msg.AddData(res.vals);
    if (res.lens.size()) {
      msg.AddData(res.lens);
    }
  }
  Postoffice::Get()->van()->Send(msg);
}
正文完
请博主喝杯咖啡吧!
post-qrcode
 
admin
版权声明:本站原创文章,由 admin 2023-03-23发表,共计5762字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)
验证码