当前位置: 首页 > news >正文

做网站的专业商业网站的后缀一般为

做网站的专业,商业网站的后缀一般为,wordpress主题logo大小,少儿编程是什么主要看 model 、loss 和 data 部分如何实现和处理的。 model—VQ_modelsVQModelEncoderVectorQuantizerDecoder loss—VQLoss_triple_codebook model—VQ_models 创建vq_model直接根据传入的模型压缩倍率8/16初始化对应的VQ_8/VQ_16#xff0c;两者都是初始化一个VQModel的类…主要看 model 、loss 和 data 部分如何实现和处理的。 model—VQ_modelsVQModelEncoderVectorQuantizerDecoder loss—VQLoss_triple_codebook model—VQ_models 创建vq_model直接根据传入的模型压缩倍率8/16初始化对应的VQ_8/VQ_16两者都是初始化一个VQModel的类只是压缩的倍率ch_mult不同这个和UNet里的ch_mult是一致的表示每个Block上/下采样的倍数所有倍率之积就是压缩倍率 # create and load modelvq_model VQ_models[args.vq_model](codebook_sizeargs.codebook_size,codebook_embed_dimargs.codebook_embed_dim,commit_loss_betaargs.commit_loss_beta,entropy_loss_ratioargs.entropy_loss_ratio,dropout_pargs.dropout_p,with_clip_supervisionargs.with_clip_supervision,with_disentanglementargs.with_disentanglement,disentanglement_ratioargs.disentanglement_ratio,)def VQ_8(**kwargs):return VQModel(ModelArgs(encoder_ch_mult[1, 2, 2, 4], decoder_ch_mult[1, 2, 2, 4], **kwargs))def VQ_16(**kwargs):return VQModel(ModelArgs(encoder_ch_mult[1, 1, 2, 2, 4], decoder_ch_mult[1, 1, 2, 2, 4], **kwargs))VQ_models {VQ-16: VQ_16, VQ-8: VQ_8}VQModel 包含3个 codebook 的 VQModel 的结构如下 Encoder Encoder逐步压缩spatial维度到embed_dim维度VectorQuantizer3个VectorQuantizer分别是pixel level的无teachermid semantic level的DINO teacherhigh semantic level的CLIP teacher配合3个quant_conv将z_channels变成codebook_embed_dimDecoder1个post_quant_conv(将emebdding_dim从3*codebook_embed_dim变成z_channels)一个Decoder逐步将embed_dim维度还原到spatial维度FeatPredHead2个FeatPredHead分别是将vq feature对齐到CLIP和DINO feature的MLP Head用于蒸馏监督 class VQModel(nn.Module):def __init__(self, config: ModelArgs):super().__init__()self.config config# Two head encoderself.encoder Encoder(ch_multconfig.encoder_ch_mult, z_channelsconfig.z_channels, dropoutconfig.dropout_p)# Quantizer for visual detail headself.quantize_vis VectorQuantizer(config.codebook_size, config.codebook_embed_dim,config.commit_loss_beta, config.entropy_loss_ratio,config.codebook_l2_norm, config.codebook_show_usage)self.quant_conv_vis nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)# Quantizer for mid-level semantic headself.quantize_sem_mid VectorQuantizer(config.codebook_size, config.codebook_embed_dim,config.commit_loss_beta, config.entropy_loss_ratio,config.codebook_l2_norm, config.codebook_show_usage)self.quant_conv_sem_mid nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)# Quantizer for high-level semantic headself.quantize_sem_high VectorQuantizer(config.codebook_size, config.codebook_embed_dim,config.commit_loss_beta, config.entropy_loss_ratio,config.codebook_l2_norm, config.codebook_show_usage)self.quant_conv_sem_high nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)print(Visual codebook: [{} x {}].format(config.codebook_size, config.codebook_embed_dim))print(Mid Semantic codebook: [{} x {}].format(config.codebook_size, config.codebook_embed_dim))print(High Semantic codebook: [{} x {}].format(config.codebook_size, config.codebook_embed_dim))# Pixel decoderinput_dim config.codebook_embed_dim * 3self.post_quant_conv nn.Conv2d(input_dim, config.z_channels, 1)self.decoder Decoder(ch_multconfig.decoder_ch_mult, z_channelsconfig.z_channels,dropoutconfig.dropout_p)# Down-sample factor in encoder channel multiplierself.num_resolutions len(config.encoder_ch_mult)if self.num_resolutions 5: # encoder_ch_mult[1, 1, 2, 2, 4]down_factor 16elif self.num_resolutions 4: # encoder_ch_mult[1, 2, 2, 4]down_factor 8else:raise NotImplementedError# Semantic feature predictionif self.config.with_clip_supervision:print(Include feature prediction head for representation supervision)self.mid_sem_feat_pred FeatPredHead(input_dimconfig.codebook_embed_dim, out_dim384, down_factordown_factor)self.high_sem_feat_pred FeatPredHead(input_dimconfig.codebook_embed_dim, out_dim768, down_factordown_factor)else:print(NO representation supervision)if self.config.with_disentanglement:print(Disentangle Ratio: , self.config.disentanglement_ratio)else:print(No Disentangle Regularization)前向forward 包含encode、vq、decode三个主要过程因为需要KD额外要一步feature对齐操作 ① 输入经过encoder得到3个不同的featureh_vis, h_sem_mid, h_sem_high再经过3个quant_conv将embed_dim对齐到codebook_embed_dim。②将不同level的image feature送入不同的VectorQuantizer得到三个不同的quant_feature和emb_lossemb_loss包含vq_loss、commit_loss、entropy_loss三部分因为需要知识蒸馏因此需要额外使用FeatPredHead将quant_feature对齐到CLIP和DINO特征的维度(mid_sem_feat_pred和high_sem_feat_pred)因为希望3个codebook相互正交解耦程度大因此需要构造1个解耦loss使3个level的vq feature相互不同embedding点积之和的L2 loss即disentangle_loss 。③ 将quant_feature经过post_quant_conv和decoder解码为原始image的pixel_valuesdec。 def forward(self, input):# 1. encodeh_vis, h_sem_mid, h_sem_high self.encoder(input)h_vis self.quant_conv_vis(h_vis)h_sem_mid self.quant_conv_sem_mid(h_sem_mid)h_sem_high self.quant_conv_sem_high(h_sem_high)# 2. vqquant_vis, emb_loss_vis, _ self.quantize_vis(h_vis)quant_sem_mid, emb_loss_sem_mid, _ self.quantize_sem_mid(h_sem_mid)quant_sem_high, emb_loss_sem_high, _ self.quantize_sem_high(h_sem_high)# for konwledge distillationif self.config.with_clip_supervision:mid_lvl_sem_feat self.mid_sem_feat_pred(quant_sem_mid)high_lvl_sem_feat self.high_sem_feat_pred(quant_sem_high)else:mid_lvl_sem_feat Nonehigh_lvl_sem_feat None# for disentangle vq feature of 3 codebookif self.config.with_disentanglement:disentangle_loss (self.compute_disentangle_loss(quant_vis, quant_sem_mid) self.compute_disentangle_loss(quant_vis, quant_sem_high) self.compute_disentangle_loss(quant_sem_mid, quant_sem_high)) / 3.0else:disentangle_loss 0# 3. decodequant torch.cat([quant_vis, quant_sem_mid, quant_sem_high], dim1)dec self.decode(quant)return dec, \emb_loss_vis, emb_loss_sem_mid, emb_loss_sem_high, \disentangle_loss, \mid_lvl_sem_feat, high_lvl_sem_feat本文叫FQ的创新点就是在于设计了这个disentangle_loss使得3个codebook相互正交解耦这个损失函数的设计思想是如果2个特征是解耦的那么它们的点积应该接近于零因为它们应该是正交的。通过最小化这个损失模型被鼓励学习到解耦的不同level的特征。 def compute_disentangle_loss(self, quant_vis, quant_sem):quant_vis rearrange(quant_vis, b c h w - (b h w) c)quant_sem rearrange(quant_sem, b c h w - (b h w) c)quant_vis F.normalize(quant_vis, p2, dim-1)quant_sem F.normalize(quant_sem, p2, dim-1)dot_product torch.sum(quant_vis * quant_sem, dim1)loss torch.mean(dot_product ** 2) * self.config.disentanglement_ratioreturn lossEncoder Encoder是输入image feature经过统一的downsampling conv_blocks和mid blocks再分别送入3个不同的adapter输出3个不同的feature。 conv_in输入的image feature首先由conv_in将channel维度转化为128。downsamplingconv_blocks根据ch_mult(1,1,2,2,4)或ch_mult(1, 2, 2, 4)构建ResnetBlock和AttnBlock以及Downsample组成其中ch_mult用于控制每个conv_block的channel增加倍数。channel增加h和w减小。每个block的下采样后的channel是128*ch_mult[i]例如ch_mult(1, 2, 2, 4)时共有4个blockchannel的变化是128-128-256-512-2048。mid由ResnetBlockAttnBlockResnetBlock组成其中卷积不改变channel等效于MLP。adapter由3个不同的FactorizedAdapter组成用于将统一的encoder feature转化为3个不同的feature用于后面3个codebook的VQ操作。conv_out因为前一步将feature转化了3份h_vis, h_sem_mid, h_sem_high因此此处从conv_out分别用3个不同的conv2d用于对齐feature的channel维度(转换为z_channels维度)。 class Encoder(nn.Module):def __init__(self, in_channels3, ch128, ch_mult(1,1,2,2,4), num_res_blocks2, norm_typegroup, dropout0.0, resamp_with_convTrue, z_channels256):super().__init__()self.num_resolutions len(ch_mult)self.num_res_blocks num_res_blocksself.conv_in nn.Conv2d(in_channels, ch, kernel_size3, stride1, padding1)# downsamplingin_ch_mult (1,) tuple(ch_mult)self.conv_blocks nn.ModuleList()for i_level in range(self.num_resolutions):conv_block nn.Module()# res attnres_block nn.ModuleList()attn_block nn.ModuleList()block_in ch*in_ch_mult[i_level]block_out ch*ch_mult[i_level]for _ in range(self.num_res_blocks):res_block.append(ResnetBlock(block_in, block_out, dropoutdropout, norm_typenorm_type))block_in block_outif i_level self.num_resolutions - 1:attn_block.append(AttnBlock(block_in, norm_type))conv_block.res res_blockconv_block.attn attn_block# downsampleif i_level ! self.num_resolutions-1:conv_block.downsample Downsample(block_in, resamp_with_conv)self.conv_blocks.append(conv_block)# middleself.mid nn.ModuleList()self.mid.append(ResnetBlock(block_in, block_in, dropoutdropout, norm_typenorm_type))self.mid.append(AttnBlock(block_in, norm_typenorm_type))self.mid.append(ResnetBlock(block_in, block_in, dropoutdropout, norm_typenorm_type))if self.num_resolutions 5:down_factor 16elif self.num_resolutions 4:down_factor 8else:raise NotImplementedError# semantic head mid-levelself.semantic_head_mid nn.ModuleList()self.semantic_head_mid.append(FactorizedAdapter(down_factor))# semantic head high-levelself.semantic_head_high nn.ModuleList()self.semantic_head_high.append(FactorizedAdapter(down_factor))# visual details headself.visual_head nn.ModuleList()self.visual_head.append(FactorizedAdapter(down_factor))# endself.norm_out_sem_mid Normalize(block_in, norm_type)self.conv_out_sem_mid nn.Conv2d(block_in, z_channels, kernel_size3, stride1, padding1)self.norm_out_sem_high Normalize(block_in, norm_type)self.conv_out_sem_high nn.Conv2d(block_in, z_channels, kernel_size3, stride1, padding1)self.norm_out_vis Normalize(block_in, norm_type)self.conv_out_vis nn.Conv2d(block_in, z_channels, kernel_size3, stride1, padding1)def forward(self, x):h self.conv_in(x)# downsamplingfor i_level, block in enumerate(self.conv_blocks):for i_block in range(self.num_res_blocks):h block.res[i_block](h)if len(block.attn) 0:h block.attn[i_block](h)if i_level ! self.num_resolutions - 1:h block.downsample(h)# middlefor mid_block in self.mid:h mid_block(h)h_vis hh_sem_mid hh_sem_high h# semantic head mid-levelfor blk in self.semantic_head_mid:h_sem_mid blk(h_sem_mid)h_sem_mid self.norm_out_sem_mid(h_sem_mid)h_sem_mid nonlinearity(h_sem_mid)h_sem_mid self.conv_out_sem_mid(h_sem_mid)# semantic head high-levelfor blk in self.semantic_head_high:h_sem_high blk(h_sem_high)h_sem_high self.norm_out_sem_high(h_sem_high)h_sem_high nonlinearity(h_sem_high)h_sem_high self.conv_out_sem_high(h_sem_high)# visual headfor blk in self.visual_head:h_vis blk(h_vis)h_vis self.norm_out_vis(h_vis)h_vis nonlinearity(h_vis)h_vis self.conv_out_vis(h_vis)return h_vis, h_sem_mid, h_sem_highVectorQuantizer VectorQuantizer的初始化操作主要是创建一个大小[codebook_size, codebook_embed_dim]为codebook embeddingembedding。 class VectorQuantizer(nn.Module):def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):super().__init__()self.n_e n_e # codebook_sizeself.e_dim e_dim # codebook_embed_dimself.beta beta # commitment_loss scaleself.entropy_loss_ratio entropy_loss_ratio # entropy_loss scaleself.l2_norm l2_norm # l2_norm for codebook embeddingsself.show_usage show_usage # show codebook usage# create codebook embedding and initializeself.embedding nn.Embedding(self.n_e, self.e_dim)self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)if self.l2_norm: # normalize embeddingsself.embedding.weight.data F.normalize(self.embedding.weight.data, p2, dim-1)if self.show_usage: # initialize codebook usageself.register_buffer(codebook_used, nn.Parameter(torch.zeros(65536))) # 1048576forward的操作和VQGAN一样就是把image feature z的所有token embeddings查表量化为codebook中argmin(distances)的emebddings得到 quant image feature zq同时计算3个loss用于优化codebook embedding。 l2 norm同时对z和codebook embeddings进行L2归一化可以将向量的模长缩放到相同的大小即转换为在单位球面上的向量这样每个向量在距离度量中的作用是相等的使得不同向量更具有可比性向量之间更容易比较和匹配提高了训练稳定性和重建质量。argmin(distances)经典的VQ计算distances的操作展开为两个平方和一个乘积 ( z − e ) 2 z 2 e 2 − 2 e ∗ z (z - e)^2 z^2 e^2 - 2 e * z (z−e)2z2e2−2e∗z。然后argmin(distances)得到z中每个embedding在codebook中最近的embedding的index再从codebook的embeddings中取出组成zq。codebook usage是计算codebook中的embeddings 的利用率。 def forward(self, z):# reshape z - (batch, height, width, channel) and flattenz torch.einsum(b c h w - b h w c, z).contiguous()z_flattened z.view(-1, self.e_dim)if self.l2_norm: # normalize z and codebook_embedding for mapping vector to euclidean space(单位球上)z F.normalize(z, p2, dim-1)z_flattened F.normalize(z_flattened, p2, dim-1)embedding F.normalize(self.embedding.weight, p2, dim-1)else:embedding self.embedding.weight# distances from z to embeddings e_j: (z - e)^2 z^2 e^2 - 2 e * zd torch.sum(z_flattened ** 2, dim1, keepdimTrue) \torch.sum(embedding**2, dim1) - 2 * \torch.einsum(bd,dn-bn, z_flattened, torch.einsum(n d - d n, embedding))# argmin(distances)min_encoding_indices torch.argmin(d, dim1)# replace each z_i with its closest embedding e_jz_q embedding[min_encoding_indices].view(z.shape)perplexity Nonemin_encodings Nonevq_loss Nonecommit_loss Noneentropy_loss Nonecodebook_usage 0# compute codebook usageif self.show_usage and self.training:cur_len min_encoding_indices.shape[0]self.codebook_used[:-cur_len] self.codebook_used[cur_len:].clone() # copy last cur_len elements to frontself.codebook_used[-cur_len:] min_encoding_indices # set last cur_len elements as min_encoding_indicescodebook_usage len(torch.unique(self.codebook_used)) / self.n_eembedding loss vq_loss是计算量化后的向量 z_q 和原始输入向量 z 之间的均方误差Mean Squared Error, MSE。z.detach() 表示 z 是从计算图中分离出来的这意味着在计算 vq_loss 时z 不会对其梯度产生影响。这个损失鼓励模型将输入向量 z 量化为与其尽可能接近的嵌入向量 z_q。commit_loss也是均方误差但是这里 z_q 是从计算图中分离出来的。这意味着在计算 commit_loss 时z_q 不会对其梯度产生影响。这个损失的作用是鼓励模型在量化过程中保持对原始输入向量 z 的承诺即量化后的向量 z_q 应该尽可能地反映输入向量 z 的信息。参数 self.beta 是一个超参数用于调节这个损失在总损失中的重要性。entropy_loss用于鼓励码本的均匀使用从而提高模型的泛化能力。compute_entropy_loss(-d) 计算的是基于码本距离的负值的熵损失-d 表示我们对每个输入向量 z 计算到所有嵌入的平方距离然后取负值。熵损失的计算通常涉及到对这些距离的softmax操作然后计算交叉熵。self.entropy_loss_ratio 是一个超参数用于调节熵损失在总损失中的重要性。 # compute 3 loss for embeddingif self.training:vq_loss torch.mean((z_q - z.detach()) ** 2) commit_loss self.beta * torch.mean((z_q.detach() - z) ** 2) entropy_loss self.entropy_loss_ratio * compute_entropy_loss(-d)# preserve gradientsz_q z (z_q - z).detach()# reshape back to match original input shapez_q torch.einsum(b h w c - b c h w, z_q)return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices)def compute_entropy_loss(affinity, loss_typesoftmax, temperature0.01):flat_affinity affinity.reshape(-1, affinity.shape[-1])flat_affinity / temperatureprobs F.softmax(flat_affinity, dim-1)log_probs F.log_softmax(flat_affinity 1e-5, dim-1)if loss_type softmax:target_probs probselse:raise ValueError(Entropy loss {} not supported.format(loss_type))avg_probs torch.mean(target_probs, dim0)avg_entropy - torch.sum(avg_probs * torch.log(avg_probs 1e-5))sample_entropy - torch.mean(torch.sum(target_probs * log_probs, dim-1))loss sample_entropy - avg_entropyreturn lossget_codebook_entry用于Transformer自回归的预测一个index序列后用于在codebook查表转化为对应embeddings。 def get_codebook_entry(self, indices, shapeNone, channel_firstTrue):# shape (batch, channel, height, width) if channel_first else (batch, height, width, channel)if self.l2_norm:embedding F.normalize(self.embedding.weight, p2, dim-1)else:embedding self.embedding.weightz_q embedding[indices] # (b*h*w, c)if shape is not None:if channel_first:z_q z_q.reshape(shape[0], shape[2], shape[3], shape[1])# reshape back to match original input shapez_q z_q.permute(0, 3, 1, 2).contiguous()else:z_q z_q.view(shape)return z_qDecoder 整个VQ操作从z到zq不改变image feature的shapechannel维度还是z_channels256。因此Decoder将zq解码为image的pixel values的过程如下 conv_in使用Conv2d将zq的channel维度从z_channels变换到block_in(由ch128和ch_mult决定的)。middle block和Encoder一样由ResnetBlockAttnBlockResnetBlock组成不改变channel维度。upsampling conv_blocks和Encoder刚好相反根据ch_mult构造多个Block逐步上采样增大spatial维度减小channel维度。conv_out最终的conv_out用于将channel维度从block_in转化为out_channels3得到图像pixel valuse。 class Decoder(nn.Module):def __init__(self, z_channels256, ch128, ch_mult(1,1,2,2,4), num_res_blocks2, norm_typegroup,dropout0.0, resamp_with_convTrue, out_channels3):super().__init__()self.num_resolutions len(ch_mult)self.num_res_blocks num_res_blocksblock_in ch*ch_mult[self.num_resolutions-1]# z to block_inself.conv_in nn.Conv2d(z_channels, block_in, kernel_size3, stride1, padding1)# middleself.mid nn.ModuleList()self.mid.append(ResnetBlock(block_in, block_in, dropoutdropout, norm_typenorm_type))self.mid.append(AttnBlock(block_in, norm_typenorm_type))self.mid.append(ResnetBlock(block_in, block_in, dropoutdropout, norm_typenorm_type))# upsamplingself.conv_blocks nn.ModuleList()for i_level in reversed(range(self.num_resolutions)):conv_block nn.Module()# res attnres_block nn.ModuleList()attn_block nn.ModuleList()block_out ch*ch_mult[i_level]for _ in range(self.num_res_blocks 1):res_block.append(ResnetBlock(block_in, block_out, dropoutdropout, norm_typenorm_type))block_in block_outif i_level self.num_resolutions - 1:attn_block.append(AttnBlock(block_in, norm_type))conv_block.res res_blockconv_block.attn attn_block# downsampleif i_level ! 0:conv_block.upsample Upsample(block_in, resamp_with_conv)self.conv_blocks.append(conv_block)# endself.norm_out Normalize(block_in, norm_type)self.conv_out nn.Conv2d(block_in, out_channels, kernel_size3, stride1, padding1)propertydef last_layer(self):return self.conv_out.weightdef forward(self, z):# z to block_inh self.conv_in(z)# middlefor mid_block in self.mid:h mid_block(h)# upsamplingfor i_level, block in enumerate(self.conv_blocks):for i_block in range(self.num_res_blocks 1):h block.res[i_block](h)if len(block.attn) 0:h block.attn[i_block](h)if i_level ! self.num_resolutions - 1:h block.upsample(h)# endh self.norm_out(h)h nonlinearity(h)h self.conv_out(h)return h loss—VQLoss_triple_codebook 前面的VQ_Model进行forward的时候会得到3个embed_loss和1个disentangle_loss codebook embedding loss因为有3个codebook所有3个VQ操作回得到3个embed_lossemb_loss_vis, emb_loss_sem_mid, emb_loss_sem_high每个emb_loss都是由3个loss组成vq_loss, commit_loss, entropy_loss用于优化codebook。disentangle loss本文的创新点之一将不同的codebook的zq之间计算点积的L2距离之和作为disentangle_loss希望不同codebook之间相互正交。 除此之外在训练时还可以使用VQLoss_triple_codebook也可以另外计算reconstruction_lossperceptual_losskd_teacher_loss pixel loss reconstruction_loss计算VQ_Model重建前后input和output的pixel values的l1_loss或者l2_loss。perceptual_loss使用vgg-based LPIPS计算input和output的lpips值作为loss。 discriminator loss用于优化鉴别器discriminatordiscriminator可以是PatchGAN或StyleGAN输入真实的image或重建的image输出预测真假的概率分布logits。discriminator_loss类型可以是hinge、vanilla、non-saturating三类。 gen_adv_loss用于优化生成器生成器的目标是生成尽可能接近真实数据的假数据以欺骗判别器。分为hinge和non_saturating两种都是希望重建后图像的概率分布logits_fake更倾向于重建后的image是真实的。 semantic loss(kd_teacher_loss)使用2个不同的FeatureHead输出了2个image vq feature分别与CLIP和DINO的feature计算loss用来蒸馏通用的理解表征。 VQLoss_triple_codebook的初始化就是为计算上述loss准备一些参数和模型 class VQLoss_triple_codebook(nn.Module):def __init__(self, disc_start, disc_losshinge, disc_dim64, disc_typepatchgan, image_size256,disc_num_layers3, disc_in_channels3, disc_weight1.0, disc_adaptive_weightFalse,gen_adv_losshinge, reconstruction_lossl2, reconstruction_weight1.0,codebook_weight1.0, perceptual_weight1.0,with_clip_supervisionFalse, semantic_weight0.5,):super().__init__()# 1. discriminator lossassert disc_type in [patchgan, stylegan]assert disc_loss in [hinge, vanilla, non-saturating]# discriminatorif disc_type patchgan:print(Using patchgan D)self.discriminator PatchGANDiscriminator(input_ncdisc_in_channels,n_layersdisc_num_layers,ndfdisc_dim,)elif disc_type stylegan:print(Using stylegan D)self.discriminator StyleGANDiscriminator(input_ncdisc_in_channels,image_sizeimage_size,)else:raise ValueError(fUnknown GAN discriminator type {disc_type}.)# disc_loss typeif disc_loss hinge:self.disc_loss hinge_d_losselif disc_loss vanilla:self.disc_loss vanilla_d_losselif disc_loss non-saturating:self.disc_loss non_saturating_d_losselse:raise ValueError(fUnknown GAN discriminator loss {disc_loss}.)self.discriminator_iter_start disc_startself.disc_weight disc_weightself.disc_adaptive_weight disc_adaptive_weightassert gen_adv_loss in [hinge, non-saturating]# 2. gen_adv_lossif gen_adv_loss hinge:self.gen_adv_loss hinge_gen_losselif gen_adv_loss non-saturating:self.gen_adv_loss non_saturating_gen_losselse:raise ValueError(fUnknown GAN generator loss {gen_adv_loss}.)# 3. perceptual lossself.perceptual_loss LPIPS().eval()self.perceptual_weight perceptual_weight# 4. semantic lossself.with_clip_supervision with_clip_supervisionif with_clip_supervision:self.clip_model CLIPVisionTower(/mnt/workspace/Project/UnderGenTokenizer/FQGAN/models/clip-vit-base-patch16).eval()self.dinov2_model DinoVisionTower(/mnt/workspace/Project/UnderGenTokenizer/FQGAN/models/dinov2-small).eval()self.clip_model.requires_grad_(False)self.dinov2_model.requires_grad_(False)self.semantic_weight semantic_weightelse:self.clip_model Noneself.dinov2_model Noneself.semantic_weight None# 5. reconstruction lossif reconstruction_loss l1:self.rec_loss F.l1_losselif reconstruction_loss l2:self.rec_loss F.mse_losselse:raise ValueError(fUnknown rec loss {reconstruction_loss}.)self.rec_weight reconstruction_weight# 6. codebook lossself.codebook_weight codebook_weightVQLoss_triple_codebook类的forward过程根据optimizer_idx的值分为2个模式两个模式在同一个batch的先后执行也就是在训练时要进行2次的vq_loss类的forward一次计算generator的loss一次计算discriminator的loss。且generator和discriminator分别使用2个不同的优化器optimizer和optimizer_disc optimizer_idx 0时优化generator计算reconstruction loss、perceptual loss、semantic loss、gen_adv_loss并将其与之前VQModel推理时计算的codebook_embed_loss和disentangle_loss线性加权起来组成总loss用于优化VQ_Model。optimizer_idx 1时优化discriminator计算discriminator loss用于优化Discriminator。 def forward(self,codebook_loss_vis, codebook_loss_sem_mid, codebook_loss_sem_high,inputs, reconstructions,disentangle_loss,semantic_feat_mid, semantic_feat_high,optimizer_idx, global_step, last_layerNone,loggerNone, log_every100):# generator updateif optimizer_idx 0:# reconstruction lossrec_loss self.rec_loss(inputs.contiguous(), reconstructions.contiguous())# semantic lossif semantic_feat_mid is not None:assert semantic_feat_high is not Nonesemantic_loss_mid self.dinov2_model(inputs.contiguous(), semantic_feat_mid) # how to compute semantic loss?semantic_loss_mid torch.mean(semantic_loss_mid)semantic_loss_high self.clip_model(inputs.contiguous(), semantic_feat_high)semantic_loss_high torch.mean(semantic_loss_high)else:assert self.with_clip_supervision Falsesemantic_loss_mid torch.mean(torch.zeros_like(rec_loss))semantic_loss_high torch.mean(torch.ones_like(rec_loss))# perceptual lossp_loss self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())p_loss torch.mean(p_loss)# discriminator losslogits_fake self.discriminator(reconstructions.contiguous())generator_adv_loss self.gen_adv_loss(logits_fake)if self.disc_adaptive_weight:null_loss self.rec_weight * rec_loss self.perceptual_weight * p_loss # pixel lossdisc_adaptive_weight self.calculate_adaptive_weight(null_loss, generator_adv_loss,last_layerlast_layer)else:disc_adaptive_weight 1disc_weight adopt_weight(self.disc_weight, global_step, thresholdself.discriminator_iter_start)loss self.rec_weight * rec_loss \self.perceptual_weight * p_loss \disc_adaptive_weight * disc_weight * generator_adv_loss \codebook_loss_vis[0] codebook_loss_vis[1] codebook_loss_vis[2] \codebook_loss_sem_mid[0] codebook_loss_sem_mid[1] codebook_loss_sem_mid[2] \codebook_loss_sem_high[0] codebook_loss_sem_high[1] codebook_loss_sem_high[2] \self.semantic_weight * semantic_loss_mid self.semantic_weight * semantic_loss_high disentangle_lossif global_step % log_every 0:rec_loss self.rec_weight * rec_lossp_loss self.perceptual_weight * p_lossgenerator_adv_loss disc_adaptive_weight * disc_weight * generator_adv_losslogger.info(f(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, fvq_loss_sem_mid: {codebook_loss_sem_mid[0]:.4f}, fcommit_loss_sem_mid: {codebook_loss_sem_mid[1]:.4f}, fentropy_loss_sem_mid: {codebook_loss_sem_mid[2]:.4f}, fcodebook_usage_sem_mid: {codebook_loss_sem_mid[3]:.4f}, fvq_loss_sem_high: {codebook_loss_sem_high[0]:.4f}, fcommit_loss_sem_high: {codebook_loss_sem_high[1]:.4f}, fentropy_loss_sem_high: {codebook_loss_sem_high[2]:.4f}, fcodebook_usage_sem_high: {codebook_loss_sem_high[3]:.4f}, fvq_loss_vis: {codebook_loss_vis[0]:.4f}, fcommit_loss_vis: {codebook_loss_vis[1]:.4f}, fentropy_loss_vis: {codebook_loss_vis[2]:.4f}, fcodebook_usage_vis: {codebook_loss_vis[3]:.4f}, fdisentangle_loss: {disentangle_loss: .4f}fgenerator_adv_loss: {generator_adv_loss:.4f}, fdisc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}, fsemantic_loss_mid: {semantic_loss_mid:.4f}, semantic_loss_high: {semantic_loss_high:.4f})if dist.get_rank() 0:wandb.log({rec_loss: rec_loss,perceptual_loss: p_loss,disentangle_loss: disentangle_loss,codebook_loss_sem_mid: codebook_loss_sem_mid[0],commit_loss_sem_mid: codebook_loss_sem_mid[1],entropy_loss_sem_mid: codebook_loss_sem_mid[2],codebook_usage_sem_mid: codebook_loss_sem_mid[3],codebook_loss_sem_high: codebook_loss_sem_high[0],commit_loss_sem_high: codebook_loss_sem_high[1],entropy_loss_sem_high: codebook_loss_sem_high[2],codebook_usage_sem_high: codebook_loss_sem_high[3],codebook_loss_vis: codebook_loss_vis[0],commit_loss_vis: codebook_loss_vis[1],entropy_loss_vis: codebook_loss_vis[2],codebook_usage_vis: codebook_loss_vis[3],generator_adv_loss: generator_adv_loss,disc_adaptive_weight: disc_adaptive_weight,disc_weight: disc_weight,semantic_loss_mid: semantic_loss_mid,semantic_loss_high: semantic_loss_high,})return loss# discriminator updateif optimizer_idx 1:logits_real self.discriminator(inputs.contiguous().detach())logits_fake self.discriminator(reconstructions.contiguous().detach())disc_weight adopt_weight(self.disc_weight, global_step, thresholdself.discriminator_iter_start)d_adversarial_loss disc_weight * self.disc_loss(logits_real, logits_fake)if global_step % log_every 0:logits_real logits_real.detach().mean()logits_fake logits_fake.detach().mean()logger.info(f(Discriminator) fdiscriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, flogits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f})if dist.get_rank() 0:wandb.log({discriminator_adv_loss: d_adversarial_loss,disc_weight: disc_weight,logits_real: logits_real,logits_fake: logits_fake,})return d_adversarial_loss
http://www.w-s-a.com/news/442839/

相关文章:

  • wordpress statraq重庆百度优化
  • 企业网站官网英文WordPress站点切换为中文
  • 服装公司网站定位一点号自媒体平台
  • 密云微网站建设汽车之家手机官网首页
  • 多语言外贸网站制作苏州建设网站微信公众号
  • 用wordpress建站学什么百度给企业做网站吗
  • 福建城乡建设网站做数码测评的网站
  • 东海县建设局网站wordpress 好用的主题
  • 网站图片设计制作制作一个门户网站需要多少钱
  • 虚拟币交易网站源码自己给网站做支付接口
  • 免费的seo网站在线 crm
  • 绍兴市高速公路建设指挥部网站网站主页和子页风格如何统一
  • 获取网站状态网站租金可以做办公费吗
  • 网站开发执行什么标准号wordpress主题 表白
  • 杭州网站推广与优化凡科网是免费的吗
  • 公司网站的重要性门户网站推广介绍方案
  • 做金融网站看那些素材江门网红打卡景点蓬江区
  • 饮食网站模板建网站中企动力优
  • 郑州 制造 网站东平企业建站公司
  • 天津设计师网站大全展示型网站搭建
  • 南宁网站建设 传导网站开发平台开发公司
  • 网站建设好处上海建设工程网站
  • 黑河哈尔滨网站建设太原网站制作定制开发
  • 建站做网站香河住房与建设局网站
  • 如何制造一个网站域名分类网站
  • 解析视频的网站怎么做凡科网快图
  • 企业网站优化问题接单app平台有哪些
  • 怎么做网站后缀识别符号才不会变什么是电子商务网站建设
  • 中山 五金 骏域网站建设专家专门用来制作网页的软件是什么
  • 怎么做刷东西的网站数据分析软件工具有哪些