zoukankan      html  css  js  c++  java
  • 基于ray的分布式机器学习(二)

    基本思路
    基于parameter server + multiple workers模式。
    同步方式
    parameter server负责网络参数的统一管理,每次迭代均将参数发送给每一个worker,多个worker同时迭代数据集,计算当前批次的损失和梯度,
    当所有worker全部完成当前批次的计算后,将每个worker的梯度回传给parameter server,parameter server使用该梯度进行参数优化。
    异步方式
    与同步方式不同的是,parameter server不需要每次等待所有worker全部完成一个批次的计算后再利用所有worker的梯度更新网络参数,
    而是每当有一个worker完成一个批次的计算时,立刻进行网络参数的更新,并将新的参数下发给该worker。
    1、定义模型 class ConvNet(nn.Module): def __init__(self): super(ConvNet, self).__init__() self.conv1 = nn.Conv2d(1, 3, kernel_size=3) self.fc = nn.Linear(192, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 3)) x = x.view(-1, 192) x = self.fc(x) return F.log_softmax(x, dim=1) def get_weights(self): return {k: v.cpu() for k, v in self.state_dict().items()} def set_weights(self, weights): self.load_state_dict(weights) def get_gradients(self): grads = [] for p in self.parameters(): grad = None if p.grad is None else p.grad.data.cpu().numpy() grads.append(grad) return grads def set_gradients(self, gradients): for g, p in zip(gradients, self.parameters()): if g is not None: p.grad = torch.from_numpy(g) 2、定义parameter server @ray.remote class ParameterServer(object): def __init__(self, lr): self.model = ConvNet() self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr) def apply_gradients(self, *gradients): summed_gradients = [ np.stack(gradient_zip).sum(axis=0) for gradient_zip in zip(*gradients) ] self.optimizer.zero_grad() self.model.set_gradients(summed_gradients) self.optimizer.step() return self.model.get_weights() def get_weights(self): return self.model.get_weights() 3、定义worker @ray.remote class DataWorker(object): def __init__(self): self.model = ConvNet() self.data_iterator = iter(get_data_loader()[0]) def compute_gradients(self, weights): self.model.set_weights(weights) try: data, target = next(self.data_iterator) except StopIteration: # When the epoch ends, start a new epoch. self.data_iterator = iter(get_data_loader()[0]) data, target = next(self.data_iterator) self.model.zero_grad() output = self.model(data) loss = F.nll_loss(output, target) loss.backward() return self.model.get_gradients() 4、同步训练 iterations = 200 num_workers = 2 ray.init(ignore_reinit_error=True) ps = ParameterServer.remote(1e-2) workers = [DataWorker.remote() for i in range(num_workers)] model = ConvNet() test_loader = get_data_loader()[1] print("Running synchronous parameter server training.") current_weights = ps.get_weights.remote() for i in range(iterations): gradients = [ worker.compute_gradients.remote(current_weights) for worker in workers ] current_weights = ps.apply_gradients.remote(*gradients) if i % 10 == 0: model.set_weights(ray.get(current_weights)) accuracy = evaluate(model, test_loader) print("Iter {}: accuracy is {:.1f}".format(i, accuracy)) print("Final accuracy is {:.1f}.".format(accuracy)) ray.shutdown() 5、异步训练 print("Running Asynchronous Parameter Server Training.") ray.init(ignore_reinit_error=True) ps = ParameterServer.remote(1e-2) workers = [DataWorker.remote() for i in range(num_workers)] current_weights = ps.get_weights.remote() gradients = {} for worker in workers: gradients[worker.compute_gradients.remote(current_weights)] = worker for i in range(iterations * num_workers): ready_gradient_list, _ = ray.wait(list(gradients)) ready_gradient_id = ready_gradient_list[0] worker = gradients.pop(ready_gradient_id) current_weights = ps.apply_gradients.remote(*[ready_gradient_id]) gradients[worker.compute_gradients.remote(current_weights)] = worker if i % 10 == 0: model.set_weights(ray.get(current_weights)) accuracy = evaluate(model, test_loader) print("Iter {}: accuracy is {:.1f}".format(i, accuracy)) print("Final accuracy is {:.1f}.".format(accuracy))
  • 相关阅读:
    软件测试各阶段测试人员的职责
    【Python 学习_第4周_字符编码】金角大王培训_第4周_理解装饰器_1
    【Python 学习_第3周_字符编码】金角大王培训_第三周_字符编解码_心得及网上资料连接
    【Python 学习_第2周_程序代码】金角大王培训第二周练习_购物车代码,将写的代码和老师代码比较,记录下收获
    [测试理论_测试用例设计第一步_建立需求模型]读 Essential Software Test Design 书有感_1
    基于selenium模块的Python 自动化脚本常见错误(二)_采用selenium自带截图工具截取的图片截取不完全时的处理方法
    【测试管理_浅谈软件测试的价值及如何做】
    【Python 学习_第1周_程序代码】金角大王培训第一周作业_输入用户名和密码,根据输入内容进行结果判定(包含excel表格读写操作)
    基于selenium模块的Python 自动化脚本常见错误(一)
    C#中yield return的作用
  • 原文地址:https://www.cnblogs.com/zcsh/p/14206727.html
Copyright © 2011-2022 走看看