본문 바로가기
AI/딥러닝 프레임워크 개발

17~18단계) 메모리 관리 방식 , 순환 참조 , 메모리 절약 모드

by 채채씨 2021. 6. 19.
728x90
반응형
1. 메모리 관리 방식 

CPython의 메모리 관리는 두 가지 방식으로 진행된다. 참조(reference) 수를 세는 방식과 세대(generation)를 기준으로 쓸모없어진 객체를 회수하는 방식이다. 앞으로 전자를 참조 카운트로, 후자를 GC(Garbage Collection)라고 부를 것이다. 먼저 참조 카운트 방식을 살펴본 후 GC를 볼 것이다.

 

 

1) 참조 카운트 방식

모든 객체는 참조 카운트가 0인 상태로 생성된다. 다른 객체가 참조할 때마다 1씩 증가하고 객체에 대한 참조가 끊길 때마다 1씩 감소하다가 0이 되면 해당 객체는 메모리에서 삭제된다. 예를 들어, 대입 연산자를 사용하거나 함수에 인수로 전달하거나 컨테이너 타입 객체에 추가할 때 참조 카운트가 증가한다.

 

class obj:
	pass
    
def f(x):
	print(x)
    
a = obj() #변수에 대입: 참조 카운트 1
f(a) #함수에 전달: 참조 카운트 2
#함수 완료: 빠져나오면 참조 카운트 1
a = None #대입 해제: 참조 카운트 0

 

a = obj()
b = obj()
c = obj()

a.b = b
b.c = c

a = b = c = None

 

객체 관계도

 

a가 b를 참조하고, b가 c를 참조하도록 설정하고 나서, a = b = c = None을 실행하면 오른쪽처럼 바뀐다. a의 참조 카운트가 0이되면서 삭제되고, 잇따라 b의 참조 카운트도 0이 되어 삭제되고 c의 참조 카운트도 0이 되어 삭제된다. 이것이 파이썬의 참조 카운트 메모리 관리 방법이다. 

 

 

2) GC

참조 카운트로 해결할 수 없는 문제로 순환 참조가 있다. 이 순환 참조의 메모리 관리를 위한 방식이 GC(Garbage Collection)이다.

a = obj()
b = obj()
c = obj()

a.b = b
b.c = c
c.a = a

a = b = c = None

 

순환 참조가 발생한 객체 관계도


오른쪽 그림을 보면 a, b, c의 참조 카운트는 모두 1이다. 그러나 사용자는 어느 객체에도 접근할 수 없다. 즉 불필요한 객체라는 의미이다. 그러나 a = b = c = None을 실행하여도 순환 참조의 참조 카운트는 0이 되지 않아서 메모리에서 삭제되지 않는다. 이때 필요한 메모리 관리 방식이 GC이다.

 

GC는 메모리가 부족해지는 시점에 파이썬 인터프리터에 의해 자동으로 호출된다. (gc모듈을 import하여 gc.collect()로 명시적 호출도 가능) 따라서 일반 파이썬 프로그래밍에서는 순환 참조를 의식할 필요가 없지만, 그렇다고 메모리 관리를 GC에 넘기다 보면 순환 참조가 없을 때와 비교해서 메모리 사용량이 커지게 된다. 신경망에서는 메모리가 중요한 자원이므로 DeZero를 개발할 때는 되도록 순환 참조를 만들지 않도록 할 것이다.

 

DeZero에는 순환 참조가 존재한다. '변수'와 '함수'를 연결하는 방식에 순환 참조가 있다.

 

Variable과 Function사이의 순환 참조

 

이러한 순환 참조는 표준 파이썬 모듈인 weakref로 해결할 수 있다.

 

※weakref 모듈

파이썬에서 weakref.ref 함수를 사용하여 약한 참조(weak reference)를 만들 수 있다. 약한 참조란 다른 객체를 참조하되 참조 카운트는 증가시키지 않는 기능이다.

import weakref
import numpy as np

a = np.array([1, 2, 3])
b = weakref.ref(a)

b() #[1 2 3]

a는 일반적인 방식으로 참조하고 b는 약한 참조를 갖도록 하였다. 참조된 데이터에 접근하려면 b()라고 쓰면 된다.

 

a = None을 실행하면 결과는 아래와 같다.

a = None
b #<weakref at 0x103b7f048; dead>

ndarray 인스턴스는 메모리에서 삭제된다. b는 약한 참조이므로 참조 카운트에는 영향을 주지 못하지만, b를 출력했을 때 dead라는 문자를 보고 삭제된 것을 알 수 있다. (파이썬 인터프리터에 한한 현상)

 

이제, 이 weakref구조를 DeZero에 도입할 것이다.

 

 

3) Function 클래스에 도입

import weakref

class function:
	def __call__(self, *inputs):
    	xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple)
        	ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]
        
        self.generation = max([x.generation for x in inputs])
        for output in outputs:
        	output.set_creator(self)
        self.inputs = inputs
        self.outputs = [weakref.ref(output) for output in outputs] #weakref적용
        return outputs if len(outputs) > 1 else outputs[0]

self.outputs가 대상을 약한 참조로 가리키게 변경하였다. 이 변경의 여파로 다른 클래스에서 Function클래스의 outputs를 참조하는 코드도 수정해야 한다. 

 

 

4) Variable 클래스의 backward 메서드에 도입

class Variable:

	#생략
    
    def backward(self):
    	
        #생략
        
        while funcs:
        f = funcs.pop()
        #수정 전: gys = [output.grad for output in f.outputs]
        gys = [output().grad for output in f.outputs]
        
        #생략

참조한 데이터에 접근하기 위해서는 괄호를 붙인다고 언급하였다. 따라서 [output.grad for ~] 부분을 [output().grad for ~]로 수정하였다. 이상 DeZero의 순환 참조 문제가 해결되었다.

 


 

2. 메모리 절약 모드

지금까지 파이썬의 메모리 관리 방식을 알아보았다. 이번에는 DeZero의 메모리 사용을 개선할 수 있는 구조 두 가지를 도입할 것이다. 첫 번째는 역전파 시 불필요한 미분 결과를 즉시 삭제하는 것이고, 두 번째는 역전파가 필요 없는 경우용 모드를 제공하는 것이다.

 

1) 필요 없는 미분값 삭제

x0 = Variable(np.array(1.0))
x1 = Variable(np.array(1.0))
t = add(x0, x1)
y = add(x0, t)
y.backward()

print(y.grad, t.grad) #1.0 1.0
print(x0.grad, x1.grad) #2.0 1.0

 

역전파로 구하고자 하는 미분값은 말단 변수인 x0, x1인데 y, t같은 중간 변수의 미분값도 저장되는 것을 볼 수 있다. 말단 변수만 구하고 나머지 필요 없는 미분값을 삭제하기 위해 아래와 같이 처리할 수 있다.

 

class Variable:
	
    #생략
    
    def backward(self, retain_grad = False):
    	if self.grad is None:
        	self.grad = np.ones_like(self.data)
            
        funcs = []
        seen_set = set()
        
        def add_func(f):
        	if f not in seen_set:
            	funcs.append(f)
                seen_set.add(f)
                funcs.sort(key=lambda x: x.generation)
                
       	add_func(self.creator)
        
        while funcs:
        	f = funcs.pop()
            gys = [output().grad for output in f.outputs]
            gxs = f.backward(*gys)
            if not isinstance(gxs, tuple):
            	gxs = (gxs,)
                
            for x, gx in zip(f.inputs, gxs):
            	if x.grad is None:
                	x.grad = gx
                else:
                	x.grad += gx
                    
                if x.creator is not None:
                	add_func(x.creator)
                    
            if not retain_grad:
            	for y in f.outputs:
                	y().grad = None #y는 약한 참조

 

메서드 인수에 retain_grad를 추가한다. retain_grad가 True이면 모든 미분값이 유지되고, False이면 중간 변수의 미분값은 None으로 재설정하고 말단 변수의 미분값만 유지한다.

 

x0 = Variable(np.array(1.0))
x1 = Variable(np.array(1.0))
t = add(x0, x1)
y = add(x0, t)
y.backward()

print(y.grad, t.grad) #None None
print(x0.grad, x1.grad) #2.0 1.0

여기까지 DeZero 메모리 사용의 첫 번째 개선이 완료되었다.

 

 

2) 역전파가 필요 없는 경우용 모드

DeZero에서 미분을 하려면 순전파를 수행한 뒤 역전파를 해야 한다. 역전파 때는 순전파의 계산 결과가 필요하므로, 그 결과값을 저장해둔다. 그 로직은 아래 Function 클래스의 주석처리한 부분에 있다.

 

class Function:
	def __call__(self, *inputs):
    	xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
        	ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]
        
        self.generation = max([x.generation for x in inputs])
        for output in outputs:
        	output.set_creator(self)
        self.inputs = inputs #################이 부분에서 순전파 결과값 저장###############
        self.outputs = [weakref.ref(output) for output in outputs]
        return outputs if len(outputs) > 1 else outputs[0]

 

함수의 입력을 인스턴스 변수 inputs으로 참조하였다. 역전파하는 경우라면 참조할 변수를 inputs에 보관해야 한다. (신경망에는 학습과 추론이 있는데, 학습 시에는 미분값을 구해야 하지만, 추론 시에는 단순 순전파만 하므로 중간 계산 결과를 즉시 버려야 메모리 사용량을 줄일 수 있음) 

 

 

역전파가 필요 없는 경우 메모리 절약을 위해, Config 클래스를 활용하여 '역전파 활성 모드'와 '역전파 비활성 모드'를 전환하는 구조를 사용할 것이다.

class Config:
	enable_backprop = True

enable_backprop은 불리언 타입으로 역전파 가능 여부를 의미한다. True이면 '역전파 활성 모드'이다.

 

이제 Config 클래스를 Function에서 참조하게 하여 모드를 전환할 수 있도록 할 것이다.

class Function:
	def __call__(self, *inputs):
    	xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
        	ys = (ys,)
        outputs = [Variable(as_variable(y)) for y in ys]
        
        if Config.enable_backprop:
        	self.generation = max([x.generation for x in inputs]) #세대 설정
            for output in outputs:
            	output.set_creator(self) #연결 설정
            self.inputs = inputs
            self.outputs = [weakref.ref(output) for output in outputs]
            
        return outputs if len(outputs) > 1 else outputs[0]

이제 Config.enable_backprop이 True일 때만 역전파 코드가 실행된다.

 

모드 전환을 적용하면 아래와 같다.

Config.enable_backprop = True
x = Variable(np.ones((100, 100, 100))) #형상이 (100, 100, 100)인 텐서
y = square(square(square(x))) #중간 계산 결과가 유지되면서 원소별 제곱이 세 번 적용됨
y.backward()

Config.enable_backprop = False
x = Variable(np.ones((100, 100, 100)))
y = square(square(square(x))) #중간 계산 결과는 사용 후 즉시 삭제됨

 

※with 문을 활용한 모드 전환

파이썬에는 후처리를 자동으로 수행하고자 할 때 사용할 수 있는 with 구문이 있다. 대표적인 예는 파일의 open과 close이다.

f = open('sample.txt', 'w')
f.write('hello world!')
f.close()

open()으로 파일을 열고, 무언가 쓰고, close()로 닫는다. 매번 close()하는 것이 귀찮고 잊을 때가 있는데 이때 with를 사용할 수 있다.

with open('sample.txt', 'w') as f:
	f.write('hello world!')

with블록에 들어갈 때 파일이 열리고(전처리) 빠져나올 때 자동으로 닫힌다(후처리).

 

 

이러한 with문의 원리를 이용하여 '역전파 비활성 모드'로 전환할 것이다.

with using_config('enable_backprop', False):
	x = Variable(np.array(2.0))
    y = square(x)

with 블록을 들어가는 using_config('enable_backprop', False): 안에서만 '역전파 비활성 모드'가 되고 with블록을 나오면 '역전파 활성 모드'로 돌아간다. (실전에서 학습 도중에 모델평가를 하기 위해 '역전파 비활성 모드'로 일시적으로 전환하는 방법을 자주 사용함)

 

with문을 사용하여 모드 전환을 구현하기 위해 contextlib모듈을 사용할 것이다. 먼저 contextlib모듈에 대해 설명하자면,

import contextlib

@contextlib.contextmanager
def config_test():
	print('start') #전처리
    try:
    	yield
    finally:
    	print('done') #후처리
        
with config_test():
	print('process...')
    
    
#실행 결과
start
process...
done

 

@contextlib.contextmanager 데코레이터를 달면 문맥을 판단하는 함수가 만들어진다. 이 함수에서 yield 전에는 전처리

로직을, 후에는 후처리 로직을 작성한다. 

 

이 contextlib을 바탕으로 using_config함수를 아래와 같이 구현할 수 있다.

import contextlib

@contextlib.contextmanager
def using_config(name, value):
	old_value = getattr(Config, name)
    setattr(Config, name, value)
    try:
    	yield
    finally:
    	setattr(Config, name, old_value)

 

실제로 사용해보면,

with using_config('enable_backprop', 'False'):
	x = Variable(np.array(2.0))
    y - square(x)

 

그러나 매번 with using_config('enable_backprop', False): 라는 긴 코드를 적기 힘드니 no_grad라는 함수를 만들 것이다.

def no_grad():
	return using_config('enable_backprop', False)
    
with no_grad():
	x = Variable(np.array(2.0))
    y = square(x)

이제 기울기가 필요 없을 때는 no_grad함수를 호출하면 된다. 즉, 단순 순전파 계산만 필요할 때는 no_grad를 통해 '함수 전환'을 사용할 수 있다.

728x90
반응형

댓글