unet中的attn_processor的修改(用于设计新的注意力模块)

参考资料


IP-adapter中设置attn的方法
参考的代码: 腾讯ailabipadapter 的官方训练代码

unet中的一些变量的数据情况

# init adapter modules
	#用来存储自己重构后的注意力处理器字典
    attn_procs = {}
    unet_sd = unet.state_dict()
    for name in unet.attn_processors.keys():
    	#如果是自注意力注意力attn1,那么设置为空,否则设置为交叉注意力的维度
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        #这里记录此时这个快的通道式
        if name.startswith("mid_block"):
        #'block_out_channels', [320, 640, 1280, 1280]
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
        #name中的,up_block.的后一个位置就是表示是第几个上块
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            attn_procs[name] = AttnProcessor()
        else:
            layer_name = name.split(".processor")[0]
            weights = {
            #这里是从unet_sd当中把这个交叉注意力层的原始kv权重拷贝一份出来
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }
            #然后这里将新构建的字典里面的attn_processor给替换为自己定义的IPAttnProcessor
            attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
            #这里将新构建的attn模型的权重初始化为原来的SD的uent中的crossattn的权重
            attn_procs[name].load_state_dict(weights)
    #最后这里将unet的注意力处理器设置为自己重构后的注意力字典
    unet.set_attn_processor(attn_procs)

attn_processor

unet中的unet.state_dict()存储了所有attn_processor的字典
我们要做修改的话,重构一个类似的字典,然后把其中我们需要修改的模块的attn_processor的类型进行替换

我们来看一下unet.attn_processors是什么样子的
在这里插入图片描述
unet.attn_processors是一个字典,包含32个元素
它的 key 是每个处理类所在位置,并结合unet的结构以及其中中crossattn块的个数(总共2,2,2,1,3,3,3(16个块)(每个块分别有一个自注意力和一个交叉注意力模块,所以总共有32个注意力块)),
我们知道了每块的名称的命名的含义:
比如:

'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor'
down_blocks.0.(可以是0,12,)(有3个下块)
代表第一个下块

attentions.0.(可以是0,1)(每个下块有2个transformer块)
代表第一个下块中的第一个transformer块

transformer_blocks.0.
这里都是0

attn1.processor(每个transformer块有2和注意快,一个交叉注意力,一个自注意力)
代表是自注意力还是交叉注意力(attn2.代表交叉注意力层,attn1代表自注意力层)

unet.config

unet.config 是unet配置的参数

FrozenDict([('sample_size', 64),
 ('in_channels', 4), 
('out_channels', 4), 
('center_input_sample', False), ('flip_sin_to_cos', True), ('freq_shift', 0), ('down_block_types', ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D']), ('mid_block_type', 'UNetMidBlock2DCrossAttn'), ('up_block_types', ['UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D']), 
('only_cross_attention', False), 
('block_out_channels', [320, 640, 1280, 1280]),
 ('layers_per_block', 2), ('downsample_padding', 1), ('mid_block_scale_factor', 1), ('dropout', 0.0), ('act_fn', 'silu'), ('norm_num_groups', 32), 
 ('norm_eps', 1e-05), ('cross_attention_dim', 768), ('transformer_layers_per_block', 1), ('reverse_transformer_layers_per_block', None),
  ('encoder_hid_dim', None), ('encoder_hid_dim_type', None), ('attention_head_dim', 8), ('num_attention_heads', None), ('dual_cross_attention', False), ('use_linear_projection', False), ('class_embed_type', None), ('addition_embed_type', None), ('addition_time_embed_dim', None), ('num_class_embeds', None), ('upcast_attention', False), ('resnet_time_scale_shift', 'default'), ('resnet_skip_time_act', False), ('resnet_out_scale_factor', 1.0), ('time_embedding_type', 'positional'), ('time_embedding_dim', None), ('time_embedding_act_fn', None), ('timestep_post_act', None), ('time_cond_proj_dim', None), ('conv_in_kernel', 3), ('conv_out_kernel', 3), ('projection_class_embeddings_input_dim', None), ('attention_type', 'default'), ('class_embeddings_concat', False), ('mid_block_only_cross_attention', None), ('cross_attention_norm', None), ('addition_embed_type_num_heads', 64), ('_use_default_values', ['addition_embed_type', 'encoder_hid_dim', 'transformer_layers_per_block', 'addition_embed_type_num_heads', 'upcast_attention', 'conv_in_kernel', 'attention_type', 'resnet_out_scale_factor', 'time_embedding_dim', 'time_embedding_act_fn', 'conv_out_kernel', 'reverse_transformer_layers_per_block', 'mid_block_type', 'class_embeddings_concat', 'time_embedding_type', 'use_linear_projection', 'class_embed_type', 'only_cross_attention', 'resnet_time_scale_shift', 'encoder_hid_dim_type', 'projection_class_embeddings_input_dim', 'dual_cross_attention', 'addition_time_embed_dim', 'cross_attention_norm', 'dropout', 'timestep_post_act', 'resnet_skip_time_act', 'num_attention_heads', 'time_cond_proj_dim', 'mid_block_only_cross_attention', 'num_class_embeds']), ('_class_name', 'UNet2DConditionModel'), ('_diffusers_version', '0.6.0'), ('_name_or_path', '/media/dell/DATA/RK/pretrained_model/stable-diffusion-v1-5')])

unet_sd

这里面是一个字典,包含了所有层的各个小模块的权重
在这里插入图片描述
这里是从unet_sd当中把这个交叉注意力层的原始kv权重拷贝一份出来,用于初始化自己设计的注意力处理器

 weights = {
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }

查看修改后unet的attn_processors

这里将unet.attn_processors的所有values()转化为list

adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

这里是IPadapter替换后的attn processor 的情况

ModuleList(
  (0): AttnProcessor2_0()
  (1): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (2): AttnProcessor2_0()
  (3): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (4): AttnProcessor2_0()
  (5): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (6): AttnProcessor2_0()
  (7): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (8): AttnProcessor2_0()
  (9): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (10): AttnProcessor2_0()
  (11): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (12): AttnProcessor2_0()
  (13): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (14): AttnProcessor2_0()
  (15): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (16): AttnProcessor2_0()
  (17): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (18): AttnProcessor2_0()
  (19): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (20): AttnProcessor2_0()
  (21): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (22): AttnProcessor2_0()
  (23): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (24): AttnProcessor2_0()
  (25): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (26): AttnProcessor2_0()
  (27): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (28): AttnProcessor2_0()
  (29): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (30): AttnProcessor2_0()
  (31): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
)

自己定义自己的attn Processor ,对原始的attn Processor进行修改

在原始的attention_processor.py 文件中定义新的attn processor类

原始的attention_processor中的attn processor

class AttnProcessor(nn.Module):
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

#3 ipadapter 新定义的
class IPAttnProcessor(nn.Module):
    r"""
    Attention processor for IP-Adapater.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
            The context length of the image features.
    """

    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
        super().__init__()

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            # get encoder_hidden_states, ip_hidden_states
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states, ip_hidden_states = (
                encoder_hidden_states[:, :end_pos, :],
                encoder_hidden_states[:, end_pos:, :],
            )
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # for ip-adapter
        ip_key = self.to_k_ip(ip_hidden_states)
        ip_value = self.to_v_ip(ip_hidden_states)

        ip_key = attn.head_to_batch_dim(ip_key)
        ip_value = attn.head_to_batch_dim(ip_value)

        ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
        self.attn_map = ip_attention_probs
        ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
        ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)

        hidden_states = hidden_states + self.scale * ip_hidden_states

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

自己定义两个新的,然后也放如这个文件里面

class StyleAttnProcessor(nn.Module):
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states
class LayoutAttnProcessor(nn.Module):
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

然后导入这两个attn processor

from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor, \
    LayoutAttnProcessor, StyleAttnProcessor

替换后的结果如下

这里是将第三个下块,和第1个上块分别替换为layout attn 和 style attn

    for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            attn_procs[name] = AttnProcessor()
        # 第三个下块的名称开头是这个
        elif name.startswith("down_blocks.2.attentions"):
            attn_procs[name] = LayoutAttnProcessor()
        #第一个上块的名称开头是这个
        elif name.startswith("up_blocks.1.attentions"):
            attn_procs[name] = StyleAttnProcessor()
        else:
            layer_name = name.split(".processor")[0]
            weights = {
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }
            attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
            attn_procs[name].load_state_dict(weights)

修改后 attn_processors 如下

{
'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
),

##  可以看到,这里的attn替换为了我们自己定义的layout  attn
 'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor': LayoutAttnProcessor(), 'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor': LayoutAttnProcessor(), 

## 可以看到,这里的attn替换为了我们自己定义的style  attn
'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': StyleAttnProcessor(), 'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': StyleAttnProcessor(), 'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor': StyleAttnProcessor(), 



'up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'mid_block.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'mid_block.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
)
}
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值