Skip to content

Commit 4fc3b97

Browse files
authored
Merge pull request openai#108 from openai/fix-shd
Restore adam state and save training step for lr decaying to work
2 parents 7939619 + 1c0ad3f commit 4fc3b97

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

jukebox/make_models.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,18 @@ def load_checkpoint(path):
3838
print("Restored from {}".format(restore))
3939
return checkpoint
4040

41-
def save_checkpoint(logdir, name, model, opt, metrics, hps):
41+
def save_checkpoint(logger, name, model, opt, metrics, hps):
4242
with t.no_grad():
4343
save_hps = {**hps}
4444
save_hps = {k: v for k,v in save_hps.items() if k not in ['metadata_v2','metadata_v3', 'alignments', 'lyric_processor', 'midi_processor']}
4545
t.save({'hps': save_hps,
4646
'model': model.state_dict(), # should also save bottleneck k's as buffers
4747
'opt': opt.state_dict() if opt is not None else None,
48-
**metrics}, f'{logdir}/checkpoint_{name}.pth.tar')
48+
'step': logger.iters,
49+
**metrics}, f'{logger.logdir}/checkpoint_{name}.pth.tar')
4950
return
5051

51-
def restore(hps, model, checkpoint_path):
52+
def restore_model(hps, model, checkpoint_path):
5253
model.step = 0
5354
if checkpoint_path != '':
5455
checkpoint = load_checkpoint(checkpoint_path)
@@ -60,6 +61,15 @@ def restore(hps, model, checkpoint_path):
6061
model.load_state_dict(checkpoint['model'])
6162
if 'step' in checkpoint: model.step = checkpoint['step']
6263

64+
def restore_opt(opt, shd, checkpoint_path):
65+
if not checkpoint_path:
66+
return
67+
checkpoint = load_checkpoint(checkpoint_path)
68+
if "opt" in checkpoint:
69+
opt.load_state_dict(checkpoint['opt'])
70+
if "step" in checkpoint:
71+
shd.step(checkpoint['step'])
72+
6373
def make_vqvae(hps, device='cuda'):
6474
from jukebox.vqvae.vqvae import VQVAE
6575
block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv,
@@ -82,7 +92,7 @@ def make_vqvae(hps, device='cuda'):
8292
**block_kwargs)
8393

8494
vqvae = vqvae.to(device)
85-
restore(hps, vqvae, hps.restore_vqvae)
95+
restore_model(hps, vqvae, hps.restore_vqvae)
8696
if hps.train and not hps.prior:
8797
print_all(f"Loading vqvae in train mode")
8898
if hps.restore_vqvae != '':
@@ -166,7 +176,7 @@ def make_prior(hps, vqvae, device='cuda'):
166176
from jukebox.transformer.ops import _convert_conv_weights_to_fp16
167177
prior.apply(_convert_conv_weights_to_fp16)
168178
prior = prior.to(device)
169-
restore(hps, prior, hps.restore_prior)
179+
restore_model(hps, prior, hps.restore_prior)
170180
if hps.train:
171181
print_all(f"Loading prior in train mode")
172182
pass

jukebox/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.nn.parallel import DistributedDataParallel
1313

1414
from jukebox.hparams import setup_hparams
15-
from jukebox.make_models import make_vqvae, make_prior, save_checkpoint
15+
from jukebox.make_models import make_vqvae, make_prior, restore_opt, save_checkpoint
1616
from jukebox.utils.logger import init_logging
1717
from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess
1818
from jukebox.utils.torch_utils import zero_grad, count_parameters
@@ -86,6 +86,9 @@ def get_optimizer(model, hps):
8686
# lr scheduler
8787
shd = get_lr_scheduler(opt, hps)
8888

89+
restore_path = hps.restore_prior if hps.prior else hps.restore_vqvae
90+
restore_opt(opt, shd, restore_path)
91+
8992
# fp16 dynamic loss scaler
9093
scalar = None
9194
if hps.fp16:
@@ -266,7 +269,7 @@ def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_proces
266269
orig_model.eval()
267270
name = 'latest' if hps.prior else f'step_{logger.iters}'
268271
if dist.get_rank() % 8 == 0:
269-
save_checkpoint(logger.logdir, name, orig_model, opt, dict(step=logger.iters), hps)
272+
save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps)
270273
orig_model.train()
271274
if ema is not None: ema.swap()
272275

0 commit comments

Comments
 (0)