论文115:Reinforced GNNs for multiple instance learning (TNNLS‘24)

2024-06-01 1333阅读

文章目录

  • 1 要点
  • 2 预备知识
    • 2.1 MIL
    • 2.2 MIL-GNN
    • 2.3 Markov博弈
    • 2.4 深度Q-Learning
    • 3 方法
      • 3.1 观测生成与交互
      • 3.2 动作选择和指导
      • 3.3 奖励计算
      • 3.4 状态转移和终止
      • 3.5 多智能体训练

        1 要点

        题目:用于MIL的强化GNN

        代码:https://github.com/RingBDStack/RGMIL

        背景:MIL是一种监督学习变体,它处理包含多个实例的包,其中训练阶段只有包级别的标签可用。MIL在现实世界的应用中很多,尤其是在医学领域;

        挑战:现有的GNN在MIL中通常需要过滤实例间的低置信度边,并使用新的包结构来调整图神经网络架构。这样的调整过程频繁且忽视了结构和架构之间的相关性;

        RGMIL框架:首次在MIL任务中利用多智能体深度强化学习 (MADRL)。MADRL允许灵活定义或扩展影响包图或GNN的因素,并同步控制它们;

        贡献:

        1. 引入MADRL到MIL中,实现对包结构和GNN架构的自动化和同步控制;
        2. 使用边阈值和GNN层数作为因素案例来构建RGMIL,探索了以前在MIL研究中被忽视的边密度和聚合范围之间的相关性;
        3. 实验结果表明,RGMIL在多个MIL数据集上实现了最佳性能,并且具有出色的可解释性;

        细节:

        1. RGMIL将训练过程建模为一个完全合作的马尔可夫博弈 (MG);
        2. 通过两个智能体搜索边过滤阈值和GNN层数;
        3. 利用反向分解网络 (VDN) 来衡量智能体的贡献和相关性;
        4. 引入图注意力网络 (GAT) 并设计参数共享机制以提高效率;

        符号表:

        符号含义
        B \mathcal{B} B包集合
        G \mathcal{G} G与包相对应的、图的集合
        Y \mathcal{Y} Y包标签
        M \mathcal{M} MMarkov博弈的七元组
        S \mathcal{S} S M \mathcal{M} M的状态空间
        O \mathcal{O} O M \mathcal{M} M的观测空间
        A \mathcal{A} A M \mathcal{M} M的动作空间
        L \mathcal{L} L智能体或者GNN模型的训练损失
        N N N包数量
        M M M包内实例数量
        L L LGNN层的数量
        T T T时间步的数量
        I I I智能体的数量
        D D D特征表示的维度
        A \mathbf{A} A与图相对应的邻接矩阵
        F \mathbf{F} F与图相对应的实例特征矩阵
        E \mathbf{E} E与图相对应的包图特征矩阵
        Z \mathbf{Z} Z特征变换矩阵
        C \mathbf{C} C重要性系数矩阵
        i ; j ; k ; l ; t i;j;k;l;t i;j;k;l;t索引变量
        s ; o ; a ; r s;o;a;r s;o;a;r状态、观测、动作、奖励
        v v v注意力机制特征向量
        γ \gamma γ折扣系数
        α \alpha α智能体学习率
        μ \mu μ动作或者奖励的窗口大小
        λ \lambda λ终止条件的奖励阈值
        & ; % \&;\% &;%逻辑和取余运算
        ⊕ \oplus ⊕拼接操作
        ∥ ⋅ ∥ \|\cdot\| ∥⋅∥矩阵的Norm函数
        σ ( ⋅ ) \sigma(\cdot) σ(⋅)激活函数
        π ( ⋅ ) \pi(\cdot) π(⋅)智能体状态-动作函数
        RWD ( ⋅ ) \text{RWD}(\cdot) RWD(⋅)奖励函数
        TRN ( ⋅ ) \text{TRN}(\cdot) TRN(⋅)状态转移函数
        AGG ( ⋅ ) \text{AGG}(\cdot) AGG(⋅)特征聚合函数
        POL ( ⋅ ) \text{POL}(\cdot) POL(⋅)特征池化函数
        EVL ( ⋅ ) \text{EVL}(\cdot) EVL(⋅)分类性能评估函数

        2 预备知识

        2.1 MIL

        令 B = { B i ∣ i = 1 , … , N } \mathcal{B}=\{\mathcal{B}_i|i=1,\dots,N\} B={Bi​∣i=1,…,N}表示包含多个包 B i = { B i , j ∣ j = 1 … , M } \mathcal{B}_i=\{\mathcal{B}_{i,j}|j=1\dots,M\} Bi​={Bi,j​∣j=1…,M},其中 N N N和 M M M分别表示包和包中实例的数量 (通常 M M M是变化的)。每个包对应一个两类包标签 Y i = max ⁡ ( Y i , 1 , … , Y i , M ) \mathcal{Y}_i=\max(\mathcal{Y}_{i,1},\dots,\mathcal{Y}_{i,M}) Yi​=max(Yi,1​,…,Yi,M​),其中 Y i , j ∈ { 0 , 1 } \mathcal{Y}_{i,j}\in\{0,1\} Yi,j​∈{0,1}是假设的实例标签。尽管数据集中少量的实例具有真实的标签,然而在MIL的训练过程中,实例标签是不可用的。因此,MIL的目标是学习一个将包映射为标签的映射函数 B → Y \mathcal{B\to Y} B→Y,其中 Y = { Y i ∣ i = 1 , … , N } \mathcal{Y}=\{ \mathcal{Y}_i | i=1,\dots,N \} Y={Yi​∣i=1,…,N}。

        2.2 MIL-GNN

        对于MIL-GNN,其首先需要将所有的包转换为一个图的集合 G = { G i ∣ i = 1 , … , N } \mathcal{G}=\{ \mathcal{G}_i|i=1,\dots,N \} G={Gi​∣i=1,…,N},其中每个包对应一个图 G i = ( A i , F i ) \mathcal{G}_i=(\mathbf{A}_i,\mathbf{F}_i) Gi​=(Ai​,Fi​),此外,每个实例可以看作是一个节点。每个邻接矩阵 A i ∈ R M × M \mathbf{A}_i\in\mathbb{R}^{M\times M} Ai​∈RM×M使用原始节点特征构建,并通过阈值来过滤边,其每个元素表示一跳邻域信息。 F i ∈ R M × D \mathbf{F}_i\in\mathbb{R}^{M\times D} Fi​∈RM×D表示实例节点的特征矩阵。

        基于此, L L L层GNN被用于传递节点特征信息,其中对于第 i i i个图 G i \mathcal{G}_i Gi​,其在第 l l l层的聚合过程表示为:

        F i l = σ ( AGG l ( A i , F i l − 1 ) ) (1) \tag{1} \mathbf{F}_i^l=\sigma\left( \text{AGG}^l (\mathbf{A}_i,\mathbf{F}_i^{l-1})\right) Fil​=σ(AGGl(Ai​,Fil−1​))(1)其中 AGG l ( ⋅ ) \text{AGG}^l(\cdot) AGGl(⋅)表示在第 l l l层的聚合函数,例如卷积和注意力、 σ ( ⋅ ) \sigma(\cdot) σ(⋅)表示激活函数、 F i l \mathbf{F}_i^l Fil​是更新后的特征矩阵。

        接下来,一个节点特征池化函数 POL ( ⋅ ) \text{POL}(\cdot) POL(⋅)被用于GNN的最后一层,以获取最终的图级别特征矩阵 E ( i ) ∈ R 1 × D \mathbf{E}(i)\in\mathbb{R}^{1\times D} E(i)∈R1×D:

        E ( i ) = POL ( { F i L ( j ) ∣ j = 1 , … , M } ) (2) \tag{2} \mathbf{E}(i)=\text{POL}(\{ \mathbf{F}_i^L(j) |j=1,\dots,M \}) E(i)=POL({FiL​(j)∣j=1,…,M})(2)其中 F i L ( j ) ∈ R 1 × D \mathbf{F}_i^L(j)\in\mathbb{R}^{1\times D} FiL​(j)∈R1×D是实例节点 B i , j \mathcal{B}_{i,j} Bi,j​的特征向量。最后, E ( i ) \mathbf{E}(i) E(i)传递给一个包图分类器。因此,在MIL-GNN中,其映射过程为 B → G → Y \mathcal{B\to G\to Y} B→G→Y。

        2.3 Markov博弈

        在多智能体强化学习 (MARL) 中,Markov博弈 (MG) 是从Markov决策过程 (MDP) 扩展而来。特别地,一个MG包含多个能够共同影响奖励和状态转移的智能体。根据是否所有的智能体都能完全获得全局状态信息,已有的MG可被看作是完全或者部分可观测,其中后者则更为普遍。

        部分可观测的MG可以被抽象为一个七元组 M = \mathcal{M}= M=,其中:

        1. S \mathcal{S} S:MG的全局状态空间;
        2. A i \mathcal{A}_i Ai​:第 i i i个智能体的动作空间。在每个时间步 t ∈ [ 1 , T ] t\in[1,T] t∈[1,T],每个智能体根据其独有的状态动作函数 π i ( ⋅ ) \pi_i(\cdot) πi​(⋅)来选择动作 a i t ∈ A i a_i^t\in\mathcal{A}_i ait​∈Ai​;
        3. 每个智能体会从全局状态获得一个独立的部分观察 o i t ∈ O i o_i^t\in\mathcal{O}_i oit​∈Oi​,因此, π i ( ⋅ ) \pi_i(\cdot) πi​(⋅)可以表示为 S → O i → A i \mathcal{S\to O_i\to A_i} S→Oi​→Ai​;
        4. 每个智能体使用其奖励函数 RED i ( ⋅ ) \text{RED}_i(\cdot) REDi​(⋅)获得即时奖励 r i t r_i^t rit​,这种博弈也被称为分散的部分可观测MDP (Dec-POMDP),旨在最大化累积奖励 ∑ t = 1 T γ ( t − 1 ) r ∗ t \sum^T_{t=1}\gamma^{(t−1)}r^{*t} ∑t=1T​γ(t−1)r∗t,其中 γ γ γ表示控制后续奖励的折扣系数;
        5. 状态转移函数 TRN ( ⋅ ) \text{TRN}(\cdot) TRN(⋅)将当前状态 s t s^t st与联合动作 a ∗ t a^{*t} a∗t映射到下一个状态 s ( t + 1 ) s^{(t+1)} s(t+1),即 S × A ∗ → S \mathcal{S \times A^*\to S} S×A∗→S。

        2.4 深度Q-Learning

        作为基于价值的RL的基算法,Q-Learning非常适合实现单一智能体的顺序决策系统。QLearning包含一个状态-动作表 π ( ⋅ ) π(·) π(⋅),它记录了各种状态下所有可能动作的 Q Q Q值。初始化后,智能体不断与环境交互,并通过Bellman方程更新 π ( ⋅ ) π(·) π(⋅)直到收敛。 π ( ⋅ ) π(·) π(⋅)的更新过程可以表示如下:

        x = x + α [ r t + γ max ⁡ a π ( s t + 1 , a ) − x ] ] s.t.  x = π ( s t , a t ) (3) \tag{3} \begin{aligned} & x = x + \alpha \left[ r_t + \gamma \max_{a} \pi(s_{t+1}, a) - x \right] ]\\ & \text{s.t. } x = \pi(s_t, a_t) \end{aligned} ​x=x+α[rt​+γamax​π(st+1​,a)−x]]s.t. x=π(st​,at​)​(3)其中: π ( s t , a t ) \pi(s_t, a_t) π(st​,at​)是预测的Q值,以及在状态 s t s_t st​下选择动作 a t a_t at​的预期奖励、 r t r_t rt​表示时间步 t t t的即时奖励、 max ⁡ a π ( s t + 1 , a ) \max_a \pi(s_{t+1}, a) maxa​π(st+1​,a)是下一个状态 s t + 1 s_{t+1} st+1​的最大Q值,以及 α \alpha α是 π ( ⋅ ) \pi(·) π(⋅)的学习率。

        在实际应用中,许多环境的状态空间是无限的,记录所有状态-动作对的值是不可行的。受深度学习的启发,许多工作引入了深度神经网络 (DNN) 来近似返回值,其中深度Q-Learning (DQN) 是传统Q-Learning的直接扩展:

        1. DQN使用DNN构建动作-价值函数 π π π (亦称为 Q Q Q函数),该函数将每个状态向量映射到 Q Q Q值向量 π ( s ) ∈ R 1 × ∣ A ∣ \pi(s) \in \mathbb{R}^{1 \times |A|} π(s)∈R1×∣A∣,其中 ∣ A ∣ |A| ∣A∣表示动作空间 A A A的大小;
        2. DQN应用经验回放和目标网络技术来更新函数 π ( ⋅ ) \pi(·) π(⋅)。例如,给定过去时间步 t t t的经验记录,其元组形式为 ⟨ s t , a t , r t , s t + 1 ⟩ \langle s_t, a_t, r_t, s_{t+1} \rangle ⟨st​,at​,rt​,st+1​⟩,则 π π π的时序差分损失可以计算如下:

          L π = E s , a , r , s ′ [ ( π ‾ ( s t , a t ) − π ( s t , a t ) ) 2 ] s.t.  π ‾ ( s t , a t ) = r t + γ max ⁡ a π ‾ ( s t + 1 , a ) (4) \tag{4} \begin{aligned} &L_\pi = \mathbb{E}_{s,a,r,s'} \left[ \left( \overline{\pi}(s_t, a_t) - \pi(s_t, a_t) \right)^2 \right]\\ &\text{s.t. } \overline{\pi}(s_t, a_t) = r_t + \gamma \max_a \overline{\pi}(s_{t+1}, a) \end{aligned} ​Lπ​=Es,a,r,s′​[(π(st​,at​)−π(st​,at​))2]s.t. π(st​,at​)=rt​+γamax​π(st+1​,a)​(4)其中: π ( ⋅ ) \pi(·) π(⋅)表示评估网络,其用于预测状态 s t s_t st​和动作 a t a_t at​的 Q Q Q值的评估网络、 π ‾ ( ⋅ ) \overline{\pi}(·) π(⋅)是一个目标网络,其架构与 π ( ⋅ ) \pi(·) π(⋅)相同。只有 π ( ⋅ ) \pi(·) π(⋅)被优化,并且其训练参数周期性复制到 π ‾ ( ⋅ ) \overline{\pi}(·) π(⋅)。由于 π ‾ \overline{\pi} π不更新时目标 Q Q Q值是稳定的,因此 π ( ⋅ ) \pi(·) π(⋅)的训练稳定性是极好的;

        3. 为了权衡探索新动作的概率,DQN应用了 ϵ ϵ ϵ-贪婪算法。因此,它并不总是选择 π ( s ) \pi(s) π(s)中最大条目的对应动作,其可以表示如下:

          a = { random action , w.p. ϵ argmax a π ( s t , a ) , w.p. 1 − ϵ (5) \tag{5} \begin{aligned} a = \begin{cases} \text{random action}, & \text{w.p.} \quad\epsilon \\ \text{argmax}_a \pi(s_t, a), & \text{w.p.} \quad 1 - \epsilon \end{cases} \end{aligned} a={random action,argmaxa​π(st​,a),​w.p.ϵw.p.1−ϵ​​(5)其中, ϵ \epsilon ϵ表示随机选择动作的概率,即探索,而 1 − ϵ 1-\epsilon 1−ϵ表示选择当前基于 π π π的最优动作,即利用。通过这样做,DQN避免了在强化学习任务中的探索-利用困境,避开了局部最优,并促进了更好的 π π π函数的发现。

        3 方法

        本节介绍RGMIL的细节,包括:1) 用于提升博弈公平性的观测生成与交互;2) 用于提升GNN效率的动作选择和指导技术;3) 用于提升博弈稳定性的奖励计算;4) 用于确保博弈收敛的状态转移和终止技术;以及5) 多智能体训练。

        RGMIL的总览如图4所示,其中左子图对应章节3.1至3.4,右子图对应章节3.5。

        论文115:Reinforced GNNs for multiple instance learning (TNNLS‘24)

        图4:RGMIL总览。左右子图分别对应经验收集和代理优化:1) 每一个时间步,初始观测从当前的block导出;2) 观测作为代理的输入,用于选择当前的动作;3) 构建可信包图,并作为定制的GNN的输入;4) GNN训练后,通过动作组合来评估性能,并确定当前的奖励;5) 带有动作的转移函数作为输入,以生成下一次观测;6) 记录以上过程,到达一定数量后,由VDN执行代理优化

        3.1 观测生成与交互

        在RGMIL中,我们将其训练过程建模为一个合作的马尔可夫博弈 (MG),涉及两个智能体,分别用于搜索最佳的边过滤阈值和GNN层数:

        1. 利用一个改进的VDN来实现MG:
          • 将训练集划分为多个等大小的区块,其中一个区块作为验证集,其余区块用作构建MG状态空间 S S S;
          • 在第一个时间步之前,随机选择一个训练区块作为全局状态;
          • 由于边过滤阈值的选择通常与拓扑信息相关,我们随后指定当前状态中包图的结构特征作为第一个智能体的观察;
          • 通过包图的成对相似性建立实例节点的初始边。以属于当前区块的第 i i i个包 B i \mathcal{B}_i Bi​为例,它的包图 G i \mathcal{G}_i Gi​可以被抽象为一个邻接矩阵 A i \mathbf{A}_i Ai​以及一个特征矩阵 F i \mathbf{F}_i Fi​;
          • 给定初始矩阵 F i 0 \mathbf{F}^0_i Fi0​,初始邻接矩阵 A i \mathbf{A}_i Ai​的计算如下:

            A i ( j , j ′ ) = ∥ F i 0 ( j ) − F i 0 ( j ′ ) ∥ 2 (6) \tag{6} \mathbf{A}_i(j, j') = \|\mathbf{F}^0_i(j) - \mathbf{F}^0_i(j')\|_2 Ai​(j,j′)=∥Fi0​(j)−Fi0​(j′)∥2​(6)其中 ∥ ⋅ ∥ 2 \|\cdot\|_2 ∥⋅∥2​表示矩阵的二范数、 A i ( j , j ′ ) \mathbf{A}_i(j, j') Ai​(j,j′)编码了第 j j j个和第 j ′ j' j′个实例节点之间的欧式距离。

          • 因此,第一个智能体的观察计算如下:

            o 1 ( d ) = 1 N d ∑ i = 1 N d exp ⁡ ( − A i ) s.t.  M i = d , d ∈ [ 1 , max ⁡ M i ] (7) \tag{7} \begin{aligned} &o_1(d) = \frac{1}{N_d} \sum_{i=1}^{N_d} \exp(-\mathbf{A}_i)\\ & \text{s.t. } M_i = d, \quad d \in [1, \max M_i] \end{aligned} ​o1​(d)=Nd​1​i=1∑Nd​​exp(−Ai​)s.t. Mi​=d,d∈[1,maxMi​]​(7)其中 o 1 ( d ) o_1(d) o1​(d)表示向量 o 1 o_1 o1​的第 d d d个条目、 N d N_d Nd​是当前区块中包的数量,并且它包含的实例数量等于 d d d、 M i M_i Mi​是包图 G i G_i Gi​的实例节点数量;

          • 由于GNN层数控制特征聚合的迭代,随后从初始节点特征 F i 0 \mathbf{F}^0_i Fi0​中获取第二个智能体的观察:

            o 2 = 1 N ∑ i = 1 N ( 1 M i ∑ j = 1 M i F i 0 ( j ) ) (8) \tag{8} o_2 = \frac{1}{N} \sum_{i=1}^{N} \left( \frac{1}{M_i} \sum_{j=1}^{M_i} F^0_i(j) \right) o2​=N1​i=1∑N​(Mi​1​j=1∑Mi​​Fi0​(j))(8)其中 F i 0 ( j ) \mathbf{F}^0_i(j) Fi0​(j)是第 j j j个实例节点的特征向量、 N N N是当前区块中包图的总数;

          • 为了进一步探索边密度和聚合迭代之间的潜在相关性,引入了观察信息交互:

            o 1 = o 1 ⊕ σ ( ( o 1 ⊕ o 2 ) ( o 2 ⊕ o 1 ) T o 1 ) o 2 = o 2 ⊕ σ ( ( o 1 ⊕ o 2 ) ( o 2 ⊕ o 1 ) T o 2 ) (9) \tag{9} \begin{aligned} &o_1 = o_1 \oplus \sigma((o_1 \oplus o_2)(o_2 \oplus o_1)^T {o_1})\\ &o_2 = o_2 \oplus \sigma((o_1 \oplus o_2)(o_2 \oplus o_1)^T {o_2}) \end{aligned} ​o1​=o1​⊕σ((o1​⊕o2​)(o2​⊕o1​)To1​)o2​=o2​⊕σ((o1​⊕o2​)(o2​⊕o1​)To2​)​(9)其中 ⊕ ( ⋅ ) \oplus(\cdot) ⊕(⋅)是向量的连接操作。通过此操作,观察 o 1 o_1 o1​和 o 2 o_2 o2​具有相同的维度,并且都编码了来自对方的信息;

        RGMIL减轻了由于观察的特征维度或信息量的变化可能导致的MG中的不公平博弈。此外,为了提高这部分的效率,RGMIL只为每个数据区块一次性计算并记录这些初始邻接矩阵和观察。

        3.2 动作选择和指导

        当输入当前的观察向量 o i o_i oi​后,每个智能体将其映射为一个 Q Q Q值向量 π i ( o i ) ∈ R 1 × ∣ A i ∣ \pi_i(o_i) \in \mathbb{R}^{1 \times |\mathcal{A}_i|} πi​(oi​)∈R1×∣Ai​∣,并基于最大的 Q Q Q值条目或随机选择一个动作 a i a_i ai​ (如公式5):

        1. 第一个阈值动作 a 1 ∈ [ 0 , 1 ] a_1 \in [0, 1] a1​∈[0,1]是一个小数,而第二个层数动作 a 2 a_2 a2​是一个整数;
        2. 在 a 1 a_1 a1​的指导下,可以获得一个更可靠的邻接矩阵 A i \mathbf{A}_i Ai​:

          A i ( j , j ′ ) = { 1 , if  exp ⁡ ( − A i ( j , j ′ ) ) ≥ a 1 0 , if  exp ⁡ ( − A i ( j , j ′ ) )

VPS购买请点击我

免责声明:我们致力于保护作者版权,注重分享,被刊用文章因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理! 部分文章是来自自研大数据AI进行生成,内容摘自(百度百科,百度知道,头条百科,中国民法典,刑法,牛津词典,新华词典,汉语词典,国家院校,科普平台)等数据,内容仅供学习参考,不准确地方联系删除处理! 图片声明:本站部分配图来自人工智能系统AI生成,觅知网授权图片,PxHere摄影无版权图库和百度,360,搜狗等多加搜索引擎自动关键词搜索配图,如有侵权的图片,请第一时间联系我们,邮箱:ciyunidc@ciyunshuju.com。本站只作为美观性配图使用,无任何非法侵犯第三方意图,一切解释权归图片著作权方,本站不承担任何责任。如有恶意碰瓷者,必当奉陪到底严惩不贷!

目录[+]