-
Notifications
You must be signed in to change notification settings - Fork 513
/
Copy path_diffusion.py
112 lines (86 loc) · 3.35 KB
/
_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import functools
import torch
from time import time
from diffusers import DiffusionPipeline
from torch.utils import _pytree as pytree
import torchax
import torchax.functions
from torchax.extra import torch_view, jax_view
import jax
import torch.func
class CompiledModule:
def __init__(self, model):
weights = model.state_dict()
weights.update(model.named_parameters())
self._weights = pytree.tree_map_only(torch.Tensor, torchax.tensor.move_to_device, weights)
self._model = model
self._func_jitted_torch = None #torch_view(func_mod_jitted)
def _maybe_move_tensor(self, tensor):
if isinstance(tensor, torch.Tensor) and not isinstance(tensor, torchax.tensor.Tensor):
return torchax.tensor.move_to_device(tensor)
return tensor
def _make_jitted(self, args, kwargs):
static = []
for i, a in enumerate(args):
if not isinstance(a, torch.Tensor):
static.append(i + 1) # weight is 0
static_argnames = []
for k, v in kwargs.items():
if not isinstance(v, torch.Tensor):
static_argnames.append(k)
def f(weights, *args, **kwargs):
weights, args, kwargs = torchax.tensor.wrap((weights, args, kwargs))
with torchax.functions.XLAFunctionMode(), torchax.tensor.XLADispatchMode():
res = torch.func.functional_call(self._model, weights, args, kwargs)
if isinstance(res, tuple) and len(res) == 1:
res = res[0]
return torchax.tensor.unwrap(res)
fjit = jax.jit(f, static_argnames=tuple(static_argnames))
return torch_view(fjit)
def forward(self, *args, **kwargs):
(args, kwargs) = pytree.tree_map(self._maybe_move_tensor, (args, kwargs))
if self._func_jitted_torch is None:
self._func_jitted_torch = self._make_jitted(args, kwargs)
return self._func_jitted_torch(
self._weights,
*args,
**kwargs
)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def __getattr__(self, key):
return getattr(self._model, key)
def compile_pipe(pipe):
pipe.text_encoder = CompiledModule(pipe.text_encoder)
pipe.text_encoder_2 = CompiledModule(pipe.text_encoder_2)
pipe.unet = CompiledModule(pipe.unet)
pipe.vae = CompiledModule(pipe.vae)
def main():
pipe = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-0.9",
"stabilityai/stable-diffusion-xl-base-1.0",
use_safetensors=True,
)
compile_pipe(pipe)
global_bs = 10
inference_steps = 20
resol = 1024
prompts = ["a photo of an astronaut riding a horse on mars"] * global_bs
print(f'global batch size {global_bs}',
f'inference steps {inference_steps}',
f'Image resolution {resol}',
flush=True
)
iters = 5
for i in range(iters):
prompt = prompts
# print('per device prompts len',len(prompt))
# prompt = prompts[rank]
start = time()
image = pipe(prompt,
num_inference_steps=inference_steps,
height=resol,
width=resol).images[0]
print(f'Step {i} inference time {time()-start} sec', flush=True)
if __name__ == '__main__':
main()