zoukankan      html  css  js  c++  java
  • 论文笔记(8)"Personalized Federated Learning using Hypernetworks"

    这篇是ICML 2021的一篇论文,论文和代码都看了一下,配合着代码简单说一下文章思路。

    Motivation

    文章说PFL的难点在于用尽量少的通讯成本为每个用户提供个性化模型。然后文章列出的主要贡献也是传输成本和模型复杂度以及可以为不同算力资源的设备提供适应大小的模型,并且在结果上取得了不错的效果。

    作者通过在Server端训练一个hyper net来为各个用户生成所需要的模型参数来实现解耦传输成本和模型复杂度。

    Model Construction

    文中的Hyper net是一个多头网络,每个头输出的都是某一层的权重Tensor。具体而言,例如对于Cifar10,它的Hyper Net实际是这个样子

    class CNNHyper(nn.Module):
        def __init__(
                self, n_nodes, embedding_dim, in_channels=3, out_dim=10, n_kernels=16, hidden_dim=100,
                spec_norm=False, n_hidden=1):
            '''
            The hyper network stored in the server to generate the weight of the target network.
            
            Args:
              n_nodes: int, the total number of all nodes(users or clients)
              embedding_dim: int, dimension of the embedding 
              in_channels: int, the channels of the input image or data.
              out_dim: int, the amount of categories
              n_kernels: int, the number of kernels used in CNN
              hidden_dim: int, the dimension of the finnal latent layer in hypernetwork
              spec_norm: Bool, whether apply the sepc norm
              n_hidden: int, the number of the latent layers
            '''
            super().__init__()
    
            self.in_channels = in_channels
            self.out_dim = out_dim
            self.n_kernels = n_kernels
            self.embeddings = nn.Embedding(num_embeddings=n_nodes, embedding_dim=embedding_dim)
    		
            # Multilayer perceptron
            layers = [
                spectral_norm(nn.Linear(embedding_dim, hidden_dim)) if spec_norm else nn.Linear(embedding_dim, hidden_dim),
            ]
            for _ in range(n_hidden):
                layers.append(nn.ReLU(inplace=True))
                layers.append(
                    spectral_norm(nn.Linear(hidden_dim, hidden_dim)) if spec_norm else nn.Linear(hidden_dim, hidden_dim),
                )
    
            self.mlp = nn.Sequential(*layers)
    		
            # the weights of the targe network
            self.c1_weights = nn.Linear(hidden_dim, self.n_kernels * self.in_channels * 5 * 5)
            self.c1_bias = nn.Linear(hidden_dim, self.n_kernels)
            self.c2_weights = nn.Linear(hidden_dim, 2 * self.n_kernels * self.n_kernels * 5 * 5)
            self.c2_bias = nn.Linear(hidden_dim, 2 * self.n_kernels)
            self.l1_weights = nn.Linear(hidden_dim, 120 * 2 * self.n_kernels * 5 * 5)
            self.l1_bias = nn.Linear(hidden_dim, 120)
            self.l2_weights = nn.Linear(hidden_dim, 84 * 120)
            self.l2_bias = nn.Linear(hidden_dim, 84)
            self.l3_weights = nn.Linear(hidden_dim, self.out_dim * 84)
            self.l3_bias = nn.Linear(hidden_dim, self.out_dim)
    
            if spec_norm:
                self.c1_weights = spectral_norm(self.c1_weights)
                self.c1_bias = spectral_norm(self.c1_bias)
                self.c2_weights = spectral_norm(self.c2_weights)
                self.c2_bias = spectral_norm(self.c2_bias)
                self.l1_weights = spectral_norm(self.l1_weights)
                self.l1_bias = spectral_norm(self.l1_bias)
                self.l2_weights = spectral_norm(self.l2_weights)
                self.l2_bias = spectral_norm(self.l2_bias)
                self.l3_weights = spectral_norm(self.l3_weights)
                self.l3_bias = spectral_norm(self.l3_bias)
    
        def forward(self, idx):
            emd = self.embeddings(idx)
            features = self.mlp(emd)
    
            weights = OrderedDict({
                "conv1.weight": self.c1_weights(features).view(self.n_kernels, self.in_channels, 5, 5),
                "conv1.bias": self.c1_bias(features).view(-1),
                "conv2.weight": self.c2_weights(features).view(2 * self.n_kernels, self.n_kernels, 5, 5),
                "conv2.bias": self.c2_bias(features).view(-1),
                "fc1.weight": self.l1_weights(features).view(120, 2 * self.n_kernels * 5 * 5),
                "fc1.bias": self.l1_bias(features).view(-1),
                "fc2.weight": self.l2_weights(features).view(84, 120),
                "fc2.bias": self.l2_bias(features).view(-1),
                "fc3.weight": self.l3_weights(features).view(self.out_dim, 84),
                "fc3.bias": self.l3_bias(features).view(-1),
            })
            return weights
    

    用户端的Target Network结构

    class CNNTarget(nn.Module):
        def __init__(self, in_channels=3, n_kernels=16, out_dim=10):
            super(CNNTarget, self).__init__()
    
            self.conv1 = nn.Conv2d(in_channels, n_kernels, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(n_kernels, 2 * n_kernels, 5)
            self.fc1 = nn.Linear(2 * n_kernels * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, out_dim)
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(x.shape[0], -1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    

    Optimization

    然后直接看优化流程吧,对于我比较挂心的用户的特征向量\(\mathcal{v}_i\),他是直接拿用户的node_id也就是用户的标号,embedding出来的。整个代码只有两个model的实例,分别就是Hyper NetworkTarget Network的,然后每一轮只选择一个用户,Target Network加载根据node_id embedding出来的特征向量计算得来的权重,并进行优化。

    def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int,
              steps: int, inner_steps: int, optim: str, lr: float, inner_lr: float,
              embed_lr: float, wd: float, inner_wd: float, embed_dim: int, hyper_hid: int,
              n_hidden: int, n_kernels: int, bs: int, device, eval_every: int, save_path: Path,
              seed: int) -> None:
        '''
        The optimization process
        
        Arg:
          data_name: str, [Cifar10 or Cifar100]
          data_path: the path of the data
          classes_per_node: int, the number of classes chosen by each node
          num_nodes: int, the total number of nodes or users
          steps: int, the number of conmmunication rounds
          inner_steps: int, the number of local graidnet steps
          optim: str, sgd or adam
          lr: float, learning rate of the server
          inner_lr: float, learning rate of the node
          embed_lr: float, learning rate of the embedding layer
          wd: float, weight decay of the server
          inner_wd: float, weight decay of the node
          embed_dim: int, the dimension of the embedding layer output
          hyper_hid: int, the dimension of the finnal hidden layer output 
          n_hidden: int, the number of latent layers
          n_kernels: int, the number of kernnels in CNN
          bs: int, batch_size      
        
        '''
    
        ###############################
        # init nodes, hnet, local net #
        ###############################
        nodes = BaseNodes(data_name, data_path, num_nodes, classes_per_node=classes_per_node,
                          batch_size=bs)
    	
        # setting the embedding dim according to the n_nodes
        embed_dim = embed_dim
        if embed_dim == -1:
            logging.info("auto embedding size")
            embed_dim = int(1 + num_nodes / 4)
    	
        # Build the model
        if data_name == "cifar10":
            hnet = CNNHyper(num_nodes, embed_dim, hidden_dim=hyper_hid, n_hidden=n_hidden, n_kernels=n_kernels)
            net = CNNTarget(n_kernels=n_kernels)
        elif data_name == "cifar100":
            hnet = CNNHyper(num_nodes, embed_dim, hidden_dim=hyper_hid,
                            n_hidden=n_hidden, n_kernels=n_kernels, out_dim=100)
            net = CNNTarget(n_kernels=n_kernels, out_dim=100)
        else:
            raise ValueError("choose data_name from ['cifar10', 'cifar100']")
    
        hnet = hnet.to(device)
        net = net.to(device)
    
        ##################
        # init optimizer #
        ##################
        embed_lr = embed_lr if embed_lr is not None else lr
        optimizers = {
            'sgd': torch.optim.SGD(
                [
                    {'params': [p for n, p in hnet.named_parameters() if 'embed' not in n]},
                    {'params': [p for n, p in hnet.named_parameters() if 'embed' in n], 'lr': embed_lr}
                ], lr=lr, momentum=0.9, weight_decay=wd
            ),
            'adam': torch.optim.Adam(params=hnet.parameters(), lr=lr)
        }
        optimizer = optimizers[optim]
        criteria = torch.nn.CrossEntropyLoss()
    
        ################
        # init metrics #
        ################
        last_eval = -1
        best_step = -1
        best_acc = -1
        test_best_based_on_step, test_best_min_based_on_step = -1, -1
        test_best_max_based_on_step, test_best_std_based_on_step = -1, -1
        step_iter = trange(steps)
    
        results = defaultdict(list)
        for step in step_iter:
            hnet.train()
    
            # select a client at random
            node_id = random.choice(range(num_nodes))
    
            # produce & load local network weights
            weights = hnet(torch.tensor([node_id], dtype=torch.long).to(device))
            net.load_state_dict(weights)
    
            # init inner optimizer
            inner_optim = torch.optim.SGD(
                net.parameters(), lr=inner_lr, momentum=.9, weight_decay=inner_wd
            )
    
            # storing theta_i for later calculating delta theta
            inner_state = OrderedDict({k: tensor.data for k, tensor in weights.items()})
    
            # NOTE: evaluation on sent model
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                img, label = tuple(t.to(device) for t in batch)
                pred = net(img)
                prvs_loss = criteria(pred, label)
                prvs_acc = pred.argmax(1).eq(label).sum().item() / len(label)
                net.train()
    
            # inner updates -> obtaining theta_tilda
            for i in range(inner_steps):
                net.train()
                inner_optim.zero_grad()
                optimizer.zero_grad()
    
                batch = next(iter(nodes.train_loaders[node_id]))
                img, label = tuple(t.to(device) for t in batch)
    
                pred = net(img)
    
                loss = criteria(pred, label)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 50)
    
                inner_optim.step()
    
            optimizer.zero_grad()
    
            final_state = net.state_dict()
    
            # calculating delta theta
            delta_theta = OrderedDict({k: inner_state[k] - final_state[k] for k in weights.keys()})
    
            # calculating phi gradient
            hnet_grads = torch.autograd.grad(
                list(weights.values()), hnet.parameters(), grad_outputs=list(delta_theta.values())
            )
    
            # update hnet weights
            for p, g in zip(hnet.parameters(), hnet_grads):
                p.grad = g
    
            torch.nn.utils.clip_grad_norm_(hnet.parameters(), 50)
            optimizer.step()
    
            step_iter.set_description(
                f"Step: {step+1}, Node ID: {node_id}, Loss: {prvs_loss:.4f},  Acc: {prvs_acc:.4f}"
            )
    		
            # evaluation
            if step % eval_every == 0:
                last_eval = step
                step_results, avg_loss, avg_acc, all_acc = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="test")
                logging.info(f"\nStep: {step+1}, AVG Loss: {avg_loss:.4f},  AVG Acc: {avg_acc:.4f}")
    
                results['test_avg_loss'].append(avg_loss)
                results['test_avg_acc'].append(avg_acc)
    
                _, val_avg_loss, val_avg_acc, _ = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="val")
                if best_acc < val_avg_acc:
                    best_acc = val_avg_acc
                    best_step = step
                    test_best_based_on_step = avg_acc
                    test_best_min_based_on_step = np.min(all_acc)
                    test_best_max_based_on_step = np.max(all_acc)
                    test_best_std_based_on_step = np.std(all_acc)
    
                results['val_avg_loss'].append(val_avg_loss)
                results['val_avg_acc'].append(val_avg_acc)
                results['best_step'].append(best_step)
                results['best_val_acc'].append(best_acc)
                results['best_test_acc_based_on_val_beststep'].append(test_best_based_on_step)
                results['test_best_min_based_on_step'].append(test_best_min_based_on_step)
                results['test_best_max_based_on_step'].append(test_best_max_based_on_step)
                results['test_best_std_based_on_step'].append(test_best_std_based_on_step)
    	
        if step != last_eval:
            _, val_avg_loss, val_avg_acc, _ = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="val")
            step_results, avg_loss, avg_acc, all_acc = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="test")
            logging.info(f"\nStep: {step + 1}, AVG Loss: {avg_loss:.4f},  AVG Acc: {avg_acc:.4f}")
    
            results['test_avg_loss'].append(avg_loss)
            results['test_avg_acc'].append(avg_acc)
    
            if best_acc < val_avg_acc:
                best_acc = val_avg_acc
                best_step = step
                test_best_based_on_step = avg_acc
                test_best_min_based_on_step = np.min(all_acc)
                test_best_max_based_on_step = np.max(all_acc)
                test_best_std_based_on_step = np.std(all_acc)
    
            results['val_avg_loss'].append(val_avg_loss)
            results['val_avg_acc'].append(val_avg_acc)
            results['best_step'].append(best_step)
            results['best_val_acc'].append(best_acc)
            results['best_test_acc_based_on_val_beststep'].append(test_best_based_on_step)
            results['test_best_min_based_on_step'].append(test_best_min_based_on_step)
            results['test_best_max_based_on_step'].append(test_best_max_based_on_step)
            results['test_best_std_based_on_step'].append(test_best_std_based_on_step)
    
        save_path = Path(save_path)
        save_path.mkdir(parents=True, exist_ok=True)
        with open(str(save_path / f"results_{inner_steps}_inner_steps_seed_{seed}.json"), "w") as file:
            json.dump(results, file, indent=4)
    

    Summary

    1. 用户特征\(\mathcal{v}_i\)的获取是最让我感到奇怪的,可能用embedding来生成很直接,但是放在server端去根据node_id生成就有一种servernode是一种对抗的感觉,明明用户有自己的例如人口特征等用户特征数据。感觉让所有用户用这些数据去产生一个\(v_i\)更符合逻辑;
    2. 关于他说的传输成本和模型复杂度的解耦,感觉说的模棱两可,他传输的数据和普通的FedAvg是一样的,他确实可以在server端训练一个很深的网络,但是用户本地的模型变复杂那他的传输成本也会提高;
    3. 提供的代码里没有展示对不同算力资源的设备生成不同的模型。
  • 相关阅读:
    gitlab搭建
    java数组
    安裝nextcloud
    Spring的定时任务@Scheduled(cron = "0 0 1 * * *")
    java内存结构(下)
    java内存结构(上)
    多线程的三个特性
    @RequestBody用法
    SwaggerAPI注解详解(转载)
    在jpanel中添加jbutton并自由设置按钮大小和位置
  • 原文地址:https://www.cnblogs.com/DemonHunter/p/15616063.html
Copyright © 2011-2022 走看看