원본 논문: https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf

  • 강화학습을 DNN에 적용한 최초의 논문

강화학습이란 ?

  • 현재 머신러닝은 크게 3가지로 분류된다.
    1. 비지도 학습(Unsupervised learning)
    2. 지도 학습(Supervised learning)
    3. 강화 학습(Reinforcement learning)
강화학습도 비지도학습의 일부가 아니냐는 주변분들의 의견이 있긴하지만 
크게 3가지로 나누는 것에 이견은 없는 것으로 보인다.
  • 강화학습은 쉽게 표현하면 연속적인 결정 을 내리는 과정이라고 볼 수 있다. 그리고 그 결정은 순수하게 보상이라는 것에 의해 결정된다.
  • 강화학습에는 행동의 주체 에이전트(agent), 행동(action), 보상(reward), 환경(environment) 가 존재한다.
  • 정책(policy), 가치함수(value function) 등 더 세부적인 내용이 있지만 결국 설계자가 지정한 보상을 최대로 하는 행동을 알아내고 수행하는 것이 강화학습의 작동원리이다.

  • 정책(policy), 가치함수(value function) 등 더 세부적인 내용이 있지만 결국 설계자가 지정한 보상을 최대로 하는 행동을 알아내고 수행하는 것이 강화학습의 작동원리이다.
 

What is reinforcement learning? The complete guide - deepsense.ai

Although machine learning is seen as a monolith, this cutting-edge technology is diversified, with various sub-types including machine learning, deep learning, and the state-of-the-art technology of deep reinforcement learning.

deepsense.ai

 

소개

  • 강화학습과 딥러닝을 결합하는 뼈대를 제공한 DQN 모델로 간단한 그리드 월드에서 열매(?)를 찾아가는 법을 학습시키는 모델을 만들어보았다.
  • Github 링크
 

GitHub - Hyunmok-Park/RL_snake

Contribute to Hyunmok-Park/RL_snake development by creating an account on GitHub.

github.com

모델 소개

  • 환경 : 그리드 월드
  • state : (현재 팩맨의 위치, 열매위치)
  • 보상 : 기본 이동시 -1, 열매를 획득한다면 +1
  • 행동 : 4방향 이동

규칙

  • 모델 초기에 그리드 월드의 가로, 세로 길이를 지정한다.
  • 팩맨은 게임 시작시에 (0,0) 위치에서 출발하며, 열매는 랜덤한 위치에 1개 형성된다.
  • 팩맨은 매 단계에서 4방향 중 한가지 방향을 선택해서 이동하며 열매를 획득한다면 +1 보상, 그렇지 못하면 -1 보상을 획득한다.
  • 추가로 장애물, 몸의 길이가 늘어나는 규칙도 추가할 예정

모델환경

  • 환경 클래스는 액션을 받아서 환경을 업데이트하고 그에 따른 보상을 리턴한다.
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

class myenv():
    def __init__(self, height=10, width=10):
        self.height = height
        self.width = width
        self.x = 0
        self.y = 0
        self.dx = [1, 0, -1, 0]
        self.dy = [0, 1, 0, -1]

        self.food_x = None
        self.food_y = None

        self.create_food()

    def reset(self):
        self.x = 0
        self.y = 0
        self.create_food()
        return self.x, self.y, self.food_x, self.food_y

    def step(self, action):

        if action == 0: #right
            if self.x == (self.width - 1):
                pass
            else:
                self.x = self.x + 1
        elif action == 1: #up
            if self.y == 0:
                pass
            else:
                self.y = self.y - 1
        elif action == 2: #left
            if self.x == 0:
                pass
            else:
                self.x = self.x - 1
        elif action == 3 : #down
            if self.y == (self.height - 1):
                pass
            else:
                self.y = self.y + 1

        if self.x == self.food_x and self.y == self.food_y:
            reward = 1
            done = True
        else:
            done = False
            reward = -1

        return (self.x, self.y, self.food_x, self.food_y), reward, done

    def create_food(self):
        done = False
        while not done:
            x = np.random.choice([i for i in range(self.width)])
            y = np.random.choice([i for i in range(self.height)])

            if x == self.food_x and y == self.food_y:
                done = False
            else:
                done = True

        self.food_x = x
        self.food_y = y

    def draw_world(self):
        return 0
  • Qnet
    • 간단한 MLP구조로 출력으로 4가지 action에 대한 value 값을 리턴한다.
import torch
import torch.nn as nn
import torch.nn.functional as F

import random

class Qnet(nn.Module):
    def __init__(self, hidden_dim):
        super(Qnet, self).__init__()

        self.nn = nn.Sequential(
            nn.Linear(4, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 4)
        )

    def forward(self, state):
        return self.nn(state)

    def sample_action(self, state, eps):
        action = self.forward(state)
        if random.random() < eps:
            return random.choice([i for i in range(4)])
        else:
            return action.argmax().item()
  • Buffer
    • DQN에서 사용하는 replay buffer로 간단한 deque를 사용해서 최대 10000개, 최소 2000개의 transition 데이터를 사용해 학습한다.
import collections
import random

import torch

class ReplayBuffer():
    def __init__(self, buffer_limit):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, num_sample):
        mini_batch = random.sample(self.buffer, num_sample)
        s_list = []
        a_list = []
        r_list = []
        next_s_list = []
        done_list = []

        for tran in mini_batch:
            s_list.append(tran[0])
            a_list.append([tran[1]])
            r_list.append([tran[2]])
            next_s_list.append(tran[3])
            done_list.append([tran[4]])

        s_list = torch.tensor(s_list, dtype=torch.float)
        a_list = torch.tensor(a_list)
        r_list = torch.tensor(r_list)
        next_s_list = torch.tensor(next_s_list, dtype=torch.float)
        done_list = torch.tensor(done_list)

        return s_list, a_list, r_list, next_s_list, done_list

    def len(self):
        return len(self.buffer)
  • 메인 학습
    • 학습 및 테스트를 진행하는 코드
    • DQN은 offline 학습 구조를 따르기 때문에 q, q_target 이라는 2개의 별도의 qnet을 형성한다.
    • 이 모델에서는 20번의 에피소드가 끝날때마다 q_target 을 업데이트해주었다.
    • 감마값은 0.98로 설정
import torch
import torch.nn.functional as F
import torch.optim

import numpy as np
from tqdm import tqdm

from net import Qnet
from grid_world import myenv
from buffer import ReplayBuffer

def main():
    device = torch.device('mps')

    q = Qnet(128)
    q_target = Qnet(128)
    env = myenv(10, 10)
    q_target.load_state_dict(q.state_dict())
    memory = ReplayBuffer(buffer_limit=10000)

    update_interval = 20
    optimizer = torch.optim.Adam(q.parameters(), lr=0.001)

    for n_epi in tqdm(range(10000), desc="n_epi"):
        eps = max(0.3, 0.9 - 0.01 * (n_epi / 200))
        s = env.reset() #(x,y,f_x,f_y)
        done = False
        score = 0

        while not done:
            a = q.sample_action(torch.from_numpy(np.array(s)).float(), eps)
            next_s, reward, done = env.step(a)
            done_mask = 0.0 if done else 1.0
            memory.put((s, a, reward, next_s, done_mask))
            s = next_s
            score += reward
            if done:
                break

        if memory.len() > 2000:
            train(q, q_target, memory, optimizer)

        if n_epi % update_interval == 0:
            q_target.load_state_dict(q.state_dict())
            score = 0

    torch.save(q_target.state_dict(), "qnet")

    x, y, f_x, f_y = env.reset()
    q_target.eval()
    print(x, y, f_x, f_y)
    while True:
        action = q_target(torch.tensor([x, y, f_x, f_y]).float()).argmax().item()
        next_s, reward, done = env.step(action)
        print(action, next_s)
        x, y = next_s[0], next_s[1]
        if done:
            break

def train(q, q_target, memory, opt):
    for i in range(10):
        s_list, a_list, r_list, next_s_list, done_list = memory.sample(32)
        q_out = q(s_list)
        q_a = q_out.gather(1, a_list)
        max_q_prime = q_target(next_s_list).max(1)[0].unsqueeze(1)
        target = r_list + 0.98 * max_q_prime * done_list
        loss = F.smooth_l1_loss(q_a, target)

        opt.zero_grad()
        loss.backward()
        opt.step()

if __name__ == '__main__':
    main()

결과

  • 테스트 결과 (0,0)에서 출발한 팩맨이 (2,9) 위치의 열매를 찾기 위해 최단경로로 이동하는 것을 확인했다.
    • 첫째줄 : (팩맨_x, 팩맨_y, 열매_x, 열매_y)
    • 둘째줄부터 : action, (팩맨_x, 팩맨_y, 열매_x, 열매_y)
0 0 2 9
3 (0, 1, 2, 9)
3 (0, 2, 2, 9)
3 (0, 3, 2, 9)
3 (0, 4, 2, 9)
3 (0, 5, 2, 9)
3 (0, 6, 2, 9)
3 (0, 7, 2, 9)
0 (1, 7, 2, 9)
0 (2, 7, 2, 9)
3 (2, 8, 2, 9)
3 (2, 9, 2, 9)

근데 이렇게 만들면 안될수가 없다... 좌표값을 그대로 주는 것이 아니라 이미지 인식으로 해야한다.

'3. Dev > 모델 구현' 카테고리의 다른 글

Pytorch로 U-Net 구현하기  (0) 2023.02.12
Pytorch로 GAN 구현하기(+ mnist 데이터)  (0) 2023.01.27
Pytorch로 GAN 구현하기  (1) 2023.01.24
Pytorch로 Transformer 구현하기  (0) 2022.10.20

+ Recent posts