Skip to content

Latest commit

 

History

History
113 lines (83 loc) · 4.66 KB

aura_flow.md

File metadata and controls

113 lines (83 loc) · 4.66 KB

AuraFlow

AuraFlow is inspired by Stable Diffusion 3 and is by far the largest text-to-image generation model that comes with an Apache 2.0 license. This model achieves state-of-the-art results on the GenEval benchmark.

It was developed by the Fal team and more details about it can be found in this blog post.

AuraFlow can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out this section for more details.

Quantization

Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.

Refer to the Quantization overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [AuraFlowPipeline] for inference with bitsandbytes.

import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, AuraFlowTransformer2DModel, AuraFlowPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel

quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = T5EncoderModel.from_pretrained(
    "fal/AuraFlow",
    subfolder="text_encoder",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = AuraFlowTransformer2DModel.from_pretrained(
    "fal/AuraFlow",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

pipeline = AuraFlowPipeline.from_pretrained(
    "fal/AuraFlow",
    text_encoder=text_encoder_8bit,
    transformer=transformer_8bit,
    torch_dtype=torch.float16,
    device_map="balanced",
)

prompt = "a tiny astronaut hatching from an egg on the moon"
image = pipeline(prompt).images[0]
image.save("auraflow.png")

Loading GGUF checkpoints are also supported:

import torch
from diffusers import (
    AuraFlowPipeline,
    GGUFQuantizationConfig,
    AuraFlowTransformer2DModel,
)

transformer = AuraFlowTransformer2DModel.from_single_file(
    "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)

pipeline = AuraFlowPipeline.from_pretrained(
    "fal/AuraFlow-v0.3",
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)

prompt = "a cute pony in a field of flowers"
image = pipeline(prompt).images[0]
image.save("auraflow.png")

Support for torch.compile()

AuraFlow can be compiled with torch.compile() to speed up inference latency even for different resolutions. First, install PyTorch nightly following the instructions from here. The snippet below shows the changes needed to enable this:

+ torch.fx.experimental._config.use_duck_shape = False
+ pipeline.transformer = torch.compile(
    pipeline.transformer, fullgraph=True, dynamic=True
)

Specifying use_duck_shape to be False instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out this comment.

This enables from 100% (on low resolutions) to a 30% (on 1536x1536 resolution) speed improvements.

Thanks to AstraliteHeart who helped us rewrite the [AuraFlowTransformer2DModel] class so that the above works for different resolutions (PR).

AuraFlowPipeline

[[autodoc]] AuraFlowPipeline - all - call