Skip to content

[WIP] Sample images when checkpointing. #2157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from

Conversation

LucasSloan
Copy link

I based this on the in progress sampling in https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py due to the suggestion on #2030 that that was a good example to follow.

Unfortunately, this code doesn't work at present and I'm not sure why. I get the error RuntimeError: Input type (c10::Half) and bias type (float) should be the same, full stack trace:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/c/Users/lucas/Development/diffusers/examples/text_to_image/train_text_to_image.py:757 in    │
│ <module>                                                                                         │
│                                                                                                  │
│   754                                                                                            │
│   755                                                                                            │
│   756 if __name__ == "__main__":                                                                 │
│ ❱ 757 │   main()                                                                                 │
│   758                                                                                            │
│                                                                                                  │
│ /mnt/c/Users/lucas/Development/diffusers/examples/text_to_image/train_text_to_image.py:720 in    │
│ main                                                                                             │
│                                                                                                  │
│   717 │   │   │   │   │   │   │                                                                  │
│   718 │   │   │   │   │   │   │   # run inference                                                │
│   719 │   │   │   │   │   │   │   prompt = [args.validation_prompt]                              │
│ ❱ 720 │   │   │   │   │   │   │   images = pipeline(prompt, num_images_per_prompt=args.num_val   │
│   721 │   │   │   │   │   │   │                                                                  │
│   722 │   │   │   │   │   │   │   for i, image in enumerate(images):                             │
│   723 │   │   │   │   │   │   │   │   image.save(os.path.join(args.output_dir, f"sample-{globa   │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/utils/_contextlib.py:115 in                 │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_sta │
│ ble_diffusion.py:611 in __call__                                                                 │
│                                                                                                  │
│   608 │   │   │   │   latent_model_input = self.scheduler.scale_model_input(latent_model_input   │
│   609 │   │   │   │                                                                              │
│   610 │   │   │   │   # predict the noise residual                                               │
│ ❱ 611 │   │   │   │   noise_pred = self.unet(                                                    │
│   612 │   │   │   │   │   latent_model_input,                                                    │
│   613 │   │   │   │   │   t,                                                                     │
│   614 │   │   │   │   │   encoder_hidden_states=prompt_embeds,                                   │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1488 in _call_impl     │
│                                                                                                  │
│   1485 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1486 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1487 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1488 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1489 │   │   # Do not call functions when jit is used                                          │
│   1490 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1491 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py:482 in      │
│ forward                                                                                          │
│                                                                                                  │
│   479 │   │   │   emb = emb + class_emb                                                          │
│   480 │   │                                                                                      │
│   481 │   │   # 2. pre-process                                                                   │
│ ❱ 482 │   │   sample = self.conv_in(sample)                                                      │
│   483 │   │                                                                                      │
│   484 │   │   # 3. down                                                                          │
│   485 │   │   down_block_res_samples = (sample,)                                                 │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1488 in _call_impl     │
│                                                                                                  │
│   1485 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1486 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1487 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1488 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1489 │   │   # Do not call functions when jit is used                                          │
│   1490 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1491 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:463 in forward           │
│                                                                                                  │
│    460 │   │   │   │   │   │   self.padding, self.dilation, self.groups)                         │
│    461 │                                                                                         │
│    462 │   def forward(self, input: Tensor) -> Tensor:                                           │
│ ❱  463 │   │   return self._conv_forward(input, self.weight, self.bias)                          │
│    464                                                                                           │
│    465 class Conv3d(_ConvNd):                                                                    │
│    466 │   __doc__ = r"""Applies a 3D convolution over an input signal composed of several inpu  │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:459 in _conv_forward     │
│                                                                                                  │
│    456 │   │   │   return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=sel  │
│    457 │   │   │   │   │   │   │   weight, bias, self.stride,                                    │
│    458 │   │   │   │   │   │   │   _pair(0), self.dilation, self.groups)                         │
│ ❱  459 │   │   return F.conv2d(input, weight, bias, self.stride,                                 │
│    460 │   │   │   │   │   │   self.padding, self.dilation, self.groups)                         │
│    461 │                                                                                         │
│    462 │   def forward(self, input: Tensor) -> Tensor:                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Input type (c10::Half) and bias type (float) should be the same

I tried to fix it on line 713 by setting torch_dtype=weight_dtype on the StableDiffusionPipeline, but that didn't work.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@Weifeng-Chen
Copy link

+1

vae=vae,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
torch_dtype=weight_dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch_dtype=weight_dtype,


# run inference
prompt = [args.validation_prompt]
images = pipeline(prompt, num_images_per_prompt=args.num_validation_images).images
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
images = pipeline(prompt, num_images_per_prompt=args.num_validation_images).images
with torch.autocast("cuda"):
images = pipeline(prompt, num_images_per_prompt=args.num_validation_images).images

@patil-suraj that's one of the cases we should use autocast I think no?

@patrickvonplaten
Copy link
Contributor

Hey @LucasSloan,

Thanks a lot for opening the PR. Want to explain a bit what's going on here.

When training / fine-tuning stable diffusion models we noticed the following:

Now this works fine in training because accelerate takes care of everything. However during inference, we run into the problem that the UNet is in fp32 while the text encoder is in fp16. We could cast the unet into float16 which would work just fine for inference, but then we're breaking the model for the training after inference (remember trainable weights should stay in fp32). Thus, the solution here is to use autocast to automatically cast down the unet if necessary as shown in the PR review.

Note: We never recommend using autocast for pure inference but only for such special training cases. Does this make sense? Could you try whether it works with autocast?

Also could you maybe try to add the wandb and tensorflow logger here as well:

for tracker in accelerator.trackers:

?

Also cc @patil-suraj

@Weifeng-Chen
Copy link

I meet a similar issue and I solve it by disabling the safety checker.(since it wasn't used when training and maybe the type wasn't converted) . Meanwhile, map the latents, noise, noisy_latents to self.unet.dtype may help as well.(when I training with lightning, the variable defined in training loop wasn't converted...)

@patil-suraj
Copy link
Contributor

Related issue and discussion

#2163 (comment)
#2173 (review)

@LucasSloan
Copy link
Author

Using torch.autocast() makes sense to me, but doesn't seem to resolve the issue. The other PR implementing similar functionality seems to be getting around it by not using fp16 weights at all (by reloading the full weights). Any other thoughts?

@LucasSloan
Copy link
Author

Fixed it by not unwrapping the unet.

@LucasSloan
Copy link
Author

Added wandb and tensorboard integration.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, I left a couple of comments.

Comment on lines 724 to 727
# run inference
prompt = [args.validation_prompt]
with torch.autocast("cuda"):
images = pipeline(prompt, num_images_per_prompt=args.num_validation_images).images
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me!

We should do the inference in loop over num_validation_images to avoid OOM. cf

generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])

@@ -691,6 +710,40 @@ def collate_fn(examples):
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")

if args.validation_prompt:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add validation_epochs and generate according to that instead of generating after each loop.
you could refer to this script to see how to do that

if args.validation_prompt is not None and epoch % args.validation_epochs == 0:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this should be wrapped under the main process condition (if accelerator.is_main_process:) to handle situations for multi-GPU training.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm currently using the args.checkpointing_steps. Do we have a preference for # of epochs vs. # of global steps? I slightly favor # of global steps, since that's how we're controlling checkpointing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for validation images, we prefer epochs since conceptually it's a bit simpler to think of when the inference is going to take place.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on it.

Apart from the comments shared here, I would also like point out that inference should be performed with the EMA'd weights (when use_ema is True) as mentioned by @patil-suraj here.

)
del pipeline
torch.cuda.empty_cache()

logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also want to add another final inference logging as done in the train_dreambooth_lora.py example?

if args.validation_prompt and args.num_validation_images > 0:


# run inference
prompt = [args.validation_prompt]
with torch.autocast("cuda"):
Copy link
Member

@sayakpaul sayakpaul Feb 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with torch.autocast("cuda"):
with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"):

Courtesy: @patil-suraj (#2173 (review))

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another minor nit.

@sayakpaul
Copy link
Member

sayakpaul commented Feb 2, 2023

So, I have been testing it since this morning. Here are my findings:

  • I could run the intermediate validation inference successfully with FP16, autocasting, and appropriate EMA updates.
  • However, during the final inference run which we usually do after pushing the pipeline files to the Hub, it is still failing.
Traceback (most recent call last):                                                                                                                                     | 0/30 [00:00<?, ?it/s]
  File "train_text_to_image.py", line 982, in <module>
    main()
  File "train_text_to_image.py", line 953, in main
    pipeline(
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 611, in __call__
    noise_pred = self.unet(
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 482, in forward
    sample = self.conv_in(sample)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (c10::Half) and bias type (float) should be the same

I think we can potentially get around the issue is we cast the UNet and the Safety Checker modules to weight_dtype after initializing the pipeline:

f accelerator.is_main_process:
    unet = accelerator.unwrap_model(unet)
    if args.use_ema:
        ema_unet.copy_to(unet.parameters())

    pipeline = StableDiffusionPipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        text_encoder=text_encoder,
        vae=vae,
        unet=unet,
        revision=args.revision,
    )

    ...

    # before running inference
    pipeline.unet = unet.to(weight_dtype)
    pipeline.safety_checker = pipeline.safety_checker.to(weight_dtype)

Here's my gist that contains the modified train_text_to_image.py script, ema.py script (thanks to @patil-suraj for the suggestions here), and the execution instructions.

Let me know if anything is unclear.

@patil-suraj
Copy link
Contributor

@sayakpaul Thanks for testing this! Think we should wrap the pipe call in autocast here as well, we can-not explicitly cast the model here as we always save models in full-precision.

@sayakpaul
Copy link
Member

sayakpaul commented Feb 2, 2023

@patil-suraj here's what I did:

  • While saving the final pipeline, I didn't use the text_encoder and vae to avoid the mismatch issues:
pipeline = StableDiffusionPipeline.from_pretrained(
    args.pretrained_model_name_or_path,
    unet=unet,
    revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)
  • And then before running the final inference, I move the pipeline to accelerator.device.

With these changes, things seem to work.

Would you like to test it with the gist (from #2157 (comment)) and the changes above?

Also, would you like me to open a PR adding the store() and restore() methods to EMAModel?

@LucasSloan
Copy link
Author

I've been testing this change, and it seems like it doesn't actually do the fine tuning. Possibly creating the pipeline overwrites the weights, losing the training progress? I'll do some more testing to confirm, but if someone else could try (and maybe think about why it would happen), that'd be great.

@LucasSloan
Copy link
Author

Another feature I'd like to have is the ability to provide multiple prompts. However, when I added action="append" to the --validation_prompt argument, I get an error on this line:

accelerator.init_trackers("text2image-fine-tune", config=vars(args))

Where tensorboard doesn't like the fact that the --validation_prompt argument is a list instead of one of the basic types it supports. Does anyone have a suggestion for fixing that?

@sayakpaul
Copy link
Member

I've been testing this change, and it seems like it doesn't actually do the fine tuning. Possibly creating the pipeline overwrites the weights, losing the training progress? I'll do some more testing to confirm, but if someone else could try (and maybe think about why it would happen), that'd be great.

Could you try with the train_text_to_image.py script mentioned in this gist (as mentioned here)? Also, take note of the changes suggested in #2157 (comment).

Another feature I'd like to have is the ability to provide multiple prompts. However, when I added action="append" to the --validation_prompt argument, I get an error on this line:

accelerator.init_trackers("text2image-fine-tune", config=vars(args))

Where tensorboard doesn't like the fact that the --validation_prompt argument is a list instead of one of the basic types it supports. Does anyone have a suggestion for fixing that?

Our example scripts are meant to be as simple as possible. So, I would de-provision this feature for the time being :) But to support it in the tracker, you could maybe try to use nargs? This post has some good reference examples.

@LucasSloan
Copy link
Author

@sayakpaul if you could add the methods to EMAModel, that'd be great.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

Some pending changes. I will let @patil-suraj also to do a review.

@LucasSloan
Copy link
Author

Figured out why the model wasn't training - I was using the --enable_xformers_memory_efficient_attention flag. Is that known not to work or do I have an issue with my set up?

@LucasSloan
Copy link
Author

Tried re-installing xformers, and instead of no training, the safety_checker tripped (training on the Pokemon dataset, validation prompt "Yoda"). Tried again, disabling the safety_checker, and I got black images anyway, along with the error message:

/home/lucas/.local/lib/python3.8/site-packages/diffusers/pipelines/pipeline_utils.py:813: RuntimeWarning: invalid value encountered in cast
  images = (images * 255).round().astype("uint8")

@LucasSloan
Copy link
Author

Added EMA support.

Can someone look into those test failures? I wouldn't expect changes to this file do anything... Are they known issues?

@sayakpaul
Copy link
Member

Can someone look into those test failures? I wouldn't expect changes to this file do anything... Are they known issues?

Thanks a lot for the changes. Yeah the failing tests are unrelated.

pipeline.set_progress_bar_config(disable=True)

# run inference
prompt = [args.validation_prompt]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this required? We can simply pass the validation prompt, no?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the changes.

If you tested the latest changes, could you please comment on the exact command and environment you used?

This will be helpful for the community as well as for the diffusers team to replicate your tests.

@LucasSloan
Copy link
Author

Rebased and the tests fixed themselves.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul do you think this PR still makes sense? Should we try to get it in?

@sayakpaul
Copy link
Member

@sayakpaul do you think this PR still makes sense? Should we try to get it in?

I think so, yes! I am going to review and test the PR tomorrow and comment accordingly.

@sayakpaul
Copy link
Member

@LucasSloan I tried testing the code today and I observed the following command:

accelerate launch --mixed_precision="fp16"  examples/text_to_image/train_text_to_image.py   \
	--pretrained_model_name_or_path=$MODEL_NAME   --dataset_name=$DATASET_NAME   \
	--use_ema   \
	--resolution=512 --center_crop --random_flip   \
	--train_batch_size=1   \
	--gradient_accumulation_steps=4   --gradient_checkpointing   \
	--max_train_steps=20 --max_train_samples=5 \
	--enable_xformers_memory_efficient_attention \
	--learning_rate=1e-05   --max_grad_norm=1   --lr_scheduler="constant" --lr_warmup_steps=0   \
	--mixed_precision="fp16" \
	--validation_prompt="cute dragon creature" --num_validation_images=3 --validation_steps=1  \
	--seed=666  \
	--output_dir="sd-pokemon-model"

there's no validation inference being done. Is there something I am missing out on?

@github-actions
Copy link
Contributor

github-actions bot commented Apr 7, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Apr 7, 2023
@github-actions github-actions bot closed this Apr 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants