-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[WIP] Sample images when checkpointing. #2157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
+1 |
vae=vae, | ||
unet=accelerator.unwrap_model(unet), | ||
revision=args.revision, | ||
torch_dtype=weight_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch_dtype=weight_dtype, |
|
||
# run inference | ||
prompt = [args.validation_prompt] | ||
images = pipeline(prompt, num_images_per_prompt=args.num_validation_images).images |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
images = pipeline(prompt, num_images_per_prompt=args.num_validation_images).images | |
with torch.autocast("cuda"): | |
images = pipeline(prompt, num_images_per_prompt=args.num_validation_images).images |
@patil-suraj that's one of the cases we should use autocast I think no?
Hey @LucasSloan, Thanks a lot for opening the PR. Want to explain a bit what's going on here. When training / fine-tuning stable diffusion models we noticed the following:
Now this works fine in training because Note: We never recommend using autocast for pure inference but only for such special training cases. Does this make sense? Could you try whether it works with autocast? Also could you maybe try to add the wandb and tensorflow logger here as well:
? Also cc @patil-suraj |
I meet a similar issue and I solve it by disabling the safety checker.(since it wasn't used when training and maybe the type wasn't converted) . Meanwhile, map the latents, noise, noisy_latents to self.unet.dtype may help as well.(when I training with lightning, the variable defined in training loop wasn't converted...) |
Related issue and discussion |
Using torch.autocast() makes sense to me, but doesn't seem to resolve the issue. The other PR implementing similar functionality seems to be getting around it by not using fp16 weights at all (by reloading the full weights). Any other thoughts? |
Fixed it by not unwrapping the unet. |
Added wandb and tensorboard integration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, I left a couple of comments.
# run inference | ||
prompt = [args.validation_prompt] | ||
with torch.autocast("cuda"): | ||
images = pipeline(prompt, num_images_per_prompt=args.num_validation_images).images |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me!
We should do the inference in loop over num_validation_images
to avoid OOM. cf
diffusers/examples/text_to_image/train_text_to_image_lora.py
Lines 768 to 772 in 9213d81
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | |
images = [] | |
for _ in range(args.num_validation_images): | |
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) | |
@@ -691,6 +710,40 @@ def collate_fn(examples): | |||
accelerator.save_state(save_path) | |||
logger.info(f"Saved state to {save_path}") | |||
|
|||
if args.validation_prompt: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add validation_epochs
and generate according to that instead of generating after each loop.
you could refer to this script to see how to do that
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And this should be wrapped under the main process condition (if accelerator.is_main_process:
) to handle situations for multi-GPU training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm currently using the args.checkpointing_steps. Do we have a preference for # of epochs vs. # of global steps? I slightly favor # of global steps, since that's how we're controlling checkpointing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, for validation images, we prefer epochs since conceptually it's a bit simpler to think of when the inference is going to take place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on it.
Apart from the comments shared here, I would also like point out that inference should be performed with the EMA'd weights (when use_ema
is True) as mentioned by @patil-suraj here.
) | ||
del pipeline | ||
torch.cuda.empty_cache() | ||
|
||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also want to add another final inference logging as done in the train_dreambooth_lora.py
example?
if args.validation_prompt and args.num_validation_images > 0: |
|
||
# run inference | ||
prompt = [args.validation_prompt] | ||
with torch.autocast("cuda"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with torch.autocast("cuda"): | |
with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): |
Courtesy: @patil-suraj (#2173 (review))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another minor nit.
So, I have been testing it since this morning. Here are my findings:
Traceback (most recent call last): | 0/30 [00:00<?, ?it/s]
File "train_text_to_image.py", line 982, in <module>
main()
File "train_text_to_image.py", line 953, in main
pipeline(
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 611, in __call__
noise_pred = self.unet(
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 482, in forward
sample = self.conv_in(sample)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (c10::Half) and bias type (float) should be the same I think we can potentially get around the issue is we cast the UNet and the Safety Checker modules to f accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=text_encoder,
vae=vae,
unet=unet,
revision=args.revision,
)
...
# before running inference
pipeline.unet = unet.to(weight_dtype)
pipeline.safety_checker = pipeline.safety_checker.to(weight_dtype) Here's my gist that contains the modified Let me know if anything is unclear. |
@sayakpaul Thanks for testing this! Think we should wrap the |
@patil-suraj here's what I did:
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet,
revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)
With these changes, things seem to work. Would you like to test it with the gist (from #2157 (comment)) and the changes above? Also, would you like me to open a PR adding the |
I've been testing this change, and it seems like it doesn't actually do the fine tuning. Possibly creating the pipeline overwrites the weights, losing the training progress? I'll do some more testing to confirm, but if someone else could try (and maybe think about why it would happen), that'd be great. |
Another feature I'd like to have is the ability to provide multiple prompts. However, when I added
Where tensorboard doesn't like the fact that the |
Could you try with the
Our example scripts are meant to be as simple as possible. So, I would de-provision this feature for the time being :) But to support it in the tracker, you could maybe try to use |
be4e755
to
bdd961b
Compare
@sayakpaul if you could add the methods to EMAModel, that'd be great. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
Some pending changes. I will let @patil-suraj also to do a review.
Figured out why the model wasn't training - I was using the |
0aa774d
to
f8b4215
Compare
Tried re-installing xformers, and instead of no training, the safety_checker tripped (training on the Pokemon dataset, validation prompt "Yoda"). Tried again, disabling the safety_checker, and I got black images anyway, along with the error message:
|
Added EMA support. Can someone look into those test failures? I wouldn't expect changes to this file do anything... Are they known issues? |
Thanks a lot for the changes. Yeah the failing tests are unrelated. |
pipeline.set_progress_bar_config(disable=True) | ||
|
||
# run inference | ||
prompt = [args.validation_prompt] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this required? We can simply pass the validation prompt, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the changes.
If you tested the latest changes, could you please comment on the exact command and environment you used?
This will be helpful for the community as well as for the diffusers team to replicate your tests.
552e573
to
8ebcccb
Compare
Use batch size 1 and iterate over num_validation_images to avoid OOM. Set autocast device from accelerator.device.
8ebcccb
to
a3300ba
Compare
Rebased and the tests fixed themselves. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul do you think this PR still makes sense? Should we try to get it in?
I think so, yes! I am going to review and test the PR tomorrow and comment accordingly. |
@LucasSloan I tried testing the code today and I observed the following command: accelerate launch --mixed_precision="fp16" examples/text_to_image/train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME --dataset_name=$DATASET_NAME \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=20 --max_train_samples=5 \
--enable_xformers_memory_efficient_attention \
--learning_rate=1e-05 --max_grad_norm=1 --lr_scheduler="constant" --lr_warmup_steps=0 \
--mixed_precision="fp16" \
--validation_prompt="cute dragon creature" --num_validation_images=3 --validation_steps=1 \
--seed=666 \
--output_dir="sd-pokemon-model" there's no validation inference being done. Is there something I am missing out on? |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
I based this on the in progress sampling in https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py due to the suggestion on #2030 that that was a good example to follow.
Unfortunately, this code doesn't work at present and I'm not sure why. I get the error
RuntimeError: Input type (c10::Half) and bias type (float) should be the same
, full stack trace:I tried to fix it on line 713 by setting
torch_dtype=weight_dtype
on the StableDiffusionPipeline, but that didn't work.