@@ -38,17 +38,18 @@ def load_checkpoint(path):
3838 print ("Restored from {}" .format (restore ))
3939 return checkpoint
4040
41- def save_checkpoint (logdir , name , model , opt , metrics , hps ):
41+ def save_checkpoint (logger , name , model , opt , metrics , hps ):
4242 with t .no_grad ():
4343 save_hps = {** hps }
4444 save_hps = {k : v for k ,v in save_hps .items () if k not in ['metadata_v2' ,'metadata_v3' , 'alignments' , 'lyric_processor' , 'midi_processor' ]}
4545 t .save ({'hps' : save_hps ,
4646 'model' : model .state_dict (), # should also save bottleneck k's as buffers
4747 'opt' : opt .state_dict () if opt is not None else None ,
48- ** metrics }, f'{ logdir } /checkpoint_{ name } .pth.tar' )
48+ 'step' : logger .iters ,
49+ ** metrics }, f'{ logger .logdir } /checkpoint_{ name } .pth.tar' )
4950 return
5051
51- def restore (hps , model , checkpoint_path ):
52+ def restore_model (hps , model , checkpoint_path ):
5253 model .step = 0
5354 if checkpoint_path != '' :
5455 checkpoint = load_checkpoint (checkpoint_path )
@@ -60,6 +61,15 @@ def restore(hps, model, checkpoint_path):
6061 model .load_state_dict (checkpoint ['model' ])
6162 if 'step' in checkpoint : model .step = checkpoint ['step' ]
6263
64+ def restore_opt (opt , shd , checkpoint_path ):
65+ if not checkpoint_path :
66+ return
67+ checkpoint = load_checkpoint (checkpoint_path )
68+ if "opt" in checkpoint :
69+ opt .load_state_dict (checkpoint ['opt' ])
70+ if "step" in checkpoint :
71+ shd .step (checkpoint ['step' ])
72+
6373def make_vqvae (hps , device = 'cuda' ):
6474 from jukebox .vqvae .vqvae import VQVAE
6575 block_kwargs = dict (width = hps .width , depth = hps .depth , m_conv = hps .m_conv ,
@@ -82,7 +92,7 @@ def make_vqvae(hps, device='cuda'):
8292 ** block_kwargs )
8393
8494 vqvae = vqvae .to (device )
85- restore (hps , vqvae , hps .restore_vqvae )
95+ restore_model (hps , vqvae , hps .restore_vqvae )
8696 if hps .train and not hps .prior :
8797 print_all (f"Loading vqvae in train mode" )
8898 if hps .restore_vqvae != '' :
@@ -166,7 +176,7 @@ def make_prior(hps, vqvae, device='cuda'):
166176 from jukebox .transformer .ops import _convert_conv_weights_to_fp16
167177 prior .apply (_convert_conv_weights_to_fp16 )
168178 prior = prior .to (device )
169- restore (hps , prior , hps .restore_prior )
179+ restore_model (hps , prior , hps .restore_prior )
170180 if hps .train :
171181 print_all (f"Loading prior in train mode" )
172182 pass
0 commit comments