Skip to content

Refactor gradscaler #99301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

heidongxianhua
Copy link
Contributor

@heidongxianhua heidongxianhua commented Apr 17, 2023

Fixes #ISSUE_NUMBER
Now, the GradScaler related code is at torch/cuda/amp, but we think for different device (cuda/xpu, etc), the strategy to update scale should be basically the same. So it is may be better to move these code to torch/amp, so that we can inherit the GradScaler defined in torch/amp for other devices (cuda/xla/mps, ... and custom device).
And most importantly, this will not break backward. @bdhirsh @albanD

as we talked at this discuss, https://dev-discuss.pytorch.org/t/improve-the-extension-with-privateuse1-for-custom-device/1196/7
cc @mcarilli @ptrblck @leslie-fang-intel @jgong5

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 17, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99301

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 94f5fee:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (sharded) release notes category label Apr 17, 2023
@heidongxianhua heidongxianhua force-pushed the refactor_gradscaler branch 2 times, most recently from b2a76fe to 8b1f176 Compare April 17, 2023 11:08
@albanD
Copy link
Collaborator

albanD commented Apr 18, 2023

I'll let @janeyx99 take a first stab at reviewing this.

@janeyx99 janeyx99 self-requested a review April 18, 2023 17:21
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! Left several comments

Comment on lines +13 to +14
"is deprecated. It will be removed in the future and "
"use torch.amp.grad_scaler._refresh_per_optimizer_state()")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"is deprecated. It will be removed in the future and "
"use torch.amp.grad_scaler._refresh_per_optimizer_state()")
"is deprecated. It will be removed in the future. Instead "
"use torch.amp.grad_scaler._refresh_per_optimizer_state()")

nit for clarity

@@ -170,7 +48,8 @@ def scale(self, outputs):
return outputs * self._scale.to(device=outputs.device, non_blocking=True)

# Invoke the more complex machinery only if we're treating multiple outputs.
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
# holds a reference that can be overwritten by apply_scale
stash: List[torch.amp.grad_scaler._MultiDeviceReplicator] = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of typing out the full torch.amp.grad_scaler. every time in the file, can these be imported at the top for clarity + ease of reading?

Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
"""
def __init__(self, master_tensor: torch.Tensor) -> None:
self.master = master_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you removed the assertion for the device check = CUDA and XLA. We will still want to keep these during the refactoring until we can confidently say all devices are supported.

iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).

Args:
device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add documentation for this new variable so people know what values it can take and what it's used for!

Default: ``True``
"""
def __init__(self,
device="cuda",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer not defaulting to CUDA here but passing in the right value in the cuda/amp/grad_scaler for device.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we may want to assert that the device is either cuda or xla here

return outputs

# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also see the assert statement is no longer there--same comment as above about keeping the assertions until we're confident other devices work.


for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(grads,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To give more context about device support, this call, for example, only exists for CUDA.

"""
return self._enabled

def state_dict(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also include the new device attr, along with the load_state_dict and other state related functions below

@@ -83,6 +83,7 @@ class ShardedGradScaler(GradScaler):

def __init__(
self,
device: str = "cuda",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is CUDA only, then shouldn't it inherit from the GradScaler from torch.cuda.amp.grad_scaler?

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 25, 2023
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 24, 2023
@github-actions github-actions bot closed this Jul 24, 2023
@heidongxianhua heidongxianhua deleted the refactor_gradscaler branch January 8, 2025 03:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: amp (automated mixed precision) autocast open source release notes: distributed (sharded) release notes category Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants