Search
Duplicate

Few-shot Learning with Graph Neural Networks

Summary

Few shot learning을 위해 다른 sample과의 similarity 정보까지 이용(즉, 각 sample의 label을 독립적으로 학습하는데서 그치지 않음)
각 sample을 graph의 node라고 보고, edge는 두 sample간의 similarity kernel로 간주.
Edge, 즉 similarity kernel은 trainable 함(즉, 단순한 inner product 등으로 pre-defined 되지 않음)
Node의 feature는 message passing algorithm에서 착안하여 각 time step 마다 이웃 node에서 message를 받아서 업데이트됨.
Semi-supervised learning, 더 나아가 active learning에도 적용 가능.
Omniglot, Mini-ImageNet에 대해 더 적은 parameter로 state-of-the-art 성능을 보여줌(2017년 기준)

Keywords

Few shot learning
Graph neural network
Semi-supervised learning
Active learning with Attention

1. Introduction

Supervised end-to-end learning has been extremely successful in computer vision, speech, or machine translation tasks.
However, there are some tasks(e.g. few shot learning) that cannot achieve high performance with conventional methods.
New supervised learning setup
Input-output setup:
With i.i.d. samples of collections of images and their associated label similarity
cf) conventional setup: i.i.d. samples of images and their associated labels
Authors' model can be extended to semi-supervised and active learning
Semi-supervised learning:
Learning from a mixture of labeled and unlabeled examples
Active learning:
The learner has the option to request those missing labels that will be most helpful for the prediction task
ICML 2019 active learning tutorial
Annotated by JH Gu

2. Closely related works and ideas

[Research article] Matching Networks for One shot learning - Vinyals et al.(2016)
[Research article] Prototypical Networks for Few-shot Learning - Snell et al.(2017)
[Review article] Geometric deep learning - Bronstein et al.(2017)
[Research article] Message passing - Gilmer et al.(2017)

3. Problem set-up

Authors view the task as a supervised interpolation problem on a graph
Nodes: Images
Edges: Similarity kernels → TRAINABLE

General set-up

Input-output pairs (Ti,Yi)i(\mathcal{T}_i, Y_i)_i drawn from i.i.d. from a distribution P\mathcal{P} of partially labeled image collections
ss: # labeled samples
rr: # unlabled samples
tt: # samples to classify
KK: # classes
Pl(RN)\mathcal{P}_l(\mathbb{R}^N): class-specific image distribution over RN\mathbb{R}^N
targets YiY_i are associated with xˉ1,...,xˉtTi\bar{x}_1, ..., \bar{x}_t \in \mathcal{T}_i
Learning objective:
minΘ1LiL(Φ(Ti,Θ),Yi)+R(Θ)\min_\Theta \frac{1}{L} \sum_{i \leq L} \ell(\Phi(\mathcal{T}_i, \Theta), Y_i) + \mathcal{R}(\Theta)
(R\mathcal{R} is the standard regularization objective)

Few shot learning setting

r=0,t=1,s=qKr=0, t=1, s=qK \longrightarrow qshotKwayq-\text{shot} \, K-\text{way}

Semi-supervised learning setting

r>0,t=1r > 0, t=1
Model can use the auxiliary images(unlabeled set) {x~1,...,x~r}\{ \tilde{x}_1, ..., \tilde{x}_r \} to improve the prediction accuracy, by leveraging the fact that these samples are drawn from the common distributions.

Active learning setting

The learner has the ability to request labels from the auxiliary images {x~1,...,x~r}\{\tilde{x}_1, ..., \tilde{x}_r\}.

4. Model

ϕ(x)\phi(x): CNN
h(l)h(l): One-hot encoded label(for labeled set), or uniform distribution(for unlabeled set)

4.1. Set and Graph Input Representations

The goal of few shot learning:
To propagate label information from labeled samples towards the unlabeled query image
→ The propagation can be formalized as a posterior inference over a graphical model
GT=(V,E)G_\mathcal{T} = (V,E)
Similarity measure is not pre-specified, but learned!
c.f.) in Siamese network, the similarity measure is fixed(L1 distance)!
본 논문(Few shot learning with GNN)에 쓰인 문장 구조가 이상해서 헷갈리게 쓰여있음.

4.2. Graph Neural Networks

We are given an input signal FRV×dF \in \mathbb{R}^{V \times d} on the vertices of a weighted graph GG.
Then we consider a family, or a set "A\mathcal{A}" of graph intrinsic linear operators.
A={A~(k),1}\mathcal{A} = \{\tilde{A}^{(k)}, \mathbf{1}\}
Linear operator
e.g.) Simplest linear operator is adjacency operator AA, where (AF)i=jiwi,jFj(AF)_i = \sum_{j \sim i} w_{i,j}F_j (wi,jw_{i,j} is associated weight)

GNN layer

A GNN layer Gc()\text{Gc}(\cdot) receives as input a signal x(k)RV×dk\mathbf{x}^{(k)} \in \mathbb{R}^{V\times d_k} and produces x(k+1)RV×dk+1\mathbf{x}^{(k+1)} \in \mathbb{R}^{V\times d_{k+1}}
x(k+1)=Gc(x(k))=ρ(BABx(k)θB,l(k))\mathbf{x}^{(k+1)} = \text{Gc}(\mathbf{x}^{(k)}) = \rho\Big(\sum_{B\in\mathcal{A}} B\mathbf{x}^{(k)}\theta^{(k)}_{B, l}\Big )
x(k)\mathbf{x}^{(k)}: representation vector of a certain node at time step kk
θ\theta: trainable parameters
ρ\rho: Leaky ReLU

Construction of edge feature matrix, inspired by message passing algorithm

A~i,j(k)=φθ~(xi(k),xj(k))\tilde{A}^{(k)}_{i, j} = \varphi_{\tilde{\theta}}(\mathbf{x}^{(k)}_i, \mathbf{x}^{(k)}_j )
A~i,j(k)\tilde{A}^{(k)}_{i, j}: learned edge features from the node's current hidden representation(at time step kk)
φ\varphi: a metric and a symmetric function parameterized with neural network
φθ~(xi(k),xj(k))=MLPθ~(abs(xi(k)xj(k)))\varphi_{\tilde{\theta}}(\mathbf{x}^{(k)}_i, \mathbf{x}^{(k)}_j ) = \text{MLP}_{\tilde{\theta}}(abs(\mathbf{x}^{(k)}_i - \mathbf{x}^{(k)}_j))
A~(k)\tilde{A}^{(k)} is then normalized by row-wise softmax
→ And added to the family A={A~(k),1}\mathcal{A} = \{\tilde{A}^{(k)}, \mathbf{1}\}
1\mathbf{1}: Identity matrix, which is the self-edge to aggregate vertex's own features

Construction of initial node features

xi(0)=(ϕ(xi),h(li))\mathbf{x}^{(0)}_i = (\phi(x_i), h(l_i))
ϕ\phi: convolutional neural network
h(l)R+Kh(l) \in \mathbb{R}^K_+ : a one-hot encoding of the label
For images with unknown label, x~j\tilde{x}_j(unlabeled data) and xˉj \bar{x}_j(test data), h(lj)h(l_j) is set with uniform distribution.

5. Training

5.1. Few-shot and Semi-supervised learning

The final layer of GNN is a softmax mapping. We then use cross-entropy loss:
(Φ(T;Θ),Y)=kyklogP(Y=ykT)\ell(\Phi(\mathcal{T}; \Theta), Y) = -\sum_k y_k \log P(Y_* = y_k \, |\, \mathcal{T})
The semi-supervised setting is trained identically, but the initial label fields of x~j\tilde{x}_js will be filled with uniform distribution.

5.2. Active learning (with attention)

In active learning, the model has the intrinsic ability to query for one of the labels from {x~1,...,x~r}\{ \tilde{x}_1, ..., \tilde{x}_r \}.
The network will learn to ask for the most informative label to classify the sample xˉ\bar{x}.
The querying is done after the first layer of GNN by using a softmax attention over the unlabeled nodes of the graph.

Attention

We apply a function g(xi(1))R1g(\mathbf{x}^{(1)}_i) \in \mathbb{R}^1 that maps each unlabeld vector node to a scalar value.
A softmax is applied over the {1,...,r}\{1, ..., r\} scalar values obtained after applying gg:
rr: # unlabeled samples
Attention=Softmax(g(x{1,...,r}(1)))\text{Attention} = \text{Softmax}(g(\mathbf{x}^{(1)}_{\{1,...,r\}}))
To query only one sample we set all elements to zero except for one. → Attention\text{Attention}'
At training, model randomly samples one value based on its multinomial probability.
At test, model just keeps the maximum value.
Then we multiply this with the label vectors
wh(li)=Attention,h(l{1,...,r})w \cdot h(l_{i*}) = \langle \text{Attention}', h(l_{\{1, ..., r\}}) \rangle
(ww is scaling factor)
This value is then summed to the current representation.
xi(1)=[Gc(xi(0)),xi(0)]=[Gc(xi(0)),(ϕ(xi),h(li))]\mathbf{x}^{(1)}_{i*} = [\text{Gc}(\mathbf{x}^{(0)}_{i*}), \mathbf{x}^{(0)}_{i*}] = [\text{Gc}(\mathbf{x}^{(0)}_{i*}), (\phi(x_{i*}), h(l_{i*}))]

6. Results

6.1. Few-shot learning

Omniglot

# of parameters: 5M(TCML)\sim5\text{M} (\text{TCML}), 300K(3 layers GNN)\sim300 \text{K}(3 \text{ layers GNN})
Omniglot:
Omniglot
1,623 characters X 20 examples for each characters

Mini-ImageNet

# of parameters: 11M(TCML)\sim 11\text{M} (\text{TCML}), 400K(3 layers GNN)\sim 400 \text{K}(3 \text{ layers GNN})
Mini-ImageNet:
Originally introduced by Vinyals et al.(2016)
Mini-ImageNet
Divided into 64 training, 16 validation, 20 testing classes each containing 600 examples.

6.2. Semi-supervised learning

Omniglot

Mini-ImageNet

6.3. Active learning

Random: Network chooses a random sample to be labeled, instead of one that maximally reduces the loss of the classification task T\mathcal{T}