Search

A Generalist Neural Algorithmic Learner

Type
Application
Published year
2022
Journal / Conference
LoG
Keyword
GNN
Multi-task
Algorithmic learner
Status
Done
Language
๐Ÿ‡บ๐Ÿ‡ธ
๐Ÿ‡ฐ๐Ÿ‡ท
Blog upload
Yes
Date
2023/01/13
1 more property

Summary

โ€ข
Neural network, ๊ทธ ์ค‘์—์„œ๋„ GNN์€ CLRS ์ฑ…์— ์žˆ๋Š” ์ „ํ†ต์ ์ธ ์ปดํ“จํ„ฐ ๊ณผํ•™ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋‹ค.
โ€ข
ํŠน์ • ์•Œ๊ณ ๋ฆฌ์ฆ˜ ํ•˜๋‚˜๋งŒ์œผ๋กœ ์–ด๋–ค ๋ฌธ์ œ๋ฅผ ํ’€๊ธฐ ์–ด๋ ค์šธ ๋•Œ Generalist Neural Algorithmic Learner๊ฐ€ ํ•„์š”ํ•˜๋‹ค.
โ€ข
Multi-task learning์„ ์•ˆ์ •ํ™”ํ•˜๋Š”๋ฐ ์žˆ์–ด chunking mechanism์ด ์ค‘์š”ํ–ˆ๋‹ค.

Starting point: A benchmark to train neural computer scientists

clrs
deepmind
โ€ข
Neural network๊ฐ€ ์ „ํ†ต์ ์ธ CS algorithm์„ ํ’€๋„๋ก ํ•™์Šต์‹œํ‚ฌ ์ˆ˜ ์žˆ์„๊นŒ?
โ€ข
์ดํ›„์—๋Š” ์ด ์ง€์‹์„ natural input์— ์ ์šฉํ•ด์„œ real-world problem์„ ํ’€ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค.
โ€ข
๊ทธ๋ฆฌ๊ณ , ๊ทธ algorithm์„ ์—ฌ๋Ÿฌ ๊ฐœ ํ•™์Šต์‹œํ‚ฌ ์ˆ˜ ์žˆ์„๊นŒ?
โ€ข
์ด๋Ÿฐ ๋ฌธ์ œ๋“ค์€ recurrent architecture๋กœ modeling ํ•  ์ˆ˜ ์žˆ๋‹ค.
โ—ฆ
LSTM, Transformer, ConvNet, GNN ๋“ฑ.
โ—ฆ
์ €์ž๋“ค์€ GNN์„ ์‚ฌ์šฉ.
Benchmark: Introduction to Algorithms: CLRS

Representation

โ€ข
๋ชจ๋“  ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ graph ํ˜•ํƒœ๋กœ ํ‘œํ˜„๋˜์—ˆ๋‹ค.
โ€ข
๊ฐ ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ์ •ํ•ด์ง„ ์ˆซ์ž๋งŒํผ์˜ โ€œprobeโ€๋กœ ํ‘œํ˜„๋œ๋‹ค.
โ€ข
์˜ˆ๋ฅผ ๋“ค์–ด, insertion sort ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ์•„๋ž˜ 6๊ฐ€์ง€์˜ probe๋กœ ๊ตฌ์„ฑ๋œ๋‹ค.
โ—ฆ
'pos': (Stage. INPUT, Location.NODE, Type.SCALAR) โ†’ ๊ฐ node์˜ ID
โ—ฆ
'key': (Stage.INPUT, Location. NODE, Type.SCALAR) โ†’ ์ •๋ ฌํ•  value
โ—ฆ
'pred': (Stage.OUTPUT, Location. NODE, Type.POINTER) โ†’ ์ตœ์ข… node ์ˆœ์„œ
โ—ฆ
'pred h': (Stage. HINT, Location. NODE, Type. POINTER) โ†’ ์‹คํ–‰ํ•˜๋ฉด์„œ ๋ฐ”๋€Œ๋Š” node ์ˆœ์„œ
โ—ฆ
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE) โ†’ insertionํ•  index
โ—ฆ
'j': (Stage.HINT, Location.NODE, Type.MASK_ONE) โ†’ tracking ํ•  index
โ€ข
Probe๋Š” input, output, hint ์ค‘ ํ•˜๋‚˜.
โ€ข
Input๊ณผ output์€ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ์‹คํ–‰ ์‹œ ๊ณ ์ •, hint๋Š” ์‹คํ–‰๋˜๋ฉด์„œ ๋ฐ”๋€œ.
โ—ฆ
๋”ฐ๋ผ์„œ, ๋ชจ๋“  sorting algorithm์€ input๊ณผ output์ด ๊ฐ™๊ณ  hint๋งŒ ๋‹ค๋ฆ„.

Representation: Encoding

์ฒ˜์Œ positional ID (node์˜ ID)๋Š” encoder์— ์˜ํ•ด vector ํ˜•ํƒœ๋กœ ํ‘œํ˜„๋จ.
Algorithm์˜ ๋Œ€์ƒ์ด ๋˜๋Š” value๋„ encoder์— ์˜ํ•ด vector ํ˜•ํƒœ๋กœ ํ‘œํ˜„๋˜๊ณ , pos์— ๋”ํ•ด์ง.
Hint๋กœ ์‚ฌ์šฉ๋˜๋Š” pred_h๋„ encoder์— ์˜ํ•ด pointer๋กœ mapping๋จ.
Insertion์— ํ•„์š”ํ•œ index๋„ encoder์— ์˜ํ•ด ๋ฐ”๋€Œ์–ด ๋”ํ•ด์ง.

Representation: Decoding

Processing step์€ algorithm์— ๋”ฐ๋ผ ๋ฐ”๋€Œ์ง€ ์•Š๋Š”๋‹ค! Processing parameter๋Š” ๊ณต์œ ๋จ.

Training

Trainํ•  ๋•Œ๋Š” hint๋ฅผ ์‚ฌ์šฉํ•˜๊ณ , testํ•  ๋•Œ๋Š” hint๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  input๋งŒ ์ฃผ๊ณ  output์„ ๋งž์ถ˜๋‹ค.
Trainํ•  ๋•Œ๋Š” output loss ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์ค‘๊ฐ„ ๊ณผ์ •์ธ hint loss์— ๋Œ€ํ•ด์„œ๋„ loss๋ฅผ ์ค€๋‹ค.

Details

๋ช‡ ๊ฐ€์ง€ training detail์ด ์‹ค์ œ ํ•™์Šต์— ๋งค์šฐ ์ค‘์š”ํ–ˆ๋‹ค.
Node ๊ฐœ์ˆ˜๊ฐ€ 16์ดํ•˜์ธ graph์— ๋Œ€ํ•ด์„œ trainํ•˜๊ณ , node ๊ฐœ์ˆ˜๊ฐ€ 64์ธ sample์— ๋Œ€ํ•ด test ํ–ˆ๋‹ค.
Train, test ์‹œ ๋ชจ๋‘ trajectory์˜ length๋„ ์ •ํ•ด์ ธ์žˆ๋‹ค.
In-distribution score์— ๋”ฐ๋ผ early stopping์„ ์ ์šฉํ–ˆ๋‹ค.

Why even care about building a generalist?

์™œ generalist algorithmic solver๊ฐ€ ํ•„์š”ํ•œ ๊ฒƒ์ผ๊นŒ?
์–ด๋–ค input์— ๋Œ€ํ•ด output์„ ๋งž์ถ”๋Š” ๊ฒƒ์€ problem solving ์ธ๋ฐ, ์šฐ๋ฆฌ๊ฐ€ ๋ณดํ†ต ์–ด๋–ป๊ฒŒ real-world์—์„œ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ณ  ์žˆ๋Š”์ง€ ์‚ดํŽด๋ณด์ž.
โ€ข
Example: Route recommendation
Google map์—์„œ ์ตœ์ ์˜ route ์ถ”์ฒœ์„ ํ•ด์•ผํ•˜๋Š” ์ƒํ™ฉ์„ ๊ฐ€์ •ํ•ด๋ณด์ž. ๊ทธ๋ƒฅ raw map ์ƒํƒœ๋กœ๋Š” ๋ฌธ์ œ๋ฅผ ํ’€ ์ˆ˜ ์—†๋‹ค. Graph ํ˜•ํƒœ๋กœ ๋ฌธ์ œ๋ฅผ ์ถ”์ƒํ™”ํ•œ๋‹ค์Œ, CLRS์—์„œ ๊ณต๋ถ€ํ–ˆ๋˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์ ์šฉํ•ด์„œ, (์•„๋งˆ ๋Œ€๋ถ€๋ถ„ Dijkstraโ€™s algorithm์„ ์จ์„œ) ์ตœ๋‹จ ๊ฑฐ๋ฆฌ๋ฅผ ์•Œ์•„๋‚ด๊ณ , ๊ทธ๊ฒƒ์„ ์ •๋‹ต์œผ๋กœ ํ•˜์—ฌ raw map์œผ๋กœ ๋‹ค์‹œ ๋ณ€ํ™˜ํ•  ๊ฒƒ์ด๋‹ค.
ํ•˜์ง€๋งŒ real-world์—์„œ๋Š” ๊ผญ โ€˜์ตœ๋‹จ๊ฑฐ๋ฆฌโ€™๊ฐ€ ์ตœ์ ์˜ route๋Š” ์•„๋‹ ์ˆ˜ ์žˆ๋‹ค. ๊ตํ†ต์ƒํ™ฉ, ๋†’์ด, ๋“ฑ๋“ฑ ๋‹ค์–‘ํ•œ factor๊ฐ€ ์กด์žฌํ•˜๊ณ , ์ด๋Ÿด ๊ฒฝ์šฐ์—๋Š” algorithm์ด obviousํ•˜์ง€ ์•Š๋‹ค.

Details

์ €์ž๋“ค์€ neural algorithmic reasoning์œผ๋กœ ํŒŒ๋ž€ ๋ถ€๋ถ„์˜ bottleneck์„ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋‹ค๊ณ  ์ฃผ์žฅํ•œ๋‹ค.
๊ทธ๋ฆฌ๊ณ  ์ž˜ ํ•™์Šต๋œ generalist processor๊ฐ€ ๋นจ๊ฐ„ bottleneck์„ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋‹ค๊ณ  ์ฃผ์žฅํ•œ๋‹ค.
โ—ฆ
์ด generalist processor๋Š” ์ค‘์š”ํ•œ ์•Œ๊ณ ๋ฆฌ์ฆ˜์˜ latent space๋ฅผ ์ž˜ ๊ณต์œ ํ•œ๋‹ค๋ฉด, ๋” ์ด์ƒ ํŠน์ • ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ๊ณ ๋ฅด์ง€ ์•Š์•„๋„ ๋˜๊ณ , ์ตœ์ ์˜ ์ •๋‹ต์„ ์ถ”์ฒœํ•ด์ค„ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค.

To get a generalist, first we need a good specialist!

๊ทธ๋Ÿฌ๋‚˜ CLRS์˜ 30๊ฐœ์˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜์— ๋Œ€ํ•ด ๋ชจ๋‘ ๋‹ค ์ž˜ ํ•˜๋Š” generalist๋ฅผ naรฏveํ•˜๊ฒŒ ํ•™์Šตํ•˜๋Š” ๊ฒƒ์€ ์–ด๋ ค์› ๋‹ค.
์ตœ๊ทผ NE++ (Xhonneux et al., NeurIPS'21) ์—์„œ, ์ด๋Ÿฐ ์•Œ๊ณ ๋ฆฌ์ฆ˜๋“ค์€ ์„œ๋กœ ์—ฐ๊ด€์ด ๋†’์€ ๊ฒƒ๋“ค๋ผ๋ฆฌ (e.g. Prim + Dijkstra) ํ•™์Šต๋  ๋•Œ๋งŒ ์ž˜ ํ•™์Šต๋œ๋‹ค๋Š” ์–˜๊ธฐ๊ฐ€ ์žˆ์—ˆ๋‹ค.
DeepMind ์ €์ž๋“ค์ด ๋ฐœ๊ฒฌํ•œ ํ•™์Šต ์–‘์ƒ:
ํ•™์Šต์ด ๋ถˆ์•ˆ์ •ํ•œ task ๋“ค์€ ๋‹ค๋ฅธ task์—๋„ ํฌ๊ฒŒ ์˜ํ–ฅ์„ ์ค€๋‹ค.
๋”ฐ๋ผ์„œ, single-task ๋‹จ์œ„๋กœ ์•ˆ์ •์„ฑ์„ ๋จผ์ € ๋†’์—ฌ ๋†“์•„์•ผ ํ•œ๋‹ค.

Bucket list of improvements

โ€ข
Teacher forcing ์—†์• ๊ธฐ
โ€ข
Training data augmentation (16๊ฐœ ์ดํ•˜์˜ node size ๊ธฐ์ค€)
โ€ข
Soft hint propagation (hint์— ๋Œ€ํ•ด argmax๋ฅผ ์ทจํ•˜์ง€ ์•Š๊ณ , softmax๋ฅผ ์ทจํ•œ๋‹ค)
โ€ข
Static hint elimination (hint๊ฐ€ ๋ฐ”๋€Œ์ง€ ์•Š์œผ๋ฉด, input์œผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค)
โ€ข
Encoder initialization (Xavier) + Gradient clipping
โ€ข
Randomized positional embedding
โ€ข
Sinkhorn operator๋ฅผ ์‚ฌ์šฉํ•œ Permutation decoder
โ€ข
Processor์— gating mechanism
โ€ข
Triplet reasoning
tijk=ฯˆt(hi,hj,hk,eij,eik,ekj,g)t_{ijk} = \psi_t (h_i, h_j, h_k, e_{ij}, e_{ik}, e_{kj}, g)
hij=ฯ•t(maxโกktijk)h_{ij} = \phi_t(\max_k t_{ijk})

Results

Final step to the generalist

Multi-task learning์„ ์œ„ํ•ด chunking mechanism์ด ์ค‘์š”ํ–ˆ๋‹ค.
1.
Trajectory์˜ length๋Š” 16์œผ๋กœ ์„ค์ •ํ•œ๋‹ค.
2.
16๋ณด๋‹ค ์งง์€ sample๋“ค์€ padding ๋˜์ง€ ์•Š๊ณ  ๊ทธ๋ƒฅ ๋‹ค์Œ task๊ฐ€ concat ๋œ๋‹ค.
์ฆ‰, padding ์—†์ด ์ญ‰ ์—ฐ๊ฒฐ๋œ ํ˜•ํƒœ์ด๋ฉฐ, ์ด๋ ‡๊ฒŒ ๊ตฌ์„ฑํ•˜๋ฉด memory-efficient ํ•˜๊ณ  ํ•™์Šต์ด ์•ˆ์ •ํ™”๋˜๋Š” ํšจ๊ณผ๊ฐ€ ์žˆ์—ˆ๋‹ค.
์ด๋ ‡๊ฒŒ ํ•ด๋„ ๋˜๋Š” ์ด์œ ๋Š” CLRS-30 task๋Š” Markovian ์ด๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

Single generalist that matches the thirty specialists

Chunking helps significantly

Reference