Diffusers documentation
LoRA
LoRA
LoRA (Low-Rank Adaptation) is a method for quickly training a model for a new task. It works by freezing the original model weights and adding a small number of new trainable parameters. This means it is significantly faster and cheaper to adapt an existing model to new tasks, such as generating images in a new style.
LoRA checkpoints are typically only a couple hundred MBs in size, so they’re very lightweight and easy to store. Load these smaller set of weights into an existing base model with load_lora_weights() and specify the file name.
import torch
from diffusers import AutoPipelineForText2Image
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"ostris/super-cereal-sdxl-lora",
weight_name="cereal_box_sdxl_v1.safetensors",
adapter_name="cereal"
)
pipeline("bears, pizza bites").images[0]
The load_lora_weights() method is the preferred way to load LoRA weights into the UNet and text encoder because it can handle cases where:
- the LoRA weights don’t have separate UNet and text encoder identifiers
- the LoRA weights have separate UNet and text encoder identifiers
The load_lora_adapter() method is used to directly load a LoRA adapter at the model-level, as long as the model is a Diffusers model that is a subclass of PeftAdapterMixin
. It builds and prepares the necessary model configuration for the adapter. This method also loads the LoRA adapter into the UNet.
For example, if you’re only loading a LoRA into the UNet, load_lora_adapter() ignores the text encoder keys. Use the prefix
parameter to filter and load the appropriate state dicts, "unet"
to load.
import torch
from diffusers import AutoPipelineForText2Image
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
pipeline.unet.load_lora_adapter(
"jbilcke-hf/sdxl-cinematic-1",
weight_name="pytorch_lora_weights.safetensors",
adapter_name="cinematic"
prefix="unet"
)
# use cnmt in the prompt to trigger the LoRA
pipeline("A cute cnmt eating a slice of pizza, stunning color scheme, masterpiece, illustration").images[0]
torch.compile
torch.compile speeds up inference by compiling the PyTorch model to use optimized kernels. Before compiling, the LoRA weights need to be fused into the base model and unloaded first.
import torch
from diffusers import DiffusionPipeline
# load base model and LoRA
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"ostris/ikea-instructions-lora-sdxl",
weight_name="ikea_instructions_xl_v1_5.safetensors",
adapter_name="ikea"
)
# activate LoRA and set adapter weight
pipeline.set_adapters("ikea", adapter_weights=0.7)
# fuse LoRAs and unload weights
pipeline.fuse_lora(adapter_names=["ikea"], lora_scale=1.0)
pipeline.unload_lora_weights()
Typically, the UNet is compiled because its the most compute intensive component of the pipeline.
pipeline.unet.to(memory_format=torch.channels_last)
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
pipeline("A bowl of ramen shaped like a cute kawaii bear").images[0]
Refer to the hotswapping section to learn how to avoid recompilation when working with compiled models and multiple LoRAs.
Weight scale
The scale
parameter is used to control how much of a LoRA to apply. A value of 0
is equivalent to only using the base model weights and a value of 1
is equivalent to fully using the LoRA.
For simple use cases, you can pass cross_attention_kwargs={"scale": 1.0}
to the pipeline.
import torch
from diffusers import AutoPipelineForText2Image
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"ostris/super-cereal-sdxl-lora",
weight_name="cereal_box_sdxl_v1.safetensors",
adapter_name="cereal"
)
pipeline("bears, pizza bites", cross_attention_kwargs={"scale": 1.0}).images[0]
Hotswapping
Hotswapping LoRAs is an efficient way to work with multiple LoRAs while avoiding accumulating memory from multiple calls to load_lora_weights() and in some cases, recompilation, if a model is compiled. This workflow requires a loaded LoRA because the new LoRA weights are swapped in place for the existing loaded LoRA.
import torch
from diffusers import DiffusionPipeline
# load base model and LoRAs
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"ostris/ikea-instructions-lora-sdxl",
weight_name="ikea_instructions_xl_v1_5.safetensors",
adapter_name="ikea"
)
Hotswapping is unsupported for LoRAs that target the text encoder.
Set hotswap=True
in load_lora_weights() to swap the second LoRA. Use the adapter_name
parameter to indicate which LoRA to swap (default_0
is the default name).
pipeline.load_lora_weights(
"lordjia/by-feng-zikai",
hotswap=True,
adapter_name="ikea"
)
Compiled models
For compiled models, use enable_lora_hotswap() to avoid recompilation when hotswapping LoRAs. This method should be called before loading the first LoRA and torch.compile
should be called after loading the first LoRA.
The enable_lora_hotswap() method isn’t always necessary if the second LoRA targets the identical LoRA ranks and scales as the first LoRA.
Within enable_lora_hotswap(), the target_rank
parameter is important for setting the rank for all LoRA adapters. Setting it to max_rank
sets it to the highest value. For LoRAs with different ranks, you set it to a higher rank value. The default rank value is 128.
import torch
from diffusers import DiffusionPipeline
# load base model and LoRAs
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
# 1. enable_lora_hotswap
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(
"ostris/ikea-instructions-lora-sdxl",
weight_name="ikea_instructions_xl_v1_5.safetensors",
adapter_name="ikea"
)
# 2. torch.compile
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
# 3. hotswap
pipeline.load_lora_weights(
"lordjia/by-feng-zikai",
hotswap=True,
adapter_name="ikea"
)
Move your code inside the with torch._dynamo.config.patch(error_on_recompile=True)
context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an issue with a reproducible example.
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers first. For more details about this limitation, refer to the PEFT hotswapping docs.
Merge
The weights from each LoRA can be merged together to produce a blend of multiple existing styles. There are several methods for merging LoRAs, each of which differ in how the weights are merged (may affect generation quality).
set_adapters
The set_adapters() method merges LoRAs by concatenating their weighted matrices. Pass the LoRA names to set_adapters() and use the adapter_weights
parameter to control the scaling of each LoRA. For example, if adapter_weights=[0.5, 0.5]
, the output is an average of both LoRAs.
The "scale"
parameter determines how much of the merged LoRA to apply. See the Weight scale section for more details.
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"ostris/ikea-instructions-lora-sdxl",
weight_name="ikea_instructions_xl_v1_5.safetensors",
adapter_name="ikea"
)
pipeline.load_lora_weights(
"lordjia/by-feng-zikai",
weight_name="fengzikai_v1.0_XL.safetensors",
adapter_name="feng"
)
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
# use by Feng Zikai to activate the lordjia/by-feng-zikai LoRA
pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai", cross_attention_kwargs={"scale": 1.0}).images[0]

add_weighted_adapter
This is an experimental method and you can refer to PEFTs Model merging for more details. Take a look at this issue if you’re interested in the motivation and design behind this integration.
The ~peft.LoraModel.add_weighted_adapter
method enables more efficient merging methods like TIES or DARE. These merging methods remove redundant and potentially interfering parameters from merged models. Keep in mind the LoRA ranks need to have identical ranks to be merged.
Make sure the latest stable version of Diffusers and PEFT is installed.
pip install -U -q diffusers peft
Load a UNET that corresponds to the LoRA UNet.
import copy
import torch
from diffusers import AutoModel, DiffusionPipeline
from peft import get_peft_model, LoraConfig, PeftModel
unet = AutoModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
subfolder="unet",
).to("cuda")
Load a pipeline, pass the UNet to it, and load a LoRA.
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
variant="fp16",
torch_dtype=torch.float16,
unet=unet
).to("cuda")
pipeline.load_lora_weights(
"ostris/ikea-instructions-lora-sdxl",
weight_name="ikea_instructions_xl_v1_5.safetensors",
adapter_name="ikea"
)
Create a ~peft.PeftModel
from the LoRA checkpoint by combining the first UNet you loaded and the LoRA UNet from the pipeline.
sdxl_unet = copy.deepcopy(unet)
ikea_peft_model = get_peft_model(
sdxl_unet,
pipeline.unet.peft_config["ikea"],
adapter_name="ikea"
)
original_state_dict = {f"base_model.model.{k}": v for k, v in pipeline.unet.state_dict().items()}
ikea_peft_model.load_state_dict(original_state_dict, strict=True)
You can save and reuse the ikea_peft_model
by pushing it to the Hub as shown below.
ikea_peft_model.push_to_hub("ikea_peft_model", token=TOKEN)
Repeat this process and create a ~peft.PeftModel
for the second LoRA.
pipeline.delete_adapters("ikea")
sdxl_unet.delete_adapters("ikea")
pipeline.load_lora_weights(
"lordjia/by-feng-zikai",
weight_name="fengzikai_v1.0_XL.safetensors",
adapter_name="feng"
)
pipeline.set_adapters(adapter_names="feng")
feng_peft_model = get_peft_model(
sdxl_unet,
pipeline.unet.peft_config["feng"],
adapter_name="feng"
)
original_state_dict = {f"base_model.model.{k}": v for k, v in pipe.unet.state_dict().items()}
feng_peft_model.load_state_dict(original_state_dict, strict=True)
Load a base UNet model and load the adapters.
base_unet = AutoModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
subfolder="unet",
).to("cuda")
model = PeftModel.from_pretrained(
base_unet,
"stevhliu/ikea_peft_model",
use_safetensors=True,
subfolder="ikea",
adapter_name="ikea"
)
model.load_adapter(
"stevhliu/feng_peft_model",
use_safetensors=True,
subfolder="feng",
adapter_name="feng"
)
Merge the LoRAs with ~peft.LoraModel.add_weighted_adapter
and specify how you want to merge them with combination_type
. The example below uses the "dare_linear"
method (refer to this blog post to learn more about these merging methods), which randomly prunes some weights and then performs a weighted sum of the tensors based on the set weightage of each LoRA in weights
.
Activate the merged LoRAs with set_adapters().
model.add_weighted_adapter(
adapters=["ikea", "feng"],
combination_type="dare_linear",
weights=[1.0, 1.0],
adapter_name="ikea-feng"
)
model.set_adapters("ikea-feng")
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
unet=model,
variant="fp16",
torch_dtype=torch.float16,
).to("cuda")
pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai").images[0]

fuse_lora
The fuse_lora() method fuses the LoRA weights directly with the original UNet and text encoder weights of the underlying model. This reduces the overhead of loading the underlying model for each LoRA because it only loads the model once, which lowers memory usage and increases inference speed.
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"ostris/ikea-instructions-lora-sdxl",
weight_name="ikea_instructions_xl_v1_5.safetensors",
adapter_name="ikea"
)
pipeline.load_lora_weights(
"lordjia/by-feng-zikai",
weight_name="fengzikai_v1.0_XL.safetensors",
adapter_name="feng"
)
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
Call fuse_lora() to fuse them. The lora_scale
parameter controls how much to scale the output by with the LoRA weights. It is important to make this adjustment now because passing scale
to cross_attention_kwargs
won’t work in the pipeline.
pipeline.fuse_lora(adapter_names=["ikea", "feng"], lora_scale=1.0)
Unload the LoRA weights since they’re already fused with the underlying model. Save the fused pipeline with either save_pretrained() to save it locally or ~PushToHubMixin.push_to_hub
to save it to the Hub.
pipeline.unload_lora_weights()
pipeline.save_pretrained("path/to/fused-pipeline")
The fused pipeline can now be quickly loaded for inference without requiring each LoRA to be separately loaded.
pipeline = DiffusionPipeline.from_pretrained(
"username/fused-ikea-feng", torch_dtype=torch.float16,
).to("cuda")
pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai").images[0]
Use unfuse_lora()
to restore the underlying models weights, for example, if you want to use a different lora_scale
value. You can only unfuse if there is a single LoRA fused. For example, it won’t work with the pipeline from above because there are multiple fused LoRAs. In these cases, you’ll need to reload the entire model.
pipeline.unfuse_lora()

Manage
Diffusers provides several methods to help you manage working with LoRAs. These methods can be especially useful if you’re working with multiple LoRAs.
set_adapters
set_adapters() also activates the current LoRA to use if there are multiple active LoRAs. This allows you to switch between different LoRAs by specifying their name.
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"ostris/ikea-instructions-lora-sdxl",
weight_name="ikea_instructions_xl_v1_5.safetensors",
adapter_name="ikea"
)
pipeline.load_lora_weights(
"lordjia/by-feng-zikai",
weight_name="fengzikai_v1.0_XL.safetensors",
adapter_name="feng"
)
# activates the feng LoRA instead of the ikea LoRA
pipeline.set_adapters("feng")
save_lora_adapter
Save an adapter with save_lora_adapter().
import torch
from diffusers import AutoPipelineForText2Image
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
pipeline.unet.load_lora_adapter(
"jbilcke-hf/sdxl-cinematic-1",
weight_name="pytorch_lora_weights.safetensors",
adapter_name="cinematic"
prefix="unet"
)
pipeline.save_lora_adapter("path/to/save", adapter_name="cinematic")
unload_lora_weights
The unload_lora_weights() method unloads any LoRA weights in the pipeline to restore the underlying model weights.
pipeline.unload_lora_weights()
disable_lora
The disable_lora() method disables all LoRAs (but they’re still kept on the pipeline) and restores the pipeline to the underlying model weights.
pipeline.disable_lora()
get_active_adapters
The get_active_adapters() method returns a list of active LoRAs attached to a pipeline.
pipeline.get_active_adapters()
["cereal", "ikea"]
get_list_adapters
The get_list_adapters() method returns the active LoRAs for each component in the pipeline.
pipeline.get_list_adapters()
{"unet": ["cereal", "ikea"], "text_encoder_2": ["cereal"]}
delete_adapters
The delete_adapters() method completely removes a LoRA and its layers from a model.
pipeline.delete_adapters("ikea")
Resources
Browse the LoRA Studio for different LoRAs to use or you can upload your favorite LoRAs from Civitai to the Hub with the Space below.
You can find additional LoRAs in the FLUX LoRA the Explorer and LoRA the Explorer Spaces.
< > Update on GitHub