1. Engineering/BentoML
[BentoML] 2. Simple GNN with BentoML
Honestree
2022. 10. 20. 17:31
1.데이터 생성 및 전처리
- 기본적인 격자(grid)의 모양의 그래프만 사용
- 무방향 확률 그래프 모델(undirected probability graphical model)
- 데이터 생성 코드
import os import pickle import numpy as np from utils.topology import NetworkTopology, get_msg_graph from model.gt_inference import Enumerate from scipy.sparse import coo_matrix`
def mkdir(dir\_name):
if not os.path.exists(dir\_name):
os.makedirs(dir\_name)
print('made directory {}'.format(dir\_name))
def main():
num\_nodes\_I = 9
std\_J\_A = 1/3
std\_b\_A = 0.25
save\_dir = './data\_temp'
print('Generating training graphs!')
try:
mkdir(save_dir)
mkdir(os.path.join(save_dir, "train"))
except OSError:
pass
for sample_id in range(100):
seed_train = int(str(1111) + str(sample_id))
topology = NetworkTopology(num_nodes=num_nodes_I, seed=seed_train)
npr = np.random.RandomState(seed=seed_train)
graph = {}
G, W = topology.generate(topology='grid')
J = npr.normal(0, std_J_A, size=[num_nodes_I, num_nodes_I])
J = (J + J.transpose()) / 2.0
J = J * W
b = npr.normal(0, std_b_A, size=[num_nodes_I, 1])
# Enumerate
model = Enumerate(W, J, b)
prob_gt = model.inference()
graph['prob_gt'] = prob_gt # shape N x 2
graph['J'] = coo_matrix(J) # shape N X N
graph['b'] = b # shape N x 1
graph['seed_train'] = seed_train
graph['stdJ'] = std_J_A
graph['stdb'] = std_b_A
msg_node, msg_adj = get_msg_graph(G)
msg_node, msg_adj = np.array(msg_node), np.array(msg_adj)
idx_msg_edge = np.transpose(np.nonzero(msg_adj))
J_msg = J[msg_node[:, 0], msg_node[:, 1]].reshape(-1, 1)
graph['msg_node'] = msg_node
graph['idx_msg_edge'] = idx_msg_edge
graph['J_msg'] = J_msg
file_name = os.path.join(save_dir, "train", 'graph_{}_nn{}_{:07d}.p'.format('grid', num_nodes_I, sample_id))
with open(file_name, 'wb') as f:
pickle.dump(graph, f)
if **name** == '**main**':
main()
- Train / val 데이터: validation 데이터는 random seed만 변경하여 생성
- Train / val = 100개 / 10개 데이터
1.1 데이터 구조
- 확률 그래프 모델
$$ p(\mathbf{x}) = {1\over{Z}}\ exp(\ \mathbf{b}\ \cdot\ \mathbf{x}\ +\ \mathbf{x \cdot J \cdot x} ) $$
- 각 노드들이 -1 / 1의 state(x) 를 가지며 위의 식을 사용해 joint distribution값을 계산가능
- 하지만 위의 식은 개별 노드들의 대한 확률값을 구하는 식이 아니며 이를 위해선 추가적인 계산 요구
- Enumerate: 모든 경우의 수를 확률 테이블로 보관하여 개별 확률값을 계산하는 방법
- Ising model
- b: self-bias / J: coupling-strength / x: node state
- 하지만 위의 식은 개별 노드들의 대한 확률값을 구하는 식이 아니며 이를 위해선 추가적인 계산 요구
- Enumerate를 통해서 target value만 계산하고 J, b 만을 사용해 각 노드별 확률값을 구하는 것이 neural entwork의 목표
- b: 각 노드별 bias 값들의 벡터
- msg_node: 그래프의 연결상태를 노드순서로 나열한 집합
- J_msg: msg_node에 맞춰서 각 edge의 coupling strength의 집합
2. 학습 모델 작성
2.1 Simple GNN
- Gated Graph Neural Network(GG-NN)
- 확인을 위해 단순히 input을 받아서 출력하는 predict 함수 작성
# gnn.py
import numpy as np
import torch
import torch.nn as nn
EPS = float(np.finfo(np.float32).eps)
**all** = \['NodeGNN'\]
class NodeGNN(nn.Module):
def **init**(self):
""" A simplified implementation of NodeGNN """
super(NodeGNN, self).**init**()
self.hidden\_dim = 16
self.num\_prop = 5
self.aggregate\_type = 'sum'
# message function
self.msg_func = nn.Sequential(*[
nn.Linear(2 * self.hidden_dim + 8, 64),
nn.ReLU(),
nn.Linear(64, self.hidden_dim)
])
# update function
self.update_func = nn.GRUCell(
input_size=self.hidden_dim, hidden_size=self.hidden_dim)
# output function
self.output_func = nn.Sequential(*[
nn.Linear(self.hidden_dim + 2, 64),
nn.ReLU(),
nn.Linear(64, 2),
])
self.loss_func = nn.KLDivLoss(reduction='batchmean')
def forward(self, J\_msg, b, msg\_node, target=None):
num\_node = b.shape\[0\]
num\_edge = msg\_node.shape\[0\]
edge_in = msg_node[:, 0]
edge_out = msg_node[:, 1].contiguous()
ff_in = torch.cat([b[edge_in], -b[edge_in], J_msg, -J_msg], dim=1)
ff_out = torch.cat([-b[edge_out], b[edge_out], -J_msg, J_msg], dim=1)
state = torch.zeros(num_node, self.hidden_dim).to(b.device)
def _prop(state_prev):
# 1. compute messages
state_in = state_prev[edge_in, :] # shape |E| X D
state_out = state_prev[edge_out, :] # shape |E| X D
msg = self.msg_func(torch.cat([state_in, ff_in, state_out, ff_out], dim=1)) # shape: |E| X D
# 2. aggregate message
scatter_idx = edge_out.view(-1, 1).expand(-1, self.hidden_dim)
msg_agg = torch.zeros(num_node, self.hidden_dim).to(b.device) # shape: |V| X D
msg_agg = msg_agg.scatter_add(0, scatter_idx, msg)
avg_norm = torch.zeros(num_node).to(b.device).scatter_add_(0, edge_out, torch.ones(num_edge).to(b.device))
msg_agg /= (avg_norm.view(-1, 1) + EPS)
# 3. update state
state_new = self.update_func(msg_agg, state_prev) # GRU update
return state_new
# propagation
for tt in range(self.num_prop):
state = _prop(state)
# output
y = self.output_func(torch.cat([state, b, -b], dim=1))
y = torch.log_softmax(y, dim=1)
loss = self.loss_func(y, target)
return y, loss
def predict(self, J\_msg, b, msg\_node, prob\_gt):
J_msg = J_msg[0].long()
b = b[0].long()
msg_node = msg_node[0].long()
num_node = b.shape[0]
num_edge = msg_node.shape[0]
edge_in = msg_node[:, 0]
edge_out = msg_node[:, 1].contiguous()
ff_in = torch.cat([b[edge_in], -b[edge_in], J_msg, -J_msg], dim=1)
ff_out = torch.cat([-b[edge_out], b[edge_out], -J_msg, J_msg], dim=1)
state = torch.zeros(num_node, self.hidden_dim).to(b.device)
def _prop(state_prev):
# 1. compute messages
state_in = state_prev[edge_in, :] # shape |E| X D
state_out = state_prev[edge_out, :] # shape |E| X D
msg = self.msg_func(torch.cat([state_in, ff_in, state_out, ff_out], dim=1)) # shape: |E| X D
# 2. aggregate message
scatter_idx = edge_out.view(-1, 1).expand(-1, self.hidden_dim)
msg_agg = torch.zeros(num_node, self.hidden_dim).to(b.device) # shape: |V| X D
msg_agg = msg_agg.scatter_add(0, scatter_idx, msg)
avg_norm = torch.zeros(num_node).to(b.device).scatter_add_(0, edge_out, torch.ones(num_edge).to(b.device))
msg_agg /= (avg_norm.view(-1, 1) + EPS)
# 3. update state
state_new = self.update_func(msg_agg, state_prev) # GRU update
return state_new
# propagation
for tt in range(self.num_prop):
state = _prop(state)
# output
res = dict()
y = self.output_func(torch.cat([state, b, -b], dim=1))
y = torch.log_softmax(y, dim=1)
loss = self.loss_func(y, prob_gt)
res["prob"] = np.exp(y.detach().cpu().numpy())
res["loss"] = loss.detach().cpu().numpy()
3. 모델 학습
- Neural Network 학습 과정을 관리할 runner 클래스 작성
# inference\_runner.py
from **future** import (division, print\_function)
import os
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import torch
import torch.utils.data
import torch.optim as optim
from model.gnn import NodeGNN
from dataset.dataloader import \*
import bentoml
EPS = float(np.finfo(np.float32).eps)
**all** = \['NeuralInferenceRunner'\]
class NeuralInferenceRunner(object):
def train(self):
print("=== START TRAINING ===")
\# create data loader
train\_dataset = MyDataloader(split='train')
val\_dataset = MyDataloader(split='val')
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=10,
shuffle=True,
collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=10,
shuffle=True,
collate_fn=val_dataset.collate_fn)
# create models
model = NodeGNN()
# create optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(
params,
lr=0.001)
# reset gradient
optimizer.zero_grad()
#========================= Training Loop =============================#
best_val_loss = np.inf
for epoch in range(10):
print("=== EPOCH : {} ===".format(epoch))
# ===================== validation ============================ #
model.eval()
for data in tqdm(val_loader, desc="VALIDATION"):
with torch.no_grad():
_, loss = model(data['J_msg'], data['b'], data['msg_node'], target=data['prob_gt'])
# ====================== training ============================= #
model.train()
for data in tqdm(train_loader, desc="TRAINING"):
optimizer.zero_grad()
_, loss = model(data['J_msg'], data['b'], data['msg_node'], target=data['prob_gt'])
loss.backward()
optimizer.step()
snapshot(model, optimizer)
# bentoml.sklearn.save('simple_gnn', model)
return best_val_loss
def predict(self):
return 0
def snapshot(model, optimizer):
model\_snapshot = {
"model": model.state\_dict(),
"optimizer": optimizer.state\_dict(),
}
torch.save(model\_snapshot,
os.path.join("model\_snapshot.pth"))
- 학습
$ python run\_exp\_local.py
4. bentoML 서비스
- JSON file 또는 Pickle file을 입력으로 받을 수 있는 bentoML 서비스
- bentofile.yaml 작성
# bentofile.yaml
service: "service:svc"
labels:
owner: bentoml-team
stage: demo
include:
- "\*.py"
python:
packages:
- torch
- service 작성
``` python
# service.py
import bentoml
import numpy as np
from bentoml.io import JSON
runner = bentoml.pytorch.load_runner(
"simple_gnn:latest",
predict_fn_name="predict"
)
svc = bentoml.Service("simple_gnn", runners=[runner])
@svc.api(input=JSON(), output=JSON())
def predict(input_arr: JSON):
J_msg, b, msg_node, prob_gt = np.array(input_arr['J_msg']), np.array(input_arr['b']), np.array(input_arr['msg_node']), np.array(input_arr['prob_gt'])
res = runner.run(J_msg, b, msg_node, prob_gt)
return res
$ bentoml build
$ bentoml serve simple_gnn:latest --reload
- 서비스 실행
- input type을 Jsoninput으로 할 경우 import 에러가 발생하여 우선 numpy로 작성
- 테스트
- 입력
{ "J_msg":[[ 0.32481338], [ 0.10008402], [ 0.32481338], [-0.29213125], [ 0.13292147], [-0.29213125], [ 0.27904898], [ 0.10008402], [ 0.09995882], [ 0.18141661], [ 0.13292147], [ 0.09995882], [-0.14101666], [ 0.07256332], [ 0.27904898], [-0.14101666], [-0.00166374], [ 0.18141661], [ 0.01135302], [ 0.07256332], [ 0.01135302], [ 0.17131865], [-0.00166374], [ 0.17131865]], "b":[[-0.11840663], [-0.24949627], [ 0.04009551], [ 0.58055465], [-0.19123979], [-0.0364058 ], [ 0.01783222], [ 0.10532287], [-0.27045844]], "msg_node":[[0, 1], [0, 3], [1, 0], [1, 2], [1, 4], [2, 1], [2, 5], [3, 0], [3, 4], [3, 6], [4, 1], [4, 3], [4, 5], [4, 7], [5, 2], [5, 4], [5, 8], [6, 3], [6, 7], [7, 4], [7, 6], [7, 8], [8, 5], [8, 7]], "prob_gt":[[0.42438148, 0.57561852], [0.35530727, 0.64469273], [0.55791833, 0.44208167], [0.7471271 , 0.2528729 ], [0.41531799, 0.58468201], [0.51015097, 0.48984903], [0.55322386, 0.44677614], [0.52453344, 0.47546656], [0.37538241, 0.62461759]] }
- 결과
-
{ "prob": [ [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ], [ 0.49200597405433655, 0.5079939961433411 ] ], "loss": 0.02680271677672863 }