Tensorflow
모델을 파일에 저장하는 코드
cnn.save("my_cnn.h5")
save함수는 신경망의 구조 정보와 가중치 정보를 저장한다.
대용량 데이터를 저장하는데 널리 쓰이는 HDF5파일 형식을 사용하기 때문에 확장자를 h5로 지정한다.
모델을 다시 불러다 쓰는 코드
#신경망 구조와 가중치를 저장하고 있는 파일을 읽어옴
cnn=tf.keras.models.load_model('my_cnn.h5')
이렇게 해도 되고, 아니면 해당 API를 선언하고 사용해도됨
from tensorflow.keras.models import load_model
model= load_model("파일명")
텐서플로는 학습하고 있는 모델객체가 (변수명이) cnn이면
cnn.save() 로 모델을 저장하는데 파이토치는 다르다.
Pytorch
모델을 파일에 저장하는 코드
torch.save({'epoch': epoch + 1,
'state_dict': netG.state_dict()},
'model/netG_streetview.pth')
torch.save()함수로 모델을 저장하는데, 이때 인자로 모델객체가 들어간다.
netG가 모델객체였고 여기서 state_dict()를 이용해 저장을한다.
이때, 저장을 딕셔너리형태로 해도됨 (물론 딕셔너리 저장하는거 텐서플로도 가능함)
모델을 다시 불러다 쓰는 코드
netG.load_state_dict(torch.load(opt.netG, map_location=lambda storage, location: storage)['state_dict'])
이때 opt.netG는 모델이 저장되어있는 위치이다!!
위에서 딕셔너리로 에폭과 state_dict을 저장하였는데
모델에 부르기 위해선 state_dict이 필요하므로 딕셔너리에서 state_dict을 뽑아 state_dict함수로 netG 모델객체에 넣어준다.
저장해둔 에폭을 쓰고 싶다면
resume_epoch = torch.load(opt.netG)['epoch']
이렇게 쓰면 된다.