-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Refactor gradscaler #99301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor gradscaler #99301
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99301
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 94f5fee: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
b2a76fe
to
8b1f176
Compare
8b1f176
to
7271793
Compare
I'll let @janeyx99 take a first stab at reviewing this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! Left several comments
"is deprecated. It will be removed in the future and " | ||
"use torch.amp.grad_scaler._refresh_per_optimizer_state()") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"is deprecated. It will be removed in the future and " | |
"use torch.amp.grad_scaler._refresh_per_optimizer_state()") | |
"is deprecated. It will be removed in the future. Instead " | |
"use torch.amp.grad_scaler._refresh_per_optimizer_state()") |
nit for clarity
@@ -170,7 +48,8 @@ def scale(self, outputs): | |||
return outputs * self._scale.to(device=outputs.device, non_blocking=True) | |||
|
|||
# Invoke the more complex machinery only if we're treating multiple outputs. | |||
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale | |||
# holds a reference that can be overwritten by apply_scale | |||
stash: List[torch.amp.grad_scaler._MultiDeviceReplicator] = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of typing out the full torch.amp.grad_scaler.
every time in the file, can these be imported at the top for clarity + ease of reading?
Lazily serves copies of a tensor to requested devices. Copies are cached per-device. | ||
""" | ||
def __init__(self, master_tensor: torch.Tensor) -> None: | ||
self.master = master_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you removed the assertion for the device check = CUDA and XLA. We will still want to keep these during the refactoring until we can confidently say all devices are supported.
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). | ||
|
||
Args: | ||
device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add documentation for this new variable so people know what values it can take and what it's used for!
Default: ``True`` | ||
""" | ||
def __init__(self, | ||
device="cuda", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would prefer not defaulting to CUDA here but passing in the right value in the cuda/amp/grad_scaler for device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, we may want to assert that the device is either cuda or xla here
return outputs | ||
|
||
# Short-circuit for the common case. | ||
if isinstance(outputs, torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also see the assert statement is no longer there--same comment as above about keeping the assertions until we're confident other devices work.
|
||
for device, per_dtype_grads in per_device_and_dtype_grads.items(): | ||
for grads in per_dtype_grads.values(): | ||
torch._amp_foreach_non_finite_check_and_unscale_(grads, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To give more context about device support, this call, for example, only exists for CUDA.
""" | ||
return self._enabled | ||
|
||
def state_dict(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also include the new device attr, along with the load_state_dict and other state related functions below
@@ -83,6 +83,7 @@ class ShardedGradScaler(GradScaler): | |||
|
|||
def __init__( | |||
self, | |||
device: str = "cuda", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is CUDA only, then shouldn't it inherit from the GradScaler from torch.cuda.amp.grad_scaler?
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Fixes #ISSUE_NUMBER
Now, the GradScaler related code is at
torch/cuda/amp
, but we think for different device (cuda/xpu, etc), the strategy to update scale should be basically the same. So it is may be better to move these code to torch/amp, so that we can inherit theGradScaler
defined intorch/amp
for other devices (cuda/xla/mps, ... and custom device).And most importantly, this will not break backward. @bdhirsh @albanD
as we talked at this discuss, https://dev-discuss.pytorch.org/t/improve-the-extension-with-privateuse1-for-custom-device/1196/7
cc @mcarilli @ptrblck @leslie-fang-intel @jgong5