Loading [MathJax]/jax/output/CommonHTML/jax.js

기계학습/밑바닥딥러닝3 오독오독 씹기

Chapter 2. 자연스로운 코드로(step 16)

H_erb Salt 2020. 12. 16. 16:51
step_16

15. 복잡한 계산 그래프(이론)

pass

16. 복잡한 계산 그래프(구현)

  • 이전 장(Step 15)에서 설명한 복잡한 계산 그래프에 관한 이론(노트로는 정리 안함)을 구현함
  • 가장 먼저 순전파 시 '세대'를 설정하는 부분부터 시작
  • 그런 다음, 역전파 시 최근 세대의 함수부터 꺼내도록 함.
  • 이런 방식으로 복잡한 계산 그래프라도 올바른 순서로 역전파로 이루어짐

16.0 현재까지의 구현

In [1]:
import numpy as np
In [2]:
def as_array(x):
    if np.isscalar(x):
        return np.array(x)
    return x

def square(x):
    return Square()(x) 

def exp(x):
    return Exp()(x)

def add(x0, x1):
    return Add()(x0, x1)
In [3]:
class Variable:
    def __init__(self, data):
        if data is not None: 
            if not isinstance(data, np.ndarray):
                raise TypeError('{}은(는) 지원하지 않아요. ndarray로 입력하세요.'.format(type(data)))
                
        self.data = data
        self.grad = None
        self.creator = None
        
    def set_creator(self, func):
        self.creator = func
        
    def backward(self):
        if self.grad is None:
            self.grad = np.ones_like(self.data)
        
        funcs = [self.creator]
        while funcs:
            f = funcs.pop()
            gys = [output.grad for output in f.outputs]
            gxs = f.backward(*gys)
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
                
            for x, gx in zip(f.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx
                
                if x.creator is not None:
                    funcs.append(x.creator)
                    
    def cleargrad(self):
        self.grad = None
In [4]:
class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        
        if not isinstance(ys, tuple): 
            ys = (ys,)
            
        outputs = [Variable(as_array(y)) for y in ys]
        
        for output in outputs:
            output.set_creator(self)
        
        self.inputs = inputs
        self.outputs = outputs
        
        return outputs if len(outputs) > 1 else outputs[0]
    
    def forward(self, xs):
        raise NotImplementedError()
        
    def backward(self, gys):
        raise NotImplementedError()
In [5]:
class Exp(Function):
    def forward(self, x):
        y = np.exp(x)

        return y

    def backward(self, dy):
        x = self.input.data
        dx = np.exp(x) * dy

        return dx
    
class Square(Function):
    def forward(self, x):
        y = x ** 2
        return y
    
    def backward(self, gy):
        x = self.inputs[0].data
        gx = 2 * x * gy
        return gx
    
class Add(Function):
    def forward(self, x0, x1):
        y = x0 + x1
        return y
    
    def backward(self, gy):
        return gy, gy

16.1 세대 추가

  • 먼저 Variable 클래스와 Function 클래스에 인스턴스 변수 generation을 추가함
  • 몇 번째 '세대'의 함수(혹은 변수)인지 나타내는 변수. Variable 클래스부터 시작
In [8]:
class Variable:
    def __init__(self, data):
        if data is not None: 
            if not isinstance(data, np.ndarray):
                raise TypeError('{}은(는) 지원하지 않아요. ndarray로 입력하세요.'.format(type(data)))
                
        self.data = data
        self.grad = None
        self.creator = None
        self.generation = 0 # 세대 수를 기록하는 변수
        
    def set_creator(self, func):
        self.creator = func
        self.generation = func.generation + 1 # 세대를 기록함(부모 세대 +1)
        
    def backward(self):
        if self.grad is None:
            self.grad = np.ones_like(self.data)
        
        funcs = [self.creator]
        while funcs:
            f = funcs.pop()
            
            gys = [output.grad for output in f.outputs]
            gxs = f.backward(*gys)
            
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
                
            for x, gx in zip(f.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx
                
                if x.creator is not None:
                    funcs.append(x.creator)
                    
    def cleargrad(self):
        self.grad = None
  • Variable 클래스는 generation을 0으로 초기화함.
  • 그리고 set_creator 메서드가 호출될 때 부모 함수의 세대보다 1만큼 큰 값을 설정함
  • 다음은 Function 클래스. Function 클래스의 generation은 입력 변수와 같은 값으로 설정함
  • 입력 변수가 둘 이상이라면 가장 큰 generation의 수를 선택함. 예를 들어, 입력 변수가 2개고 각각의 generation이 3과 4라면 함수의 generation은 4로 설정
  • 이를 반영한 코드
In [9]:
class Function(object):
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        
        if not isinstance(ys, tuple):
            ys = (ys,)
            
        outputs = [Variable(as_array(y)) for y in ys]
        
        self.generation = max([x.generation for x in inputs])
        
        for output in outputs:
            output.set_creator(self)
            
        self.inputs = inputs
        self.outputs = outputs
        
        return outputs if len(outputs) > 1 else outputs[0]
    
    def forward(self, xs):
        raise NotImplementedError()
        
    def backward(self, gys):
        raise NotImplementedError()

16.2 세대 순으로 꺼내기

  • 지금까지의 수정을 반영하여 일반적인 계산(순전파)을 하면 모든 변수와 함수에 세대가 설정됨
    • 이전 단계에서 이야기 한 것 처럼, Variable 클래스의 backward 메서드 안에서는 처리할 함수의 후보들을 funcs 리스트에 보관함. 따라서 funcs에서 세대가 큰 함수부터 꺼내게 하면 올바른 순서로 역전파할 수행 가능
  • 이어서 함수를 세대 순으로 꺼낼 차례. 그 준비 작업으로 Dummy DeZero 함수를 사용하여 간단한 실험
In [10]:
generations = [2, 0, 1, 4, 2]
funcs = []

for g in generations:
    f = Function() # 더미 함수 클래스
    f.generation = g
    funcs.append(f)

[f.generation for f in funcs]
Out[10]:
[2, 0, 1, 4, 2]
  • 이와 같이 더미 함수를 준비하고 funcs 리스트에 추가함. 그런 다음 이 리스트에서 세대가 가장 큰 함수를 꺼냄
In [11]:
funcs.sort(key=lambda x: x.generation)
[f.generation for f in funcs]
Out[11]:
[0, 1, 2, 2, 4]
In [12]:
f = funcs.pop()
f.generation
Out[12]:
4
  • 코드에서 보듯이 리스트의 sort 메서드를 이용하여 generation을 오름차순으로 정렬함

16.3 Variable 클래스의 backward

  • Variable 클래스의 backward 메서드를 구현함.
In [13]:
class Variable:
    def __init__(self, data):
        if data is not None: 
            if not isinstance(data, np.ndarray):
                raise TypeError('{}은(는) 지원하지 않아요. ndarray로 입력하세요.'.format(type(data)))
                
        self.data = data
        self.grad = None
        self.creator = None
        self.generation = 0
        
    def set_creator(self, func):
        self.creator = func
        self.generation = func.generation + 1
        
        
    def backward(self):
        if self.grad is None:
            self.grad = np.ones_like(self.data)
            
        funcs = [] # 해당 부분 추가
        seen_set = set()
        
        def add_func(f):
            if f not in seen_set:
                funcs.append(f)
                seen_set.add(f)
                funcs.sort(key=lambda x: x.generation)
        
        add_func(self.creator) # 여기까지
        
        while funcs:
            f = funcs.pop()
            
            gys = [output.grad for output in f.outputs]
            gxs = f.backward(*gys)
            
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
                
            for x, gx in zip(f.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx
                
                if x.creator is not None:
                    add_func(x.creator) # 수정 전: funcs.append(x.creator)
                    
    def cleargrad(self):
        self.grad = None
  • 가장 큰 변화는 새로 추가된 add_func 함수. 그동안 'DeZero 함수'를 리스트에 추가할 때 funcs.append(f)를 호출했는데, 대신 add_func 함수를 호출하도록 변경
  • add_func 함수가 DeZero 함수 리스트를 세대 순으로 정렬하는 역할. 그 결과, funcs.pop()은 자동으로 세대가 가장 큰 DeZero 함수를 꺼내게 됨
  • 참고로, add_func 함수를 backward 메서드 안에 중첩 함수로 정의함. 중첩 함수는 주로 다음 두 조건을 충족할 때 적합함
    • 감싸는 메서드(backward 메서드)안에서만 이용
    • 감싸는 메서드(backward 메서드)에 정의된 변수(funcs과 seen_set)를 사용해야 함
  • 또한, 그 앞에서는 seen_set()이라는 집합을 이용함. funcs 리스트에 같은 함수를 중복 추가하는 일을 막기 위함. 덕분에 backward 메서드가 잘못되어 여러 번 불리는 일은 발생하지 않음

16.4 동작 확인

  • 이상으로 세대가 큰 함수부터 꺼낼 수 있게 됨. 아무리 복잡한 계산 그래프의 역전파도 올바른 순서로 진행할 수 있음
In [14]:
'''###################################
흐름과는 상관없이 오류 때문에 재생한 코드
###################################'''

class Square(Function):
    def forward(self, x):
        y = x ** 2
        return y
    
    def backward(self, gy):
        x = self.inputs[0].data
        gx = 2 * x * gy
        return gx
    

class Add(Function):
    def forward(self, x0, x1):
        y = x0 + x1
        return y
    
    def backward(self, dy):
        return dy, dy
In [15]:
x = Variable(np.array(2.0))
a = square(x)
y = add(square(a), square(a))
y.backward()

print(y.data)
print(x.grad)
32.0
64.0
  • 결과를 보면 x의 미분은 64.0. 수식으로 확인하면 계산 그래프는 y=(x2)2+(x2)2이므로 간단히 y=2x4을 미분하는 문제.
  • 이 때, y=8x3 이므로, x=2.0 일 때의 미분은 64.0
  • 이상으로 복잡한 계산 그래프도 다룰 수 있게 됨
  • 다음 단계에서는 DeZero의 성능, 특히 메모리 사용량에 대해 알아봄
In [ ]: