0
点赞
收藏
分享

微信扫一扫

【论文阅读】FedProto: Federated Prototype Learning across Heterogeneous Clients

先峰老师 2022-04-13 阅读 79

在进行联邦学习中解决异质性问题的调研,这篇文章发表于2022的AAAI,使用原型来解决这部分问题,同时还使用“模型异构“的概念(之前一般见到的都是系统异构)。

一、阅读目标

  1. 了解使用原型学习解决异质性的思路,以及可以借鉴的方向

  2. 区分模型异构与系统异构两个概念,并总结

二、问题回答

  1. 构建以类别为单位的原型,服务器端通过原型的传递代替模型的聚合,并将聚合的原型返回客户端用于训练
  2. 我认为二者是从属关系,模型异构应该属于系统异构的一个方面,但现在没有官方定义,大家仍然混用

三、摘要

注:1)整体想法听起来很像Model-Contrastive Federated Learning文章的思路(模型简称为MOON),MOON模型的设计是,在模型参数的传递之外,客户端和服务器之间传递模型hidden layer得到的特征向量,构造本地特征向量和全局特征向量之间的对比损失,令此对比损失参与客户端的训练过程,使得客户端的本地特征向量不断逼近全局特征向量 2)相比于MOON,FedProto的不同之处在于,他只传递原型,而不传递模型参数(由于知识的匮乏,不太懂这是怎么完成的)3)主要论点是怎么完成的(为什么原型的传递可以避免客户端之间本来会出现的不对齐)

四、引入

注:1)关于统计异质性和模型异质性的这三种解决方案,我的了解都不深,需要再重点看看 2)看完摘要,感觉这篇文章的立意确实很强,大多数联邦学习文章,都是从如何解决异构性的角度出发,而这篇文章是说,你异构就异构吧,我从别的地方下手,另辟蹊径解决问题 3)下面重点就是看他是使用何种方法得到这种原型表示了,以及这种原型表示能不能像他说得那样有着类似人类”认知“的强大表达能力(如果没有就是高开低走了)

五、相关工作

一般来说都不整理相关工作的,但最近正好要做关于异质性的调研,正好来研究下

注:1)超级大问题:这个原型是用来表示一个类的特征的平均,那么跟原数据的直接概率分布有什么不同? 2)不过就是传递网络中某一层之后的特征平均,为什么就能替代模型参数?这表达能力完全不是一个级别吧?而且server该如何聚合呢?client拿到server返回的数据又该如何利用呢? 3)没有类别概念的任务,是不是可以直接对client进行原型学习,然后在server端进行原型聚合?

六、方法

1. FedAvg

FedAvg中的目标函数:
a r g m i n w ∑ i = 1 m ∣ D i ∣ N L S ( F ( w ; x ) , y ) arg min_w \sum_{i=1}^m \frac{|D_i|}{N}L_S(F(w;x),y) argminwi=1mNDiLS(F(w;x),y)
D i D_i Di表示第i个客户端的数据集, N N N为整体数据量, F F F为共享模型, L S L_S LS表示损失函数,此目标函数用来最小化加权的客户端损失函数

2. 真实联邦场景

真实的联邦学习场景,会存在统计异质性和模型异质性,则目标函数变为:
a r g m i n w 1 , w 2 , . . . , w m ∑ i = 1 m ∣ D i ∣ N L S ( F i ( w i ; x ) , y ) arg min_{w_1,w_2, ..., w_m} \sum_{i=1}^m \frac{|D_i|}{N}L_S(F_i(w_i;x),y) argminw1,w2,...,wmi=1mNDiLS(Fi(wi;x),y)
模型异质性会导致目标函数的变化,统计异质性则不会;但统计异质性会对联邦优化的结果造成影响

3. 基于原型的聚合

在基于原型的问题中,仍然存在统计异质性和模型异质性。但不同的模型,都可以分成两部分:表示层(Representation layers) f i ( ϕ i ; x ) f_i(\phi_i;x) fi(ϕi;x)和决策层(Decision layers),后者通常指示网络的最后一层,那么表示层Representation layers则代表除了最后一层之外的其他层;

原型的表示是基于类别的,其中第 i i i个client中的第 j j j类别的原型表示为:
C i ( j ) = 1 ∣ D i , j ∣ ∑ ( x , y ) ∈ D i . j f i ( ϕ i ; x ) C_i^{(j)} = \frac{1}{|D_{i,j}|}\sum_{(x,y) \in D_{i.j}}f_i(\phi_i;x) Ci(j)=Di,j1(x,y)Di.jfi(ϕi;x)
client要对每个类别都聚合一个原型,假设共有 E E E类,则每个client向server传递 E E E个原型

目标函数可以写为:
a r g m i n { C ˉ ( j ) } j = 1 ∣ C ∣ ∑ i = 1 m ∣ D i ∣ N L S ( F i ( w i ; x ) , y ) + λ ∑ j = 1 ∣ C ∣ ∑ i = 1 m ∣ D i . j ∣ N j L R ( C ˉ i ( j ) , C i ( j ) ) arg min_{\{\bar C^{(j)}\}_{j=1}^{|C|}} \sum_{i=1}^m \frac{|D_i|}{N}L_S(F_i(w_i;x),y) + \lambda \sum_{j=1}^{|C|} \sum_{i=1}^m \frac{|D_{i.j}|}{N_j}L_R(\bar C_i^{(j)},C_i^{(j)}) argmin{Cˉ(j)}j=1Ci=1mNDiLS(Fi(wi;x),y)+λj=1Ci=1mNjDi.jLR(Cˉi(j),Ci(j))
C i ( j ) C_i^{(j)} Ci(j)是第 i i i个client的第 j j j类别的原型,则 C ˉ i ( j ) \bar C_i^{(j)} Cˉi(j)是server端对所有client的第 j j j类原型的聚合原型( i i i下标单单指示不同client的训练,但每一轮中不同client的 C ˉ i ( j ) \bar C_i^{(j)} Cˉi(j)是相同的)
client端,保持 C ˉ i ( j ) \bar C_i^{(j)} Cˉi(j)不变,优化 w i w_i wi C i ( j ) C_i^{(j)} Ci(j)
server端,通过下列公式聚合全局原型:
C ˉ ( j ) = 1 ∣ N j ∣ ∑ i ∈ N j ∣ D i , j ∣ N j C i ( j ) \bar C^{(j)} = \frac{1}{|N_j|}\sum_{i \in N_j}\frac{|D_{i,j}|}{N_j}C_i^{(j)} Cˉ(j)=Nj1iNjNjDi,jCi(j)
其中, N j N_j Nj表示所有包含 j j j类的client的集合,然后保证

模型测试阶段,如果出现一个新的client,那么首先使用预训练的模型,比如ImageNet上训练的ResNet18对表示层进行初始化,随机化决策层,然后使用这个公式来调整参数: a r g m i n j ∣ ∣ f ( ϕ ; x ) − C ( j ) ∣ ∣ 2 argmin_j||f(\phi;x)-C^{(j)}||_2 argminjf(ϕ;x)C(j)2
表示通过最小化表示向量和聚合原型之间的 L 2 L2 L2距离,然后就可以令新的client进行预测了

举报

相关推荐

0 条评论