-
Notifications
You must be signed in to change notification settings - Fork 305
/
Copy pathtorch_nightly_utils.py
227 lines (200 loc) · 7.18 KB
/
torch_nightly_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""
Return a list of recent PyTorch wheels published on download.pytorch.org.
Users can specify package name, python version, platform, and the number of days to return.
If one of the packages specified is missing on one day, the script will skip outputing the results on that day.
"""
import argparse
import os
import re
import subprocess
import urllib.parse
from collections import defaultdict
from datetime import date, timedelta
from pathlib import Path
from typing import List
import requests
from bs4 import BeautifulSoup
from cuda_utils import CUDA_VERSION_MAP, DEFAULT_CUDA_VERSION # @manual
from python_utils import DEFAULT_PYTHON_VERSION, PYTHON_VERSION_MAP
PYTORCH_CUDA_VERISON = CUDA_VERSION_MAP[DEFAULT_CUDA_VERSION]["pytorch_url"]
PYTORCH_PYTHON_VERSION = PYTHON_VERSION_MAP[DEFAULT_PYTHON_VERSION]["pytorch_url"]
torch_wheel_nightly_base = (
f"https://download.pytorch.org/whl/nightly/{PYTORCH_CUDA_VERISON}/"
)
torch_nightly_wheel_index = (
f"https://download.pytorch.org/whl/nightly/{PYTORCH_CUDA_VERISON}/torch"
)
torch_nightly_wheel_index_override = "torch_nightly.html"
def memoize(function):
""" """
call_cache = {}
def memoized_function(*f_args):
if f_args in call_cache:
return call_cache[f_args]
call_cache[f_args] = result = function(*f_args)
return result
return memoized_function
@memoize
def get_wheel_index_data(
py_version,
platform_version,
url=torch_nightly_wheel_index,
override_file=torch_nightly_wheel_index_override,
):
""" """
if os.path.isfile(override_file) and os.stat(override_file).st_size:
with open(override_file) as f:
data = f.read()
else:
r = requests.get(url)
r.raise_for_status()
data = r.text
soup = BeautifulSoup(data, "html.parser")
data = defaultdict(dict)
for link in soup.find_all("a"):
group_match = re.search("([a-z]*)-(.*)-(.*)-(.*)-(.*)\.whl", link.text)
# some packages (e.g., torch-rec) doesn't follow this naming convention
if not group_match:
continue
pkg, version, py, py_m, platform = group_match.groups()
version = urllib.parse.unquote(version)
if py == py_version and platform == platform_version:
full_url = os.path.join(
torch_wheel_nightly_base, urllib.parse.quote_plus(link.text)
)
data[pkg][version] = full_url
return data
def get_nightly_wheel_urls(
packages: list,
date: date,
py_version=PYTORCH_PYTHON_VERSION,
platform_version="linux_x86_64",
):
"""Gets urls to wheels for specified packages matching the date, py_version, platform_version"""
date_str = f"{date.year}{date.month:02}{date.day:02}"
data = get_wheel_index_data(py_version, platform_version)
rc = {}
for pkg in packages:
pkg_versions = data[pkg]
# multiple versions could happen when bumping the pytorch version number
# e.g., both torch-1.11.0.dev20220211%2Bcu113-cp38-cp38-linux_x86_64.whl and
# torch-1.12.0.dev20220212%2Bcu113-cp38-cp38-linux_x86_64.whl exist in the download link
keys = sorted([key for key in pkg_versions if date_str in key], reverse=True)
if len(keys) > 1:
print(
f"Warning: multiple versions matching a single date: {keys}, using {keys[0]}"
)
if len(keys) == 0:
return None
full_url = pkg_versions[keys[0]]
rc[pkg] = {
"version": keys[0],
"wheel": full_url,
}
return rc
def get_nightly_wheels_in_range(
packages: list,
start_date: date,
end_date: date,
py_version=PYTORCH_PYTHON_VERSION,
platform_version="linux_x86_64",
reverse=False,
):
rc = []
curr_date = start_date
while curr_date <= end_date:
curr_wheels = get_nightly_wheel_urls(
packages,
curr_date,
py_version=py_version,
platform_version=platform_version,
)
if curr_wheels is not None:
rc.append(curr_wheels)
curr_date += timedelta(days=1)
if reverse:
rc.reverse()
return rc
def get_n_prior_nightly_wheels(
packages: list,
n: int,
py_version=PYTORCH_PYTHON_VERSION,
platform_version="linux_x86_64",
reverse=False,
):
end_date = date.today()
start_date = end_date - timedelta(days=n)
return get_nightly_wheels_in_range(
packages,
start_date,
end_date,
py_version=py_version,
platform_version=platform_version,
reverse=reverse,
)
def get_most_recent_successful_wheels(
packages: list, pyver: str, platform: str
) -> List[str]:
"""Get the most recent successful nightly wheels. Return List[str]"""
curr_date = date.today()
date_limit = curr_date - timedelta(days=365)
while curr_date >= date_limit:
wheels = get_nightly_wheel_urls(
packages, curr_date, py_version=pyver, platform_version=platform
)
if wheels:
return wheels
curr_date = curr_date - timedelta(days=1)
# Can't find any valid pytorch package
return None
def install_wheels(wheels):
"""Install the wheels specified in the wheels."""
wheel_urls = list(map(lambda x: wheels[x]["wheel"], wheels.keys()))
work_dir = Path(__file__).parent.joinpath(".data")
work_dir.mkdir(parents=True, exist_ok=True)
requirements_file = work_dir.joinpath("requirements.txt").resolve()
with open(requirements_file, "w") as rf:
rf.write("\n".join(wheel_urls))
command = ["pip", "install", "-r", str(requirements_file)]
print(f"Installing pytorch nightly packages command: {command}")
subprocess.check_call(command)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--pyver",
type=str,
default=PYTORCH_PYTHON_VERSION,
help="PyTorch Python version",
)
parser.add_argument(
"--platform", type=str, default="linux_x86_64", help="PyTorch platform"
)
parser.add_argument("--priordays", type=int, default=1, help="Number of days")
parser.add_argument("--reverse", action="store_true", help="Return reversed result")
parser.add_argument(
"--packages", required=True, type=str, nargs="+", help="List of package names"
)
parser.add_argument(
"--install-nightlies",
action="store_true",
help="Install the most recent successfully built nightly packages",
)
args = parser.parse_args()
if args.install_nightlies:
wheels = get_most_recent_successful_wheels(
args.packages, args.pyver, args.platform
)
assert wheels, f"We do not find any successful pytorch nightly build of packages: {args.packages}."
print(f"Found pytorch nightly wheels: {wheels} ")
install_wheels(wheels)
exit(0)
wheels = get_n_prior_nightly_wheels(
packages=args.packages,
n=args.priordays,
py_version=args.pyver,
platform_version=args.platform,
reverse=args.reverse,
)
for wheelset in wheels:
for pkg in wheelset:
print(f"{pkg}-{wheelset[pkg]['version']}: {wheelset[pkg]['wheel']}")