[tensorflow / pytorch] 모델 저장 /불러오기 코드

Tensorflow

모델을 파일에 저장하는 코드

cnn.save("my_cnn.h5")

save함수는 신경망의 구조 정보와 가중치 정보를 저장한다.

대용량 데이터를 저장하는데 널리 쓰이는 HDF5파일 형식을 사용하기 때문에 확장자를 h5로 지정한다.

 

모델을 다시 불러다 쓰는 코드

#신경망 구조와 가중치를 저장하고 있는 파일을 읽어옴
cnn=tf.keras.models.load_model('my_cnn.h5')

이렇게 해도 되고, 아니면 해당 API를 선언하고 사용해도됨

from tensorflow.keras.models import load_model
model= load_model("파일명")

 

텐서플로는 학습하고 있는 모델객체가 (변수명이) cnn이면

cnn.save() 로 모델을 저장하는데 파이토치는 다르다.

 

 

Pytorch

모델을 파일에 저장하는 코드

torch.save({'epoch': epoch + 1,
	'state_dict': netG.state_dict()},
	'model/netG_streetview.pth')

torch.save()함수로 모델을 저장하는데, 이때 인자로 모델객체가 들어간다.

netG가 모델객체였고 여기서 state_dict()를 이용해 저장을한다.

이때, 저장을 딕셔너리형태로 해도됨 (물론 딕셔너리 저장하는거 텐서플로도 가능함)

 

모델을 다시 불러다 쓰는 코드

netG.load_state_dict(torch.load(opt.netG, map_location=lambda storage, location: storage)['state_dict'])

이때 opt.netG는 모델이 저장되어있는 위치이다!!

위에서 딕셔너리로 에폭과 state_dict을 저장하였는데

모델에 부르기 위해선 state_dict이 필요하므로 딕셔너리에서 state_dict을 뽑아 state_dict함수로 netG 모델객체에 넣어준다.

저장해둔 에폭을 쓰고 싶다면

resume_epoch = torch.load(opt.netG)['epoch']

이렇게 쓰면 된다.