Summary
โข
Neural network, ๊ทธ ์ค์์๋ GNN์ CLRS ์ฑ
์ ์๋ ์ ํต์ ์ธ ์ปดํจํฐ ๊ณผํ ์๊ณ ๋ฆฌ์ฆ์ ํ์ตํ ์ ์๋ค.
โข
ํน์ ์๊ณ ๋ฆฌ์ฆ ํ๋๋ง์ผ๋ก ์ด๋ค ๋ฌธ์ ๋ฅผ ํ๊ธฐ ์ด๋ ค์ธ ๋ Generalist Neural Algorithmic Learner๊ฐ ํ์ํ๋ค.
โข
Multi-task learning์ ์์ ํํ๋๋ฐ ์์ด chunking mechanism์ด ์ค์ํ๋ค.
Starting point: A benchmark to train neural computer scientists
โข
Neural network๊ฐ ์ ํต์ ์ธ CS algorithm์ ํ๋๋ก ํ์ต์ํฌ ์ ์์๊น?
โข
์ดํ์๋ ์ด ์ง์์ natural input์ ์ ์ฉํด์ real-world problem์ ํ ์ ์์ ๊ฒ์ด๋ค.
โข
๊ทธ๋ฆฌ๊ณ , ๊ทธ algorithm์ ์ฌ๋ฌ ๊ฐ ํ์ต์ํฌ ์ ์์๊น?
โข
์ด๋ฐ ๋ฌธ์ ๋ค์ recurrent architecture๋ก modeling ํ ์ ์๋ค.
โฆ
LSTM, Transformer, ConvNet, GNN ๋ฑ.
โฆ
์ ์๋ค์ GNN์ ์ฌ์ฉ.
Benchmark: Introduction to Algorithms: CLRS
Representation
โข
๋ชจ๋ ์๊ณ ๋ฆฌ์ฆ์ graph ํํ๋ก ํํ๋์๋ค.
โข
๊ฐ ์๊ณ ๋ฆฌ์ฆ์ ์ ํด์ง ์ซ์๋งํผ์ โprobeโ๋ก ํํ๋๋ค.
โข
์๋ฅผ ๋ค์ด, insertion sort ์๊ณ ๋ฆฌ์ฆ์ ์๋ 6๊ฐ์ง์ probe๋ก ๊ตฌ์ฑ๋๋ค.
โฆ
'pos': (Stage. INPUT, Location.NODE, Type.SCALAR) โ ๊ฐ node์ ID
โฆ
'key': (Stage.INPUT, Location. NODE, Type.SCALAR) โ ์ ๋ ฌํ value
โฆ
'pred': (Stage.OUTPUT, Location. NODE, Type.POINTER) โ ์ต์ข
node ์์
โฆ
'pred h': (Stage. HINT, Location. NODE, Type. POINTER) โ ์คํํ๋ฉด์ ๋ฐ๋๋ node ์์
โฆ
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE) โ insertionํ index
โฆ
'j': (Stage.HINT, Location.NODE, Type.MASK_ONE) โ tracking ํ index
โข
Probe๋ input, output, hint ์ค ํ๋.
โข
Input๊ณผ output์ ์๊ณ ๋ฆฌ์ฆ ์คํ ์ ๊ณ ์ , hint๋ ์คํ๋๋ฉด์ ๋ฐ๋.
โฆ
๋ฐ๋ผ์, ๋ชจ๋ sorting algorithm์ input๊ณผ output์ด ๊ฐ๊ณ hint๋ง ๋ค๋ฆ.
Representation: Encoding
์ฒ์ positional ID (node์ ID)๋ encoder์ ์ํด vector ํํ๋ก ํํ๋จ.
Algorithm์ ๋์์ด ๋๋ value๋ encoder์ ์ํด vector ํํ๋ก ํํ๋๊ณ , pos์ ๋ํด์ง.
Hint๋ก ์ฌ์ฉ๋๋ pred_h๋ encoder์ ์ํด pointer๋ก mapping๋จ.
Insertion์ ํ์ํ index๋ encoder์ ์ํด ๋ฐ๋์ด ๋ํด์ง.
Representation: Decoding
Processing step์ algorithm์ ๋ฐ๋ผ ๋ฐ๋์ง ์๋๋ค! Processing parameter๋ ๊ณต์ ๋จ.
Training
Trainํ ๋๋ hint๋ฅผ ์ฌ์ฉํ๊ณ , testํ ๋๋ hint๋ฅผ ์ฌ์ฉํ์ง ์๊ณ input๋ง ์ฃผ๊ณ output์ ๋ง์ถ๋ค.
Trainํ ๋๋ output loss ๋ฟ๋ง ์๋๋ผ ์ค๊ฐ ๊ณผ์ ์ธ hint loss์ ๋ํด์๋ loss๋ฅผ ์ค๋ค.
Details
๋ช ๊ฐ์ง training detail์ด ์ค์ ํ์ต์ ๋งค์ฐ ์ค์ํ๋ค.
Node ๊ฐ์๊ฐ 16์ดํ์ธ graph์ ๋ํด์ trainํ๊ณ , node ๊ฐ์๊ฐ 64์ธ sample์ ๋ํด test ํ๋ค.
Train, test ์ ๋ชจ๋ trajectory์ length๋ ์ ํด์ ธ์๋ค.
In-distribution score์ ๋ฐ๋ผ early stopping์ ์ ์ฉํ๋ค.
Why even care about building a generalist?
์ generalist algorithmic solver๊ฐ ํ์ํ ๊ฒ์ผ๊น?
์ด๋ค input์ ๋ํด output์ ๋ง์ถ๋ ๊ฒ์ problem solving ์ธ๋ฐ, ์ฐ๋ฆฌ๊ฐ ๋ณดํต ์ด๋ป๊ฒ real-world์์ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ณ ์๋์ง ์ดํด๋ณด์.
โข
Example: Route recommendation
Google map์์ ์ต์ ์ route ์ถ์ฒ์ ํด์ผํ๋ ์ํฉ์ ๊ฐ์ ํด๋ณด์. ๊ทธ๋ฅ raw map ์ํ๋ก๋ ๋ฌธ์ ๋ฅผ ํ ์ ์๋ค. Graph ํํ๋ก ๋ฌธ์ ๋ฅผ ์ถ์ํํ๋ค์, CLRS์์ ๊ณต๋ถํ๋ ์๊ณ ๋ฆฌ์ฆ์ ์ ์ฉํด์, (์๋ง ๋๋ถ๋ถ Dijkstraโs algorithm์ ์จ์) ์ต๋จ ๊ฑฐ๋ฆฌ๋ฅผ ์์๋ด๊ณ , ๊ทธ๊ฒ์ ์ ๋ต์ผ๋ก ํ์ฌ raw map์ผ๋ก ๋ค์ ๋ณํํ ๊ฒ์ด๋ค.
ํ์ง๋ง real-world์์๋ ๊ผญ โ์ต๋จ๊ฑฐ๋ฆฌโ๊ฐ ์ต์ ์ route๋ ์๋ ์ ์๋ค. ๊ตํต์ํฉ, ๋์ด, ๋ฑ๋ฑ ๋ค์ํ factor๊ฐ ์กด์ฌํ๊ณ , ์ด๋ด ๊ฒฝ์ฐ์๋ algorithm์ด obviousํ์ง ์๋ค.
Details
์ ์๋ค์ neural algorithmic reasoning์ผ๋ก ํ๋ ๋ถ๋ถ์ bottleneck์ ํด๊ฒฐํ ์ ์๋ค๊ณ ์ฃผ์ฅํ๋ค.
๊ทธ๋ฆฌ๊ณ ์ ํ์ต๋ generalist processor๊ฐ ๋นจ๊ฐ bottleneck์ ํด๊ฒฐํ ์ ์๋ค๊ณ ์ฃผ์ฅํ๋ค.
โฆ
์ด generalist processor๋ ์ค์ํ ์๊ณ ๋ฆฌ์ฆ์ latent space๋ฅผ ์ ๊ณต์ ํ๋ค๋ฉด, ๋ ์ด์ ํน์ ์๊ณ ๋ฆฌ์ฆ์ ๊ณ ๋ฅด์ง ์์๋ ๋๊ณ , ์ต์ ์ ์ ๋ต์ ์ถ์ฒํด์ค ์ ์์ ๊ฒ์ด๋ค.
To get a generalist, first we need a good specialist!
๊ทธ๋ฌ๋ CLRS์ 30๊ฐ์ ์๊ณ ๋ฆฌ์ฆ์ ๋ํด ๋ชจ๋ ๋ค ์ ํ๋ generalist๋ฅผ naรฏveํ๊ฒ ํ์ตํ๋ ๊ฒ์ ์ด๋ ค์ ๋ค.
์ต๊ทผ NE++ (Xhonneux et al., NeurIPS'21) ์์, ์ด๋ฐ ์๊ณ ๋ฆฌ์ฆ๋ค์ ์๋ก ์ฐ๊ด์ด ๋์ ๊ฒ๋ค๋ผ๋ฆฌ (e.g. Prim + Dijkstra) ํ์ต๋ ๋๋ง ์ ํ์ต๋๋ค๋ ์๊ธฐ๊ฐ ์์๋ค.
DeepMind ์ ์๋ค์ด ๋ฐ๊ฒฌํ ํ์ต ์์:
ํ์ต์ด ๋ถ์์ ํ task ๋ค์ ๋ค๋ฅธ task์๋ ํฌ๊ฒ ์ํฅ์ ์ค๋ค.
๋ฐ๋ผ์, single-task ๋จ์๋ก ์์ ์ฑ์ ๋จผ์ ๋์ฌ ๋์์ผ ํ๋ค.
Bucket list of improvements
โข
Teacher forcing ์์ ๊ธฐ
โข
Training data augmentation (16๊ฐ ์ดํ์ node size ๊ธฐ์ค)
โข
Soft hint propagation (hint์ ๋ํด argmax๋ฅผ ์ทจํ์ง ์๊ณ , softmax๋ฅผ ์ทจํ๋ค)
โข
Static hint elimination (hint๊ฐ ๋ฐ๋์ง ์์ผ๋ฉด, input์ผ๋ก ์ฌ์ฉํ๋ค)
โข
Encoder initialization (Xavier) + Gradient clipping
โข
Randomized positional embedding
โข
Sinkhorn operator๋ฅผ ์ฌ์ฉํ Permutation decoder
โข
Processor์ gating mechanism
โข
Triplet reasoning
Results
Final step to the generalist
Multi-task learning์ ์ํด chunking mechanism์ด ์ค์ํ๋ค.
1.
Trajectory์ length๋ 16์ผ๋ก ์ค์ ํ๋ค.
2.
16๋ณด๋ค ์งง์ sample๋ค์ padding ๋์ง ์๊ณ ๊ทธ๋ฅ ๋ค์ task๊ฐ concat ๋๋ค.
์ฆ, padding ์์ด ์ญ ์ฐ๊ฒฐ๋ ํํ์ด๋ฉฐ, ์ด๋ ๊ฒ ๊ตฌ์ฑํ๋ฉด memory-efficient ํ๊ณ ํ์ต์ด ์์ ํ๋๋ ํจ๊ณผ๊ฐ ์์๋ค.
์ด๋ ๊ฒ ํด๋ ๋๋ ์ด์ ๋ CLRS-30 task๋ Markovian ์ด๊ธฐ ๋๋ฌธ์ด๋ค.