""" Return a list of recent PyTorch wheels published on download.pytorch.org. Users can specify package name, python version, platform, and the number of days to return. If one of the packages specified is missing on one day, the script will skip outputing the results on that day. """ import argparse import os import re import subprocess import urllib.parse from collections import defaultdict from datetime import date, timedelta from pathlib import Path from typing import List import requests from bs4 import BeautifulSoup from cuda_utils import CUDA_VERSION_MAP, DEFAULT_CUDA_VERSION # @manual from python_utils import DEFAULT_PYTHON_VERSION, PYTHON_VERSION_MAP PYTORCH_CUDA_VERISON = CUDA_VERSION_MAP[DEFAULT_CUDA_VERSION]["pytorch_url"] PYTORCH_PYTHON_VERSION = PYTHON_VERSION_MAP[DEFAULT_PYTHON_VERSION]["pytorch_url"] torch_wheel_nightly_base = ( f"https://download.pytorch.org/whl/nightly/{PYTORCH_CUDA_VERISON}/" ) torch_nightly_wheel_index = ( f"https://download.pytorch.org/whl/nightly/{PYTORCH_CUDA_VERISON}/torch" ) torch_nightly_wheel_index_override = "torch_nightly.html" def memoize(function): """ """ call_cache = {} def memoized_function(*f_args): if f_args in call_cache: return call_cache[f_args] call_cache[f_args] = result = function(*f_args) return result return memoized_function @memoize def get_wheel_index_data( py_version, platform_version, url=torch_nightly_wheel_index, override_file=torch_nightly_wheel_index_override, ): """ """ if os.path.isfile(override_file) and os.stat(override_file).st_size: with open(override_file) as f: data = f.read() else: r = requests.get(url) r.raise_for_status() data = r.text soup = BeautifulSoup(data, "html.parser") data = defaultdict(dict) for link in soup.find_all("a"): group_match = re.search("([a-z]*)-(.*)-(.*)-(.*)-(.*)\.whl", link.text) # some packages (e.g., torch-rec) doesn't follow this naming convention if not group_match: continue pkg, version, py, py_m, platform = group_match.groups() version = urllib.parse.unquote(version) if py == py_version and platform == platform_version: full_url = os.path.join( torch_wheel_nightly_base, urllib.parse.quote_plus(link.text) ) data[pkg][version] = full_url return data def get_nightly_wheel_urls( packages: list, date: date, py_version=PYTORCH_PYTHON_VERSION, platform_version="linux_x86_64", ): """Gets urls to wheels for specified packages matching the date, py_version, platform_version""" date_str = f"{date.year}{date.month:02}{date.day:02}" data = get_wheel_index_data(py_version, platform_version) rc = {} for pkg in packages: pkg_versions = data[pkg] # multiple versions could happen when bumping the pytorch version number # e.g., both torch-1.11.0.dev20220211%2Bcu113-cp38-cp38-linux_x86_64.whl and # torch-1.12.0.dev20220212%2Bcu113-cp38-cp38-linux_x86_64.whl exist in the download link keys = sorted([key for key in pkg_versions if date_str in key], reverse=True) if len(keys) > 1: print( f"Warning: multiple versions matching a single date: {keys}, using {keys[0]}" ) if len(keys) == 0: return None full_url = pkg_versions[keys[0]] rc[pkg] = { "version": keys[0], "wheel": full_url, } return rc def get_nightly_wheels_in_range( packages: list, start_date: date, end_date: date, py_version=PYTORCH_PYTHON_VERSION, platform_version="linux_x86_64", reverse=False, ): rc = [] curr_date = start_date while curr_date <= end_date: curr_wheels = get_nightly_wheel_urls( packages, curr_date, py_version=py_version, platform_version=platform_version, ) if curr_wheels is not None: rc.append(curr_wheels) curr_date += timedelta(days=1) if reverse: rc.reverse() return rc def get_n_prior_nightly_wheels( packages: list, n: int, py_version=PYTORCH_PYTHON_VERSION, platform_version="linux_x86_64", reverse=False, ): end_date = date.today() start_date = end_date - timedelta(days=n) return get_nightly_wheels_in_range( packages, start_date, end_date, py_version=py_version, platform_version=platform_version, reverse=reverse, ) def get_most_recent_successful_wheels( packages: list, pyver: str, platform: str ) -> List[str]: """Get the most recent successful nightly wheels. Return List[str]""" curr_date = date.today() date_limit = curr_date - timedelta(days=365) while curr_date >= date_limit: wheels = get_nightly_wheel_urls( packages, curr_date, py_version=pyver, platform_version=platform ) if wheels: return wheels curr_date = curr_date - timedelta(days=1) # Can't find any valid pytorch package return None def install_wheels(wheels): """Install the wheels specified in the wheels.""" wheel_urls = list(map(lambda x: wheels[x]["wheel"], wheels.keys())) work_dir = Path(__file__).parent.joinpath(".data") work_dir.mkdir(parents=True, exist_ok=True) requirements_file = work_dir.joinpath("requirements.txt").resolve() with open(requirements_file, "w") as rf: rf.write("\n".join(wheel_urls)) command = ["pip", "install", "-r", str(requirements_file)] print(f"Installing pytorch nightly packages command: {command}") subprocess.check_call(command) if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--pyver", type=str, default=PYTORCH_PYTHON_VERSION, help="PyTorch Python version", ) parser.add_argument( "--platform", type=str, default="linux_x86_64", help="PyTorch platform" ) parser.add_argument("--priordays", type=int, default=1, help="Number of days") parser.add_argument("--reverse", action="store_true", help="Return reversed result") parser.add_argument( "--packages", required=True, type=str, nargs="+", help="List of package names" ) parser.add_argument( "--install-nightlies", action="store_true", help="Install the most recent successfully built nightly packages", ) args = parser.parse_args() if args.install_nightlies: wheels = get_most_recent_successful_wheels( args.packages, args.pyver, args.platform ) assert wheels, f"We do not find any successful pytorch nightly build of packages: {args.packages}." print(f"Found pytorch nightly wheels: {wheels} ") install_wheels(wheels) exit(0) wheels = get_n_prior_nightly_wheels( packages=args.packages, n=args.priordays, py_version=args.pyver, platform_version=args.platform, reverse=args.reverse, ) for wheelset in wheels: for pkg in wheelset: print(f"{pkg}-{wheelset[pkg]['version']}: {wheelset[pkg]['wheel']}")