GauGAN损失函数篇
这两天整理一个GauGAN的损失函数及其曲线变化。
Generator
首先,对于生成器而言,损失函数包括如下几项:GAN_loss、Feat_loss、VGG_loss、KLD_loss四项。我已经拖了两三周没理它了,两眼一抹黑策略失败,后天必须给陈老师讲清楚。
GAN LOSS
在函数generate_fake中,将标签图与GT作为生成器的输入,返回大小为[bs,3,256,256]的逼真图像。
随后将语义标签图、GT、生成图作为鉴别器的输入,返回鉴别结果。鉴别器的输入看作是两个部分的结合:
- 语义标签图和生成图在
dim=1处拼接得到fake_concat - 语义标签图和GT在
dim=1处拼接得到real_concat - 两者在
dim=0处拼接得到的real_and_fake为输入
1 | fake_concat = torch.cat([input_semantics, fake_image], dim=1) # [bs, 305, 256, 256] |
然而此鉴别器由2个子鉴别器组合而成,第一个子鉴别的结构如下:

所以把fake_and_real作为输入后,中间层的尺寸变化是这样的:
1 | intermediate_output: torch.Size([bs*2, 64, 129, 129]) |
并且在定义子鉴别器结构的函数NLayerDiscriminator中,它把所有的中间结果都保存起来,组合成一个结果序列,并作为该鉴别器的计算结果返回,因此该序列的长度为5。
在函数MultiscaleDiscriminator中,它再次将每个子鉴别器的返回结果作为序列保存起来,因此最后netD返回的尺寸是opt.num_D x opt.n_layers_D。
另外需要注意的是,第二个子鉴别器的输入尺寸是对第一个鉴别器的input做downsample的结果,同理,第三个子鉴别器的输入尺寸是对第二个鉴别器的input做downsample的结果。
接着在函数divide_pred中需要将生成图和GT的鉴别结果分开,对于多鉴别器的结果序列而言:
1 | fake = [] |
在GANLoss这个类中,有几个初始化的参数如下,其中target_real_label的意思就是GT进入鉴别器后得到的结果应该越接近1越好,反之生成图越接近0越好:
1 | gan_mode target_real_label=1.0 target_fake_label=0.0 # 默认gan_mode为hinge |
计算GAN损失只需要鉴别器最后一层的结果,因此在criterionGAN函数中,对pred_fake的鉴别结果进行计算,并返回结果存储在G_losses中:
1 | G_losses['GAN'] = self.criterionGAN(pred_fake, target_is_real=True, |
在loss函数中,我们只需要看hinge部分:
1 | elif self.gan_mode == 'hinge': |
在上面这段代码中,计算GAN_loss时的input尺寸是pred_fake中两个元素的最后一层,尺寸分别为[bs, 1, 35, 35]和[bs, 1, 19, 19],其预测的结果值应该在0~1的范围之内,且越接近1越好(target is real),才能证明生成器的效果好,并最终返回的是多个鉴别器的平均结果。
1 | if isinstance(input, list): |
GAN FEAT LOSS
GAN feat loss实际上是计算GT和生成图分别丢给鉴别器后,每一层输出结果之间的L1损失,且不包括最后一层。
1 | # self.criterionFeat = torch.nn.L1Loss() |
opt.lambda_feat等于10,最后返回的是多个鉴别器的平均结果。
VGG LOSS
GauGAN网络中使用的是VGG19,这个类用于计算感知损失perceptual loss,即将生成图与GT作为VGG的输入,输出结果是长度为5的序列,每个元素是各层输出的中间特征。
在VGG中,计算GT和生成图相同层的中间特征的L1损失,最后将得到的损失值乘以lambda_vgg,其值为10:
1 | # self.criterion = nn.L1loss() |
综上所述,所有损失都是基于全局计算,且越小越好的。
Discriminator
那么接下来,对于鉴别器而言,损失项包括两个:D_fake和D_real,我们继续一一介绍吧。
在函数compute_discriminator_loss中,语义标签图和GT作为输入,此时假设已经训练好生成器G,那么通过generate_fake得到生成图fake_image。与上面相同,将GT和生成图丢进鉴别器中,返回得到pred_fake和pred_real鉴别结果。
但计算损失的时候略有不同,此时鉴别器的目标有两个:
- 对于
pred_fake,鉴别器的鉴别结果越接近0越好。 - 对于
pred_real,鉴别器的鉴别结果越接近1越好。
1 | D_losses['D_Fake'] = self.criterionGAN(pred_fake, target_is_real=False, |
D_fake LOSS
当gan_mode=hinge,且目标为0时:
1 | def get_zero_tensor(self, input): |
-input-1绝大多数情况为负值,其绝对值越小越好。比如当input=0.5与input=0.2时,-input - 1的绝对值结果分别为1.5和1.2,显然对于生成图而言,input=0.2的鉴别结果更好。
D_real LOSS
当gan_mode=hinge,且目标为1时:
1 | def get_zero_tensor(self, input): |
同理,input-1绝大多数情况为负值,其绝对值越小越好。比如当input=0.5与input=0.2时,input - 1的绝对值结果分别为0.5和0.8,显然对于GT而言,input=0.5的鉴别结果更好。
Optimizer
在函数run_generator_one_step中,生成器损失g_loss等于Generator中所有损失值之和的均值。同理,鉴别器损失d_loss等于Discriminator中所有损失值之和的均值。
然后通过Adam优化器进行优化迭代。








