-
Notifications
You must be signed in to change notification settings - Fork 512
/
Copy pathcallback.py
34 lines (24 loc) · 939 Bytes
/
callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from typing import Callable
import torch
import torch_xla
import threading
def on_ready_callback(tensor, callback: Callable[[torch.Tensor], None]):
"""Installs callback on `tensor` to be called when underlying buffer is ready.
Note: Since `callback` will need to re-acquire the GIL since it is a Python
callable. If the main thread is blocking on `callback` and holding the GIL,
this will result in a deadlock.
"""
def _callback_wrapper():
callback(tensor)
torch_xla._XLAC._on_ready_callback(tensor, _callback_wrapper)
def on_ready_event(tensor: torch.Tensor) -> threading.Event:
"""Return a python threading.event that will be set once underlying
tensor buffer is ready.
Args:
tensor: tensor that the event will be blocked on
"""
ready_event = threading.Event()
def _callback_wrapper():
ready_event.set()
torch_xla._XLAC._on_ready_callback(tensor, _callback_wrapper)
return ready_event