본문 바로가기
AI/딥러닝 프레임워크 개발

15~16단계) 복잡한 계산 그래프의 이론 및 구현

by 채채씨 2021. 6. 15.
728x90
반응형
1. 복잡한 계산 그래프(이론 편)

1) 복잡한 계산 그래프

지금까지는 아래와 같은 일직선 계산 그래프에 대해 구현했다.

 

일직선 계산 그래프

 

이제는 아래와 같이 변수와 함수가 복잡하게 연결된 그래프를 다루어볼 것이다.

 

복잡한 계산 그래프

 

현재의 DeZero는 이런 복잡한 연결의 역전파를 제대로 할 수 없다.

 

 

2) 역전파의 올바른 순서

아래 계산 그래프의 역전파 순서에 대해 생각해보자.

 

 

올바른 역전파를 계산한다면 그 순서는 아래와 같을 것이다.

 

올바른 역전파의 순서

 

 

2) 현재의 DeZero 역전파 순서

그러나 아래와 같이 구현되어 있는 현재 DeZero의 역전파 흐름은 올바른 순서와 다르다.

 

class Variable:

	#생략
    
    def backward(self):
    	if self.grad is None:
        	self.grad = np.ones_like(self.data)
            
        funcs = [self.creator]
        while funcs: 
        	f = funcs.pop()  ##### 이 부분 #####
            gys = [output.grad for output in f.outputs]
            gxs = f.backward(*gys)
            if not isinstance(gxs, tuple):
            	gxs = (gxs,)
                
            for x, gx in zip(f.inputs, gxs):
            	if x.grad is None:
                	x.grad = gx
                else:
                    x.grad += gx
                    
                if x.creator is not None:
                	funcs.append(x.creator)  ##### 이 부분 #####

 

funcs의 리스트를 잘 보면, while문의 마지막 코드인 funcs.append(x.creator)로 처리할 함수의 후보를 funcs리스트에 추가하고, funcs.pop()으로 다음에 처리할 함수를 funcs리스트 끝에서 꺼낸다. 이 경우, 역전파는 아래와 같이 진행된다.

 

 

 

3) 함수의 우선순위

지금까지는 아무생각 없이 마지막 원소를 꺼내어 계산했다면, 이제는 funcs리스트에서 순서에 맞는 적절한 함수를 꺼내야 한다. 즉, 함수에 '우선순위'를 주어야 한다. 우선순위를 주기 위해, 순전파 계산 본 '함수'와 '변수'가 만들어지는 과정을 이용하여, 창조자-피조물 혹은 부모-자식 관계같은 이 관계를 기준으로 '세대(generation)'를 기록할 것이다.

 

함수와 변수 '세대'

 

세대가 우선순위인 것이다. 세대 수가 큰 것부터 처리하면 부모보다 자식이 먼저 처리되는 것을 확실히 할 수 있다.

 


 

2. 복잡한 계산 그래프(구현 편)

1) 세대설정

먼저 Variable클래스와 Function클래스에 generation이라는 변수를 추가하여 몇 번째 세대의 함수인지를 기록할 것이다.

 

■ Variable 클래스

class Variable:
	def __init__(self, data):
    	if data is not None:
        	if not isinstance(data, np.array):
            	raise TypeError('{}은(는) 지원하지 않습니다.'.format(type(data)))
                
        self.data = data
        self.grad = None
        self.crator = None
        self.generation = 0 #세대수 변수
        
    def set_creator(self, func):
    	self.creator = func
        self.generation = func.generation + 1 #세대수 기록 (부모 세대 +1)
       
    #생략

 

 

set_creator메서드가 호출될 때 부모 함수의 세대보다 1 큰 값으로 세대수를 설정하였다.

 

 

■ Function 클래스

Function클래스의 generation은 입력 변수의 generation과 같은 값으로 설정할 것이다. 입력 변수가 두 개 이상이라면 더 큰 generation을 선택한다.

class Function(object):
	def __call__(self, *inputs):
    	xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
        	ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]
        
        self.generation = max([x.generation for x in inputs]) #더 큰 세대수를 받는다
        for output in outputs:
        	output.set_creator(self)
        self.inputs = inputs
        self.outputs = outputs
        return outputs if len(outputs) > 1 else outputs[0]

 

 

2) 세대 순으로 꺼내기

 

 

generations = [2, 0, 1, 4, 2]
funcs = []

for g in generations:
	f = Function()
    f.generation = g
    funcs.append(f)
    
[f.generation for f in funcs] #[2, 0, 1, 4, 2]

세대가 가장 큰 함수를 꺼내면

funcs.sort(key=lambda x: x.generation)
[f.generation for f in funcs] #[0, 1, 2, 2, 4]

f = funcs.pop()
f.generation #4

 

■ Variable 클래스의 Backward 메서드

class Variable:

	#생략
    
    def backward(self):
    	if self.grad is None:
        	self.grad = np.ones_like(self.data)
            
        funcs = []
        seen_set = set()
        
        def add_func(f):
        	if f not in seen_set:
            	funcs.append(f)
                seen_set.add(f)
                funcs.sort(key=lambda x: x.generation)
                
        add_func(self.creator)
        
        while funcs:
        	f = funcs.pop()
            gys = [output.grad for output in f.outputs]
            gxs = f.backward(*gys)
            if not isinstance(gxs, tuple):
            	gxs = (gxs,)
                
            for x, gx in zip(f.inputs, gxs):
            	if x.grad is None:
                	x.grad = gx
                else:
                	x.grad += gx
                    
                if x.creator is not None:
                	add_func(x.creator) #수정 전: funcs.append(x.creator)

 

 

3) 동작 확인

아무리 복잡한 계산 그래프의 역전파도 올바르게 계산할 수 있게 되었다.

 

 

x = Variable(np.array(2.0))
a = Square(x)
y = add(square(a), square(a))
y.backward()

print(y.data) #32.0
print(x.data) #64.0

x=2일 때, 미분값 64이므로 올바른 값을 도출하고 있다.

 


728x90
반응형

댓글