1. 목표
딥러닝 이미지 모델로 유명한 모델인 U-Net을 pytorch로 구현하는 것을 목표로 했다. 평소에 해보고 싶었던 이미지 세그먼트를 수행했고, 데이터는 ISBI 2012 EM Segmentation Challenge에 사용된 membrane 데이터셋을 사용했다. 자동차 번호판 인식을 해보고 싶었는데 그건 다음에 시도해보도록 하겠다. 데이터셋은 train / val / test 로 모두 나눠서 저장했다.
모든 소스 코드는 깃허브에 업로드했다.
[GitHub - Hyunmok-Park/Torch_U_Net
Contribute to Hyunmok-Park/Torch_U_Net development by creating an account on GitHub.
github.com](https://github.com/Hyunmok-Park/Torch_U_Net)
2. 데이터셋
데이터는 위의 dataset 폴더에 저장했으며 정규화 전처리 정도만 추가해서 사용했다. 흑백 이미지이기 때문에 채널수가 1개라서 의도적으로 차원을 확장해서 (1, 512, 512) 크기를 맞춰주었다.
import os.path
import random
import shutil
from glob import glob
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class Dataset(Dataset):
def __init__(self, data_path, mode='train', transform=None):
self.transform = transform
self.labels = [file for file in sorted(glob(f"{data_path}/{mode}/label_*.npy"))]
self.inputs = [file for file in sorted(glob(f"{data_path}/{mode}/input_*.npy"))]
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
labels = np.load(self.labels[index])
inputs = np.load(self.inputs[index])
# 정규화
labels = labels/255.0
inputs = inputs/255.0
if labels.ndim == 2:
labels = labels[np.newaxis, :, :]
if inputs.ndim == 2:
inputs = inputs[np.newaxis, :, :]
return inputs, labels
def load_dataset(data_path='dataset', mode='train', batch_size=4, img_size=512):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)
])
dataset = Dataset(data_path, mode, transform)
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
return data_loader
3. 모델 소개
모델은 인코더, 디코더, 인코더와 디코더를 잇는 브릿지로 크게 3가지 단계로 구성했다. 인코더와 디코더에서는 반복되는 구조를 하나의 클래스로 작성해서 번거로움을 줄였다.
인코더, 디코더를 구성하는 Conv2d 는 모두 kernel-size=3, stride=1, padding=1 로 고정했고, pool의 kernel-size=2로 고정했다. 다만 padding=0으로 설정했어야 한다고 배웠는데 이 부분은 추가적인 확인이 필요했다. 현재는 padding=0 으로 설정할 경우, 이미지의 크기가 맞지 않아서 모델 학습이 불가능했다.
3.1 Encoder block
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias, pool_kernel_size):
super().__init__()
self.convlayer1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
self.convlayer2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
self.batchnorm1 = nn.BatchNorm2d(num_features=out_channels)
self.batchnorm2 = nn.BatchNorm2d(num_features=out_channels)
self.layer = nn.Sequential(
self.convlayer1,
self.batchnorm1,
nn.ReLU(),
self.convlayer2,
self.batchnorm2,
nn.ReLU(),
)
self.pool = nn.MaxPool2d(kernel_size=pool_kernel_size)
def forward(self, inputs):
output_for_decoder = self.layer(inputs)
output_for_next = self.pool(output_for_decoder)
return output_for_next, output_for_decoder
3.2 Decoder block
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias, pool_kernel_size):
super().__init__()
self.convlayer1 = nn.Conv2d(
in_channels=out_channels * 2,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
self.convlayer2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
self.batchnorm1 = nn.BatchNorm2d(num_features=out_channels)
self.batchnorm2 = nn.BatchNorm2d(num_features=out_channels)
self.pool = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=pool_kernel_size, stride=2, padding=0)
self.layer = nn.Sequential(
self.convlayer1,
self.batchnorm1,
nn.ReLU(),
self.convlayer2,
self.batchnorm2,
nn.ReLU()
)
def forward(self, decoder_output, encoder_output):
inputs = self.pool(decoder_output)
inputs = torch.cat([encoder_output, inputs], dim=1)
output = self.layer(inputs)
return output
3.3 U-Net
class UNet(nn.Module):
def __init__(self):
super().__init__()
self.e_block1 = EncoderBlock(1, 64, 3, 1, 1, True, 2)
self.e_block2 = EncoderBlock(64, 128, 3, 1, 1, True, 2)
self.e_block3 = EncoderBlock(128, 256, 3, 1, 1, True, 2)
self.e_block4 = EncoderBlock(256, 512, 3, 1, 1, True, 2)
self.bridge = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(1024),
nn.ReLU(),
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(1024),
nn.ReLU()
)
self.d_block1 = DecoderBlock(1024, 512, 3, 1, 1, True, 2)
self.d_block2 = DecoderBlock(512, 256, 3, 1, 1, True, 2)
self.d_block3 = DecoderBlock(256, 128, 3, 1, 1, True, 2)
self.d_block4 = DecoderBlock(128, 64, 3, 1, 1, True, 2)
self.final_conv = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
def forward(self, inputs):
e_output1, d_inputs1 = self.e_block1(inputs)
e_output2, d_inputs2 = self.e_block2(e_output1)
e_output3, d_inputs3 = self.e_block3(e_output2)
e_output4, d_inputs4 = self.e_block4(e_output3)
e_output = self.bridge(e_output4)
d_output1 = self.d_block1(e_output, d_inputs4)
d_output2 = self.d_block2(d_output1, d_inputs3)
d_output3 = self.d_block3(d_output2, d_inputs2)
d_output4 = self.d_block4(d_output3, d_inputs1)
output = self.final_conv(d_output4)
return output
4. 학습
학습 코드는 평소와 마찬가지로 config을 입력받아서 진행하도록 작성하였다. 모델 저장 기준으로 validation 데이터셋에 대한 결과를 추가했다. 테스트 단계에서 epoch 기준으로 저장한 모델과 validation 결과를 기준으로 저장한 모델을 비교한 결과, validation 모델이 조금은 더 좋아보였다.
from datetime import datetime
import os
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from tqdm import tqdm
import numpy as np
from data_factory.data_loader import load_dataset
from model.UNet import UNet
from torchvision.utils import save_image
def train(config):
EXP_NAME = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
os.makedirs(f"result/{EXP_NAME}", exist_ok=True)
os.makedirs(f"result/{EXP_NAME}/test", exist_ok=True)
train_loader = load_dataset(config['data_path'], 'train', config['batch_size'], config['img_size'])
val_loader = load_dataset(config['data_path'], 'val', config['batch_size'], config['img_size'])
test_loader = load_dataset(config['data_path'], 'test', 1, config['img_size'])
model = UNet().to(config['device'])
criterion = nn.BCEWithLogitsLoss().to(config['device'])
optimizer = optim.Adam(params=model.parameters(), lr=config['learning_rate'])
train_loss = []
val_loss = []
min_val_loss = np.inf
for epoch in tqdm(range(config['epoch']), desc='EPOCH'):
train_loss_batch = []
val_loss_batch = []
###########
# TRAIN #
###########
model.train()
for batch, (inputs, label) in enumerate(train_loader):
optimizer.zero_grad()
inputs = inputs.to(config['device']).type(torch.float32)
label = label.to(config['device']).type(torch.float32)
output = model.forward(inputs)
loss = criterion(label, output)
loss.backward()
optimizer.step()
train_loss_batch.append(loss.item())
torch.save(model.state_dict(), f'result/{EXP_NAME}/checkpoint.pth')
train_loss.append(np.mean(train_loss_batch))
##########
# EVAL #
##########
model.eval()
for batch, (inputs, label) in enumerate(val_loader):
inputs = inputs.to(config['device']).type(torch.float32)
label = label.to(config['device']).type(torch.float32)
output = model.forward(inputs)
loss = criterion(label, output)
val_loss_batch.append(loss.item())
if min_val_loss > np.mean(val_loss_batch):
min_val_loss = np.mean(val_loss_batch)
torch.save(model.state_dict(), f'result/{EXP_NAME}/checkpoint_best_val.pth')
val_loss.append(np.mean(val_loss_batch))
##########
# TEST #
##########
to_class = lambda x: 1.0 * (x > 0.5)
model.load_state_dict(torch.load(f'result/{EXP_NAME}/checkpoint.pth'))
model.eval()
for idx, (inputs, label) in enumerate(test_loader):
inputs = inputs.to(config['device']).type(torch.float32)
label = label.to(config['device']).type(torch.float32)
output = model.forward(inputs)
inputs = inputs.detach().cpu().numpy().reshape(config['img_size'], config['img_size'])
label = label.detach().cpu().numpy().reshape(config['img_size'], config['img_size'])
output = output.detach().cpu().numpy().reshape(config['img_size'], config['img_size'])
output = to_class(output)
f, ax = plt.subplots(1, 3, figsize=(10, 4))
ax[0].imshow(inputs, cmap='gray')
ax[1].imshow(label, cmap='gray')
ax[2].imshow(output, cmap='gray')
f.savefig(f"result/{EXP_NAME}/test/{idx}.png")
plt.close()
model.load_state_dict(torch.load(f'result/{EXP_NAME}/checkpoint_best_val.pth'))
model.eval()
for idx, (inputs, label) in enumerate(test_loader):
inputs = inputs.to(config['device']).type(torch.float32)
label = label.to(config['device']).type(torch.float32)
output = model.forward(inputs)
inputs = inputs.detach().cpu().numpy().reshape(config['img_size'], config['img_size'])
label = label.detach().cpu().numpy().reshape(config['img_size'], config['img_size'])
output = output.detach().cpu().numpy().reshape(config['img_size'], config['img_size'])
output = to_class(output)
f, ax = plt.subplots(1, 3, figsize=(10, 4))
ax[0].imshow(inputs, cmap='gray')
ax[1].imshow(label, cmap='gray')
ax[2].imshow(output, cmap='gray')
f.savefig(f"result/{EXP_NAME}/test/{idx}_val.png")
plt.close()
f, ax = plt.subplots(1, 1, figsize=(10, 2))
ax.plot(train_loss, color='blue')
ax.plot(val_loss, color='red')
f.savefig(f"result/{EXP_NAME}/train_loss.png")
plt.close()
if __name__ == '__main__':
config = {}
config['num_channel'] = 3
config['img_size'] = 512
config['data_path'] = 'dataset'
config['batch_size'] = 4
config['learning_rate'] = 0.0001
config['epoch'] = 200
config['device'] = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
train(config)
5. 결과
왼쪽의 결과가 validation 기준, 오른쪽이 epoch 기준으로 저장한 모델이다. 하나의 이미지에서는 왼쪽부터 실제 입력, 정답, 테스트 결과이다. 어느 쪽이 더 좋다고 단언하기는 어렵지만 왼쪽의 모델이 조금 더 세밀하게 조각을 나누지 않았나싶다.
6. 마치며
Pytorch 라이브러리를 사용해서 U-Net을 직접 구현해보고 학습까지 진행했다. 이미지쪽은 워낙 아는 바가 없어서 구현에서 실행까지 생각보다 시간을 오래 잡아먹었던 것 같다. 얕게나마 알아본 결과, 실제 U-net에서는 다양한 학습 테크닉을 추가했던 것 같다. 이미지라는게 다양한 기법을 통해 학습 이미지를 약간 변형해서 성능을 높일 수 있는 것으로 알고 있다. 다음에는 다른 데이터셋을 사용해서 성능을 확인해보고자 한다.
'3. Dev > 모델 구현' 카테고리의 다른 글
Pytorch로 GAN 구현하기(+ mnist 데이터) (0) | 2023.01.27 |
---|---|
Pytorch로 GAN 구현하기 (1) | 2023.01.24 |
Pytorch로 Transformer 구현하기 (0) | 2022.10.20 |
Pytorch로 DQN 구현하기(+ 팩맨) (0) | 2022.10.20 |