Skip to content

Commit 1953375

Browse files
committed
Fix model downloads
1 parent 6908052 commit 1953375

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

jukebox/make_models.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch as t
99
import jukebox.utils.dist_adapter as dist
1010
from jukebox.hparams import Hyperparams, setup_hparams
11-
from jukebox.utils.gcs_utils import download
11+
from jukebox.utils.remote_utils import download
1212
from jukebox.utils.torch_utils import freeze_model
1313
from jukebox.utils.dist_utils import print_all
1414
from jukebox.vqvae.vqvae import calculate_strides
@@ -23,15 +23,16 @@
2323

2424
def load_checkpoint(path):
2525
restore = path
26-
if restore[:5] == 'gs://':
27-
gs_path = restore
28-
local_path = os.path.join(os.path.expanduser("~/.cache"), gs_path[5:])
26+
remote_prefix = 'https://openaipublic.blob.core.windows.net/'
27+
if restore.startswith(remote_prefix):
28+
remote_path = restore
29+
local_path = os.path.join(os.path.expanduser("~/.cache"), remote_path[len(remote_prefix):])
2930
if dist.get_rank() % 8 == 0:
30-
print("Downloading from gce")
31+
print("Downloading from azure")
3132
if not os.path.exists(os.path.dirname(local_path)):
3233
os.makedirs(os.path.dirname(local_path))
3334
if not os.path.exists(local_path):
34-
download(gs_path, local_path)
35+
download(remote_path, local_path)
3536
restore = local_path
3637
dist.barrier()
3738
checkpoint = t.load(restore, map_location=t.device('cpu'))

jukebox/utils/gcs_utils.py renamed to jukebox/utils/remote_utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
import os
21
import sys
32
import subprocess
4-
from time import time
53

4+
def download(remote_path, local_path, async_download=False):
5+
args = ['wget', '-O', local_path, remote_path]
6+
print("Running ", " ".join(args))
7+
if async_download:
8+
subprocess.Popen(args)
9+
else:
10+
subprocess.call(args)
11+
12+
# GCE
613
def gs_download(gs_path, local_path, async_download=False):
714
args = ['gsutil',
815
'-o', 'GSUtil:parallel_thread_count=1',
@@ -27,16 +34,9 @@ def gs_upload(local_path, gs_path, async_upload=False):
2734
else:
2835
subprocess.call(args)
2936

30-
def download(gs_path, local_path, async_download=False):
31-
remote_path = gs_path.replace("gs://", "https://storage.googleapis.com/")
32-
args = ['wget', '-q', '-O', local_path, remote_path]
33-
if async_download:
34-
subprocess.Popen(args)
35-
else:
36-
subprocess.call(args)
37-
3837
def ls(regex):
3938
outputs = subprocess.check_output(['gsutil', 'ls', regex]).decode(sys.stdout.encoding)
4039
outputs = outputs.split('\n')
4140
outputs = [output for output in outputs if output is not '']
4241
return outputs
42+

0 commit comments

Comments
 (0)