焦作网站建设,做一个网站多少钱,管廊建设网站,项目开发的五个阶段Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance Github
摘要
近期研究表明#xff0c;扩散模型能够生成高质量样本#xff0c;但其质量在很大程度上依赖于采样引导技术#xff0c;如分类器引导#xff08;CG#xff09;和无分类器引导#xff…Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance Github
摘要
近期研究表明扩散模型能够生成高质量样本但其质量在很大程度上依赖于采样引导技术如分类器引导CG和无分类器引导CFG。这些技术在无条件生成或诸如图像恢复等各种下游任务中往往并不适用。在本文中我们提出了一种新颖的采样引导方法称为Perturbed-Attention GuidancePAG它能在无条件和条件设置下提高扩散样本的质量并且无需额外的训练或集成外部模块。PAG 旨在通过去噪过程逐步增强样本的结构。它通过用单位矩阵替换 UNet 中的self-attention map来生成结构退化的中间样本这是考虑到自注意力机制捕捉结构信息的能力并引导去噪过程远离这些退化样本。在 ADM 和 Stable Diffusion 中PAG 在条件甚至无条件场景下都显著提高了样本质量。此外在诸如空提示的 ControlNet 以及图像修复如修补和去模糊等现有引导如 CG 或 CFG无法充分利用的各种下游任务中PAG 也显著提高了基线性能。 研究表明在diffusion U-Net的self-attention 模块中query-key 主要影响structure values主要影响appearance。 如果直接扰动Vt 的话会导致 out-of-distribution (OOD)因此选择使用单位矩阵替换query-key 部分。 那么具体扰动Unet的哪一部分呢作者使用了5k个样本在PAG guidance scale s 2.5 and DDIM 25 step的条件下表现最好的是mid-block “m0”
代码
Diffusers 已经支持PAG用在多种任务中并且可以和ControlNet、 IP-Adapter 一起使用。
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torchpipeline AutoPipelineForText2Image.from_pretrained(~/.cache/modelscope/hub/AI-ModelScope/stable-diffusion-xl-base-1___0,enable_pagTrue, ##addpag_applied_layers[mid], ##addtorch_dtypetorch.float16
)
pipeline.enable_model_cpu_offload()prompt an insect robot preparing a delicious meal, anime style
generator torch.Generator(devicecpu).manual_seed(0)
images pipeline(promptprompt,num_inference_steps25,guidance_scale7.0,generatorgenerator,pag_scale2.5,
).imagesimages[0].save(pag.jpg)PAG代码细节
如果同时使用PAG和CFG那么输入到Unet中prompt_embeds定义如下也就是[uncond,cond,cond] def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):cond torch.cat([cond] * 2, dim0)if do_classifier_free_guidance:cond torch.cat([uncond, cond], dim0)return condPAGCFGIdentitySelfAttnProcessor2_0计算其中[uncond,cond]正常计算SA第二个cond则计算PSA。
class PAGCFGIdentitySelfAttnProcessor2_0:rProcessor for implementing PAG using scaled dot-product attention (enabled by default if youre using PyTorch 2.0).PAG reference: https://arxiv.org/abs/2403.17377def __init__(self):if not hasattr(F, scaled_dot_product_attention):raise ImportError(PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.)def __call__(self,attn: Attention,hidden_states: torch.FloatTensor,encoder_hidden_states: Optional[torch.FloatTensor] None,attention_mask: Optional[torch.FloatTensor] None,temb: Optional[torch.FloatTensor] None,) - torch.Tensor:residual hidden_statesif attn.spatial_norm is not None:hidden_states attn.spatial_norm(hidden_states, temb)input_ndim hidden_states.ndimif input_ndim 4:batch_size, channel, height, width hidden_states.shapehidden_states hidden_states.view(batch_size, channel, height * width).transpose(1, 2)# chunkhidden_states_uncond, hidden_states_org, hidden_states_ptb hidden_states.chunk(3)hidden_states_org torch.cat([hidden_states_uncond, hidden_states_org])# original pathbatch_size, sequence_length, _ hidden_states_org.shapeif attention_mask is not None:attention_mask attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# scaled_dot_product_attention expects attention_mask shape to be# (batch, heads, source_length, target_length)attention_mask attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])if attn.group_norm is not None:hidden_states_org attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)query attn.to_q(hidden_states_org)key attn.to_k(hidden_states_org)value attn.to_v(hidden_states_org)inner_dim key.shape[-1]head_dim inner_dim // attn.headsquery query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)key key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)value value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# the output of sdp (batch, num_heads, seq_len, head_dim)# TODO: add support for attn.scale when we move to Torch 2.1hidden_states_org F.scaled_dot_product_attention(query, key, value, attn_maskattention_mask, dropout_p0.0, is_causalFalse)hidden_states_org hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)hidden_states_org hidden_states_org.to(query.dtype)# linear projhidden_states_org attn.to_out[0](hidden_states_org)# dropouthidden_states_org attn.to_out[1](hidden_states_org)if input_ndim 4:hidden_states_org hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)# perturbed path (identity attention)batch_size, sequence_length, _ hidden_states_ptb.shapeif attn.group_norm is not None:hidden_states_ptb attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)value attn.to_v(hidden_states_ptb)hidden_states_ptb valuehidden_states_ptb hidden_states_ptb.to(query.dtype)# linear projhidden_states_ptb attn.to_out[0](hidden_states_ptb)# dropouthidden_states_ptb attn.to_out[1](hidden_states_ptb)if input_ndim 4:hidden_states_ptb hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)# cathidden_states torch.cat([hidden_states_org, hidden_states_ptb])if attn.residual_connection:hidden_states hidden_states residualhidden_states hidden_states / attn.rescale_output_factorreturn hidden_states经过Unet后noise_pred的计算方法。 def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_textFalse):rApply perturbed attention guidance to the noise prediction.Args:noise_pred (torch.Tensor): The noise prediction tensor.do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.guidance_scale (float): The scale factor for the guidance term.t (int): The current time step.return_pred_text (bool): Whether to return the text noise prediction.Returns:Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applyingperturbed attention guidance and the text noise prediction.pag_scale self._get_pag_scale(t)if do_classifier_free_guidance:noise_pred_uncond, noise_pred_text, noise_pred_perturb noise_pred.chunk(3)noise_pred (noise_pred_uncond guidance_scale * (noise_pred_text - noise_pred_uncond) pag_scale * (noise_pred_text - noise_pred_perturb))else:noise_pred_text, noise_pred_perturb noise_pred.chunk(2)noise_pred noise_pred_text pag_scale * (noise_pred_text - noise_pred_perturb)if return_pred_text:return noise_pred, noise_pred_textreturn noise_pred