这两天整理一个GauGAN的损失函数及其曲线变化。

Generator

首先,对于生成器而言,损失函数包括如下几项:GAN_loss、Feat_loss、VGG_loss、KLD_loss四项。我已经拖了两三周没理它了,两眼一抹黑策略失败,后天必须给陈老师讲清楚。

GAN LOSS

在函数generate_fake中,将标签图与GT作为生成器的输入,返回大小为[bs,3,256,256]的逼真图像。

随后将语义标签图、GT、生成图作为鉴别器的输入,返回鉴别结果。鉴别器的输入看作是两个部分的结合:

  • 语义标签图和生成图在dim=1处拼接得到fake_concat
  • 语义标签图和GT在dim=1处拼接得到real_concat
  • 两者在dim=0处拼接得到的real_and_fake为输入
1
2
3
4
5
fake_concat = torch.cat([input_semantics, fake_image], dim=1)  # [bs, 305, 256, 256]
real_concat = torch.cat([input_semantics, real_image], dim=1)

fake_and_real = torch.cat([fake_concat, real_concat], dim=0) # [bs*2, 305, 256, 256]
discriminator_out = self.netD(fake_and_real)

然而此鉴别器由2个子鉴别器组合而成,第一个子鉴别的结构如下:

所以把fake_and_real作为输入后,中间层的尺寸变化是这样的:

1
2
3
4
5
intermediate_output: torch.Size([bs*2, 64, 129, 129])
intermediate_output: torch.Size([bs*2, 128, 65, 65])
intermediate_output: torch.Size([bs*2, 256, 33, 33])
intermediate_output: torch.Size([bs*2, 512, 34, 34])
intermediate_output: torch.Size([bs*2, 1, 35, 35])

并且在定义子鉴别器结构的函数NLayerDiscriminator中,它把所有的中间结果都保存起来,组合成一个结果序列,并作为该鉴别器的计算结果返回,因此该序列的长度为5。

在函数MultiscaleDiscriminator中,它再次将每个子鉴别器的返回结果作为序列保存起来,因此最后netD返回的尺寸是opt.num_D x opt.n_layers_D

另外需要注意的是,第二个子鉴别器的输入尺寸是对第一个鉴别器的input做downsample的结果,同理,第三个子鉴别器的输入尺寸是对第二个鉴别器的input做downsample的结果。

接着在函数divide_pred中需要将生成图和GT的鉴别结果分开,对于多鉴别器的结果序列而言:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
fake = []
real = []

for p in pred:
fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
real.append([tensor[tensor.size(0) // 2:]for tensor in p])

return fake, real # (bs, layers) layers = 5

# 长度为2的pred序列p中tensor的尺寸如下:
torch.Size([4, 64, 129, 129])
torch.Size([4, 128, 65, 65])
torch.Size([4, 256, 33, 33])
torch.Size([4, 512, 34, 34])
torch.Size([4, 1, 35, 35])
----------------------------
torch.Size([4, 64, 65, 65])
torch.Size([4, 128, 33, 33])
torch.Size([4, 256, 17, 17])
torch.Size([4, 512, 18, 18])
torch.Size([4, 1, 19, 19])

# 最后得到的pred_fake和pred_real的尺寸都是 (bs, 5)

GANLoss这个类中,有几个初始化的参数如下,其中target_real_label的意思就是GT进入鉴别器后得到的结果应该越接近1越好,反之生成图越接近0越好:

1
gan_mode target_real_label=1.0 target_fake_label=0.0  # 默认gan_mode为hinge

计算GAN损失只需要鉴别器最后一层的结果,因此在criterionGAN函数中,对pred_fake的鉴别结果进行计算,并返回结果存储在G_losses中:

1
2
G_losses['GAN'] = self.criterionGAN(pred_fake, target_is_real=True, 			
for_discriminator=False)

loss函数中,我们只需要看hinge部分:

1
2
3
4
5
6
7
8
9
10
11
12
13
elif self.gan_mode == 'hinge':
if for_discriminator:
if target_is_real:
minval = torch.min(input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)
else:
minval = torch.min(-input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)
else:
# GAN LOSS
assert target_is_real, "The generator's hinge loss must be aiming for real"
loss = -torch.mean(input) # torch.mean(input)越大越好 因此加个负号
return loss

在上面这段代码中,计算GAN_loss时的input尺寸是pred_fake中两个元素的最后一层,尺寸分别为[bs, 1, 35, 35][bs, 1, 19, 19],其预测的结果值应该在0~1的范围之内,且越接近1越好(target is real),才能证明生成器的效果好,并最终返回的是多个鉴别器的平均结果。

1
2
3
4
5
6
7
8
9
10
if isinstance(input, list):
loss = 0
for pred_i in input:
if isinstance(pred_i, list):
pred_i = pred_i[-1] # [bs, 1, 35, 35] [bs, 1, 19, 19]
loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
loss += new_loss
return loss / len(input)

GAN FEAT LOSS

GAN feat loss实际上是计算GT和生成图分别丢给鉴别器后,每一层输出结果之间的L1损失,且不包括最后一层。

1
2
3
4
5
6
7
8
# self.criterionFeat = torch.nn.L1Loss()

for i in range(num_D):
num_intermediate_outputs = len(pre_fake[i] - 1) # 4
for j in range(num_intermediate_outputs):
unweighted_loss = self.criterionFeat(
pred_fake[i][j], pred_real[i][j].detach())
GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D

opt.lambda_feat等于10,最后返回的是多个鉴别器的平均结果。

VGG LOSS

GauGAN网络中使用的是VGG19,这个类用于计算感知损失perceptual loss,即将生成图与GT作为VGG的输入,输出结果是长度为5的序列,每个元素是各层输出的中间特征。

在VGG中,计算GT和生成图相同层的中间特征的L1损失,最后将得到的损失值乘以lambda_vgg,其值为10:

1
2
3
4
5
6
7
8
9
10
11
# self.criterion = nn.L1loss()

x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weight[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss

# VGG损失的最后结果
G_losses['VGG'] = self.criterionVGG(fake_image, real_image) \
* self.opt.lambda_vgg

综上所述,所有损失都是基于全局计算,且越小越好的。

Discriminator

那么接下来,对于鉴别器而言,损失项包括两个:D_fakeD_real,我们继续一一介绍吧。

在函数compute_discriminator_loss中,语义标签图和GT作为输入,此时假设已经训练好生成器G,那么通过generate_fake得到生成图fake_image。与上面相同,将GT和生成图丢进鉴别器中,返回得到pred_fakepred_real鉴别结果。

但计算损失的时候略有不同,此时鉴别器的目标有两个:

  • 对于pred_fake,鉴别器的鉴别结果越接近0越好。
  • 对于pred_real,鉴别器的鉴别结果越接近1越好。
1
2
3
4
D_losses['D_Fake'] = self.criterionGAN(pred_fake, target_is_real=False,
for_discriminator=True)
D_losses['D_real'] = self.criterionGAN(pred_real, target_is_real=True,
for_discriminator=True)

D_fake LOSS

gan_mode=hinge,且目标为0时:

1
2
3
4
5
6
7
8
9
10
def get_zero_tensor(self, input):
if self.zero_tensor is None:
self.zero_tensor = self.Tensor(1).fill_(0)
self.zero_tensor.requires_grad_(False)
return self.zero_tensor.expand_as(input)

minval = torch.min(-input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)

return loss

-input-1绝大多数情况为负值,其绝对值越小越好。比如当input=0.5input=0.2时,-input - 1的绝对值结果分别为1.51.2,显然对于生成图而言,input=0.2的鉴别结果更好。

D_real LOSS

gan_mode=hinge,且目标为1时:

1
2
3
4
5
6
7
8
9
10
def get_zero_tensor(self, input):
if self.zero_tensor is None:
self.zero_tensor = self.Tensor(1).fill_(0)
self.zero_tensor.requires_grad_(False)
return self.zero_tensor.expand_as(input)

minval = torch.min(input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)

return loss

同理,input-1绝大多数情况为负值,其绝对值越小越好。比如当input=0.5input=0.2时,input - 1的绝对值结果分别为0.50.8,显然对于GT而言,input=0.5的鉴别结果更好。

Optimizer

在函数run_generator_one_step中,生成器损失g_loss等于Generator中所有损失值之和的均值。同理,鉴别器损失d_loss等于Discriminator中所有损失值之和的均值。

然后通过Adam优化器进行优化迭代。