Skip to content

Commit f486d34

Browse files
authored
Make ControlNet SD Training Script torch.compile compatible (#6525)
* update: make controlnet script torch compile compatible Signed-off-by: Suvaditya Mukherjee <[email protected]> * update: correct earlier mistakes for compilation Signed-off-by: Suvaditya Mukherjee <[email protected]> * update: fix code style issues Signed-off-by: Suvaditya Mukherjee <[email protected]> --------- Signed-off-by: Suvaditya Mukherjee <[email protected]>
1 parent e44b205 commit f486d34

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

examples/controlnet/train_controlnet.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from diffusers.optimization import get_scheduler
5151
from diffusers.utils import check_min_version, is_wandb_available
5252
from diffusers.utils.import_utils import is_xformers_available
53+
from diffusers.utils.torch_utils import is_compiled_module
5354

5455

5556
if is_wandb_available():
@@ -787,6 +788,12 @@ def main(args):
787788
logger.info("Initializing controlnet weights from unet")
788789
controlnet = ControlNetModel.from_unet(unet)
789790

791+
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
792+
def unwrap_model(model):
793+
model = accelerator.unwrap_model(model)
794+
model = model._orig_mod if is_compiled_module(model) else model
795+
return model
796+
790797
# `accelerate` 0.16.0 will have better support for customized saving
791798
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
792799
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -846,9 +853,9 @@ def load_model_hook(models, input_dir):
846853
" doing mixed precision training, copy of the weights should still be float32."
847854
)
848855

849-
if accelerator.unwrap_model(controlnet).dtype != torch.float32:
856+
if unwrap_model(controlnet).dtype != torch.float32:
850857
raise ValueError(
851-
f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
858+
f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
852859
)
853860

854861
# Enable TF32 for faster training on Ampere GPUs,
@@ -1015,7 +1022,7 @@ def load_model_hook(models, input_dir):
10151022
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
10161023

10171024
# Get the text embedding for conditioning
1018-
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
1025+
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
10191026

10201027
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
10211028

@@ -1036,7 +1043,8 @@ def load_model_hook(models, input_dir):
10361043
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
10371044
],
10381045
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1039-
).sample
1046+
return_dict=False,
1047+
)[0]
10401048

10411049
# Get the target for loss depending on the prediction type
10421050
if noise_scheduler.config.prediction_type == "epsilon":
@@ -1109,7 +1117,7 @@ def load_model_hook(models, input_dir):
11091117
# Create the pipeline using using the trained modules and save it.
11101118
accelerator.wait_for_everyone()
11111119
if accelerator.is_main_process:
1112-
controlnet = accelerator.unwrap_model(controlnet)
1120+
controlnet = unwrap_model(controlnet)
11131121
controlnet.save_pretrained(args.output_dir)
11141122

11151123
if args.push_to_hub:

0 commit comments

Comments
 (0)