-
Notifications
You must be signed in to change notification settings - Fork 26k
Implement public API InferenceMode and its error handling #53343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 09f5448 (more details on the Dr. CI page):
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
…handling" [ghstack-poisoned]
…handling" [ghstack-poisoned]
…handling" [ghstack-poisoned]
…handling" [ghstack-poisoned]
…handling" [ghstack-poisoned]
…handling" [ghstack-poisoned]
…handling" [ghstack-poisoned]
…handling" [ghstack-poisoned]
…handling" [ghstack-poisoned]
…handling" [ghstack-poisoned]
|
Was reviewing the spec some more and I think we can probably do some more safe relaxations:
(Don't block this PR on it) |
RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
bhosmer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM per Ed's analysis here and offline conversation 😁 one note about setting requires_grad, but if we want to do anything about that it can definitely wait for a followup.
test/cpp/api/inference_mode.cpp
Outdated
| ASSERT_TRUE(is_inference_tensor(c)); | ||
|
|
||
| torch::Tensor tmp = torch::ones({1, 2, 3}).set_requires_grad(true); | ||
| ASSERT_TRUE(tmp.requires_grad()); // requires_grad is silently ignored when it's an inference tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we decide there was a reason we wanted to silently ignore, like code reuse? Seems clear it would be saner devex to actually error if you set requires_grad true on an inference tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code got edited above, but just to be super explicit: it's not silently ignored; we DO set requires_grad=True (just like in no grad mode); the idea to not error is so that code that initializes parameters can run in inference mode without triggering errors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea code reuse is the main reason. And another reason is inference tensor can actually have requires_grad=true (and gradients) as shown in the second PR. :D
RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One case that I don't see mentioned here is:
a = torch.rand(10, requires_grad=True)
with inference_mode():
b = a.view_as(a)
with torch.no_grad():
c = b.view_as(b)
c += 1
I guess this is fine, but just want to make sure
pytorch/torch/csrc/autograd/variable.h
Lines 508 to 515 in dfc7fa0
| /// Handles correctly propagating CreationMeta when a new view is created from a previous view. | |
| /// In general, we don't want the new view to be _less_ restrictive than the previous view | |
| /// (it's okay to be _more_ restrictive). A CreationMeta value of DEFAULT is currently the least | |
| /// restrictive, as the behavior for all other CreationMeta values is to error out for in-place ops. | |
| /// If this changes, the logic here will need to be updated to properly handle the new semantics. | |
| inline CreationMeta propagate_creation_meta(CreationMeta prev_view_creation_meta, CreationMeta new_view_creation_meta) { | |
| return (new_view_creation_meta == CreationMeta::DEFAULT) ? prev_view_creation_meta : new_view_creation_meta; | |
| } |
Also after looking at all the examples, I am now convinced that we should present inference mode as no grad but more restrictive. As most of the behavior is actually the same.
RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
RFC: pytorch/rfcs#17 Differential Revision: [D26973911](https://our.internmc.facebook.com/intern/diff/D26973911) [ghstack-poisoned]
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Thanks for all the updates!
|
This pull request has been reverted by 263180d. |
) Summary: Pull Request resolved: pytorch#55008 reland of pytorch#53343 For easier review, here's a diff between the version before revert. https://www.internalfb.com/phabricator/paste/view/P361764610 Test Plan: Imported from OSS Differential Revision: D27443229 Pulled By: ailzhang fbshipit-source-id: faeaff3b6165b933c9f354d5f0344e38269fbb12
Summary: https://www.internalfb.com/phabricator/paste/view/P360377337Pull Request resolved: #53343 For easier review, here's a diff between the version before revert. https://www.internalfb.com/phabricator/paste/view/P360750919 Pull Request resolved: #55008 Test Plan: Imported from OSS Pulled By: ailzhang Reviewed By: bhosmer Differential Revision: D27443229 fbshipit-source-id: 01b03446a1f6373f43dd5c7170d26226b50f363c
Stack from ghstack:
RFC: pytorch/rfcs#17
Differential Revision: D26973911