1. Engineering/BentoML

[BentoML] 2. Simple GNN with BentoML

Honestree 2022. 10. 20. 17:31

1.데이터 생성 및 전처리

  • 기본적인 격자(grid)의 모양의 그래프만 사용
  • 무방향 확률 그래프 모델(undirected probability graphical model)
  • 데이터 생성 코드
import os import pickle import numpy as np from utils.topology import NetworkTopology, get_msg_graph from model.gt_inference import Enumerate from scipy.sparse import coo_matrix`

def mkdir(dir\_name):  
if not os.path.exists(dir\_name):  
os.makedirs(dir\_name)  
print('made directory {}'.format(dir\_name))

def main():  
num\_nodes\_I = 9  
std\_J\_A = 1/3  
std\_b\_A = 0.25  
save\_dir = './data\_temp'  
print('Generating training graphs!')

try:
    mkdir(save_dir)
    mkdir(os.path.join(save_dir, "train"))
except OSError:
    pass

for sample_id in range(100):
  seed_train = int(str(1111) + str(sample_id))
  topology = NetworkTopology(num_nodes=num_nodes_I, seed=seed_train)
  npr = np.random.RandomState(seed=seed_train)
  graph = {}
  G, W = topology.generate(topology='grid')
  J = npr.normal(0, std_J_A, size=[num_nodes_I, num_nodes_I])
  J = (J + J.transpose()) / 2.0
  J = J * W

  b = npr.normal(0, std_b_A, size=[num_nodes_I, 1])

  # Enumerate
  model = Enumerate(W, J, b)
  prob_gt = model.inference()

  graph['prob_gt'] = prob_gt # shape N x 2
  graph['J'] = coo_matrix(J)  # shape N X N
  graph['b'] = b  # shape N x 1
  graph['seed_train'] = seed_train
  graph['stdJ'] = std_J_A
  graph['stdb'] = std_b_A

  msg_node, msg_adj = get_msg_graph(G)
  msg_node, msg_adj = np.array(msg_node), np.array(msg_adj)
  idx_msg_edge = np.transpose(np.nonzero(msg_adj))
  J_msg = J[msg_node[:, 0], msg_node[:, 1]].reshape(-1, 1)

  graph['msg_node'] = msg_node
  graph['idx_msg_edge'] = idx_msg_edge
  graph['J_msg'] = J_msg

  file_name = os.path.join(save_dir, "train", 'graph_{}_nn{}_{:07d}.p'.format('grid', num_nodes_I, sample_id))
  with open(file_name, 'wb') as f:
    pickle.dump(graph, f)

if **name** == '**main**':  
main()
  • Train / val 데이터: validation 데이터는 random seed만 변경하여 생성
    • Train / val = 100개 / 10개 데이터

1.1 데이터 구조

  • 확률 그래프 모델

$$ p(\mathbf{x}) = {1\over{Z}}\ exp(\ \mathbf{b}\ \cdot\ \mathbf{x}\ +\ \mathbf{x \cdot J \cdot x} ) $$

  • 각 노드들이 -1 / 1의 state(x) 를 가지며 위의 식을 사용해 joint distribution값을 계산가능
    • 하지만 위의 식은 개별 노드들의 대한 확률값을 구하는 식이 아니며 이를 위해선 추가적인 계산 요구
      • Enumerate: 모든 경우의 수를 확률 테이블로 보관하여 개별 확률값을 계산하는 방법
    • Ising model
      • b: self-bias / J: coupling-strength / x: node state
  • Enumerate를 통해서 target value만 계산하고 J, b 만을 사용해 각 노드별 확률값을 구하는 것이 neural entwork의 목표
  • b: 각 노드별 bias 값들의 벡터
  • msg_node: 그래프의 연결상태를 노드순서로 나열한 집합
  • J_msg: msg_node에 맞춰서 각 edge의 coupling strength의 집합

2. 학습 모델 작성

2.1 Simple GNN

  • Gated Graph Neural Network(GG-NN)
    • 확인을 위해 단순히 input을 받아서 출력하는 predict 함수 작성
# gnn.py

import numpy as np  
import torch  
import torch.nn as nn

EPS = float(np.finfo(np.float32).eps)  
**all** = \['NodeGNN'\]

class NodeGNN(nn.Module):  
def **init**(self):  
""" A simplified implementation of NodeGNN """  
super(NodeGNN, self).**init**()  
self.hidden\_dim = 16  
self.num\_prop = 5  
self.aggregate\_type = 'sum'

# message function
self.msg_func = nn.Sequential(*[
  nn.Linear(2 * self.hidden_dim + 8, 64),
  nn.ReLU(),
  nn.Linear(64, self.hidden_dim)
])

# update function
self.update_func = nn.GRUCell(
  input_size=self.hidden_dim, hidden_size=self.hidden_dim)

# output function
self.output_func = nn.Sequential(*[
  nn.Linear(self.hidden_dim + 2, 64),
  nn.ReLU(),
  nn.Linear(64, 2),
])
self.loss_func = nn.KLDivLoss(reduction='batchmean')


def forward(self, J\_msg, b, msg\_node, target=None):  
num\_node = b.shape\[0\]  
num\_edge = msg\_node.shape\[0\]

edge_in = msg_node[:, 0]
edge_out = msg_node[:, 1].contiguous()

ff_in = torch.cat([b[edge_in], -b[edge_in], J_msg, -J_msg], dim=1)
ff_out = torch.cat([-b[edge_out], b[edge_out], -J_msg, J_msg], dim=1)

state = torch.zeros(num_node, self.hidden_dim).to(b.device)

def _prop(state_prev):
  # 1. compute messages
  state_in = state_prev[edge_in, :]  # shape |E| X D
  state_out = state_prev[edge_out, :]  # shape |E| X D
  msg = self.msg_func(torch.cat([state_in, ff_in, state_out, ff_out], dim=1)) # shape: |E| X D
  # 2. aggregate message
  scatter_idx = edge_out.view(-1, 1).expand(-1, self.hidden_dim)
  msg_agg = torch.zeros(num_node, self.hidden_dim).to(b.device) # shape: |V| X D
  msg_agg = msg_agg.scatter_add(0, scatter_idx, msg)
  avg_norm = torch.zeros(num_node).to(b.device).scatter_add_(0, edge_out, torch.ones(num_edge).to(b.device))
  msg_agg /= (avg_norm.view(-1, 1) + EPS)
  # 3. update state
  state_new = self.update_func(msg_agg, state_prev)  # GRU update
  return state_new

# propagation
for tt in range(self.num_prop):
  state = _prop(state)

# output
y = self.output_func(torch.cat([state, b, -b], dim=1))
y = torch.log_softmax(y, dim=1)
loss = self.loss_func(y, target)
return y, loss

def predict(self, J\_msg, b, msg\_node, prob\_gt):

J_msg = J_msg[0].long()
b = b[0].long()
msg_node = msg_node[0].long()

num_node = b.shape[0]
num_edge = msg_node.shape[0]

edge_in = msg_node[:, 0]
edge_out = msg_node[:, 1].contiguous()

ff_in = torch.cat([b[edge_in], -b[edge_in], J_msg, -J_msg], dim=1)
ff_out = torch.cat([-b[edge_out], b[edge_out], -J_msg, J_msg], dim=1)

state = torch.zeros(num_node, self.hidden_dim).to(b.device)

def _prop(state_prev):
  # 1. compute messages
  state_in = state_prev[edge_in, :]  # shape |E| X D
  state_out = state_prev[edge_out, :]  # shape |E| X D
  msg = self.msg_func(torch.cat([state_in, ff_in, state_out, ff_out], dim=1)) # shape: |E| X D
  # 2. aggregate message
  scatter_idx = edge_out.view(-1, 1).expand(-1, self.hidden_dim)
  msg_agg = torch.zeros(num_node, self.hidden_dim).to(b.device) # shape: |V| X D
  msg_agg = msg_agg.scatter_add(0, scatter_idx, msg)
  avg_norm = torch.zeros(num_node).to(b.device).scatter_add_(0, edge_out, torch.ones(num_edge).to(b.device))
  msg_agg /= (avg_norm.view(-1, 1) + EPS)
  # 3. update state
  state_new = self.update_func(msg_agg, state_prev)  # GRU update
  return state_new

# propagation
for tt in range(self.num_prop):
  state = _prop(state)

# output
res = dict()
y = self.output_func(torch.cat([state, b, -b], dim=1))
y = torch.log_softmax(y, dim=1)
loss = self.loss_func(y, prob_gt)
res["prob"] = np.exp(y.detach().cpu().numpy())
res["loss"] = loss.detach().cpu().numpy()

3. 모델 학습

  • Neural Network 학습 과정을 관리할 runner 클래스 작성
# inference\_runner.py

from **future** import (division, print\_function)  
import os  
import numpy as np  
from collections import defaultdict  
from tqdm import tqdm

import torch  
import torch.utils.data  
import torch.optim as optim  
from model.gnn import NodeGNN  
from dataset.dataloader import \*  
import bentoml

EPS = float(np.finfo(np.float32).eps)  
**all** = \['NeuralInferenceRunner'\]

class NeuralInferenceRunner(object):

def train(self):  
print("=== START TRAINING ===")  
\# create data loader  
train\_dataset = MyDataloader(split='train')  
val\_dataset = MyDataloader(split='val')

train_loader = torch.utils.data.DataLoader(
  train_dataset,
  batch_size=10,
  shuffle=True,
  collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(
  val_dataset,
  batch_size=10,
  shuffle=True,
  collate_fn=val_dataset.collate_fn)

# create models
model = NodeGNN()

# create optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(
    params,
    lr=0.001)

# reset gradient
optimizer.zero_grad()

#========================= Training Loop =============================#
best_val_loss = np.inf
for epoch in range(10):
  print("=== EPOCH : {} ===".format(epoch))
  # ===================== validation ============================ #
  model.eval()
  for data in tqdm(val_loader, desc="VALIDATION"):
    with torch.no_grad():
      _, loss = model(data['J_msg'], data['b'], data['msg_node'], target=data['prob_gt'])

  # ====================== training ============================= #
  model.train()
  for data in tqdm(train_loader, desc="TRAINING"):
    optimizer.zero_grad()
    _, loss = model(data['J_msg'], data['b'], data['msg_node'],  target=data['prob_gt'])
    loss.backward()
    optimizer.step()

snapshot(model, optimizer)
# bentoml.sklearn.save('simple_gnn', model)
return best_val_loss

def predict(self):
    return 0

def snapshot(model, optimizer):  
model\_snapshot = {  
"model": model.state\_dict(),  
"optimizer": optimizer.state\_dict(),  
}  
torch.save(model\_snapshot,  
os.path.join("model\_snapshot.pth"))
  • 학습
    $ python run\_exp\_local.py

4. bentoML 서비스

  • JSON file 또는 Pickle file을 입력으로 받을 수 있는 bentoML 서비스
  • bentofile.yaml 작성
# bentofile.yaml

service: "service:svc"  
labels:  
owner: bentoml-team  
stage: demo  
include:

-   "\*.py"  
    python:  
    packages:
    -   torch

-   service 작성
``` python
# service.py

import bentoml
import numpy as np
from bentoml.io import JSON

runner = bentoml.pytorch.load_runner(
    "simple_gnn:latest",
    predict_fn_name="predict"
)

svc = bentoml.Service("simple_gnn", runners=[runner])

@svc.api(input=JSON(), output=JSON())
def predict(input_arr: JSON):
    J_msg, b, msg_node, prob_gt = np.array(input_arr['J_msg']), np.array(input_arr['b']), np.array(input_arr['msg_node']), np.array(input_arr['prob_gt'])
    res = runner.run(J_msg, b, msg_node, prob_gt)
    return res
$ bentoml build
$ bentoml serve simple_gnn:latest --reload
  • 서비스 실행
    • input type을 Jsoninput으로 할 경우 import 에러가 발생하여 우선 numpy로 작성
  • 테스트
    • 입력
    {
    "J_msg":[[ 0.32481338],
           [ 0.10008402],
           [ 0.32481338],
           [-0.29213125],
           [ 0.13292147],
           [-0.29213125],
           [ 0.27904898],
           [ 0.10008402],
           [ 0.09995882],
           [ 0.18141661],
           [ 0.13292147],
           [ 0.09995882],
           [-0.14101666],
           [ 0.07256332],
           [ 0.27904898],
           [-0.14101666],
           [-0.00166374],
           [ 0.18141661],
           [ 0.01135302],
           [ 0.07256332],
           [ 0.01135302],
           [ 0.17131865],
           [-0.00166374],
           [ 0.17131865]],
    "b":[[-0.11840663],
           [-0.24949627],
           [ 0.04009551],
           [ 0.58055465],
           [-0.19123979],
           [-0.0364058 ],
           [ 0.01783222],
           [ 0.10532287],
           [-0.27045844]],
    "msg_node":[[0, 1],
           [0, 3],
           [1, 0],
           [1, 2],
           [1, 4],
           [2, 1],
           [2, 5],
           [3, 0],
           [3, 4],
           [3, 6],
           [4, 1],
           [4, 3],
           [4, 5],
           [4, 7],
           [5, 2],
           [5, 4],
           [5, 8],
           [6, 3],
           [6, 7],
           [7, 4],
           [7, 6],
           [7, 8],
           [8, 5],
           [8, 7]],
    "prob_gt":[[0.42438148, 0.57561852],
           [0.35530727, 0.64469273],
           [0.55791833, 0.44208167],
           [0.7471271 , 0.2528729 ],
           [0.41531799, 0.58468201],
           [0.51015097, 0.48984903],
           [0.55322386, 0.44677614],
           [0.52453344, 0.47546656],
           [0.37538241, 0.62461759]]
    }
    • 결과
  • { "prob": [ [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ] ], "loss": 0.02680271677672863 }