GauGAN损失函数篇
这两天整理一个GauGAN的损失函数及其曲线变化。
Generator
首先,对于生成器而言,损失函数包括如下几项:GAN_loss、Feat_loss、VGG_loss、KLD_loss四项。我已经拖了两三周没理它了,两眼一抹黑策略失败,后天必须给陈老师讲清楚。
GAN LOSS
在函数generate_fake
中,将标签图与GT作为生成器的输入,返回大小为[bs,3,256,256]
的逼真图像。
随后将语义标签图、GT、生成图作为鉴别器的输入,返回鉴别结果。鉴别器的输入看作是两个部分的结合:
- 语义标签图和生成图在
dim=1
处拼接得到fake_concat
- 语义标签图和GT在
dim=1
处拼接得到real_concat
- 两者在
dim=0
处拼接得到的real_and_fake
为输入
1 | fake_concat = torch.cat([input_semantics, fake_image], dim=1) # [bs, 305, 256, 256] |
然而此鉴别器由2个子鉴别器组合而成,第一个子鉴别的结构如下:
所以把fake_and_real
作为输入后,中间层的尺寸变化是这样的:
1 | intermediate_output: torch.Size([bs*2, 64, 129, 129]) |
并且在定义子鉴别器结构的函数NLayerDiscriminator
中,它把所有的中间结果都保存起来,组合成一个结果序列,并作为该鉴别器的计算结果返回,因此该序列的长度为5。
在函数MultiscaleDiscriminator
中,它再次将每个子鉴别器的返回结果作为序列保存起来,因此最后netD
返回的尺寸是opt.num_D x opt.n_layers_D
。
另外需要注意的是,第二个子鉴别器的输入尺寸是对第一个鉴别器的input做downsample
的结果,同理,第三个子鉴别器的输入尺寸是对第二个鉴别器的input做downsample
的结果。
接着在函数divide_pred
中需要将生成图和GT的鉴别结果分开,对于多鉴别器的结果序列而言:
1 | fake = [] |
在GANLoss
这个类中,有几个初始化的参数如下,其中target_real_label
的意思就是GT进入鉴别器后得到的结果应该越接近1越好,反之生成图越接近0越好:
1 | gan_mode target_real_label=1.0 target_fake_label=0.0 # 默认gan_mode为hinge |
计算GAN损失只需要鉴别器最后一层的结果,因此在criterionGAN
函数中,对pred_fake
的鉴别结果进行计算,并返回结果存储在G_losses
中:
1 | G_losses['GAN'] = self.criterionGAN(pred_fake, target_is_real=True, |
在loss
函数中,我们只需要看hinge部分:
1 | elif self.gan_mode == 'hinge': |
在上面这段代码中,计算GAN_loss
时的input尺寸是pred_fake
中两个元素的最后一层,尺寸分别为[bs, 1, 35, 35]
和[bs, 1, 19, 19]
,其预测的结果值应该在0~1的范围之内,且越接近1越好(target is real),才能证明生成器的效果好,并最终返回的是多个鉴别器的平均结果。
1 | if isinstance(input, list): |
GAN FEAT LOSS
GAN feat loss
实际上是计算GT和生成图分别丢给鉴别器后,每一层输出结果之间的L1损失,且不包括最后一层。
1 | # self.criterionFeat = torch.nn.L1Loss() |
opt.lambda_feat
等于10,最后返回的是多个鉴别器的平均结果。
VGG LOSS
GauGAN网络中使用的是VGG19,这个类用于计算感知损失perceptual loss
,即将生成图与GT作为VGG的输入,输出结果是长度为5的序列,每个元素是各层输出的中间特征。
在VGG中,计算GT和生成图相同层的中间特征的L1损失,最后将得到的损失值乘以lambda_vgg
,其值为10:
1 | # self.criterion = nn.L1loss() |
综上所述,所有损失都是基于全局计算,且越小越好的。
Discriminator
那么接下来,对于鉴别器而言,损失项包括两个:D_fake
和D_real
,我们继续一一介绍吧。
在函数compute_discriminator_loss
中,语义标签图和GT作为输入,此时假设已经训练好生成器G,那么通过generate_fake
得到生成图fake_image
。与上面相同,将GT和生成图丢进鉴别器中,返回得到pred_fake
和pred_real
鉴别结果。
但计算损失的时候略有不同,此时鉴别器的目标有两个:
- 对于
pred_fake
,鉴别器的鉴别结果越接近0越好。 - 对于
pred_real
,鉴别器的鉴别结果越接近1越好。
1 | D_losses['D_Fake'] = self.criterionGAN(pred_fake, target_is_real=False, |
D_fake LOSS
当gan_mode=hinge
,且目标为0时:
1 | def get_zero_tensor(self, input): |
-input-1
绝大多数情况为负值,其绝对值越小越好。比如当input=0.5
与input=0.2
时,-input - 1
的绝对值结果分别为1.5
和1.2
,显然对于生成图而言,input=0.2
的鉴别结果更好。
D_real LOSS
当gan_mode=hinge
,且目标为1时:
1 | def get_zero_tensor(self, input): |
同理,input-1
绝大多数情况为负值,其绝对值越小越好。比如当input=0.5
与input=0.2
时,input - 1
的绝对值结果分别为0.5
和0.8
,显然对于GT而言,input=0.5
的鉴别结果更好。
Optimizer
在函数run_generator_one_step
中,生成器损失g_loss
等于Generator中所有损失值之和的均值。同理,鉴别器损失d_loss
等于Discriminator
中所有损失值之和的均值。
然后通过Adam优化器进行优化迭代。