小批量随机梯度下降
概念
在每一次迭代中,梯度下降使用整个训练数据集来计算梯度,因此它有时也被称为批量梯度下降(batch gradient descent)。而随机梯度下降在每次迭代中只随机采样一个样本来计算梯度。
我们还可以在每轮迭代中随机均匀采样多个样本来组成一个小批量,然后使用这个小批量来计算梯度。下面就来描述小批量随机梯度下降。
我们可以通过重复采样(sampling with replacement)或者不重复采样(sampling without replacement)得到一个小批量中的各个样本。前者允许同一个小批量中出现重复的样本,后者则不允许如此,且更常见。
对于这两者间的任一种方式,都可以使用
gt←∇fBt(xt−1)=∣B∣1i∈Bt∑∇fi(xt−1)
同随机梯度一样,重复采样所得的小批量随机梯度gt也是对梯度∇f(xt−1)的无偏估计。给定学习率ηt(取正数),小批量随机梯度下降对自变量的迭代如下:
xt←xt−1−ηtgt.
小批量随机梯度下降中每次迭代的计算开销为O(∣B∣)。当批量大小为1时,该算法即为随机梯度下降;当批量大小等于训练数据样本数时,该算法即为梯度下降。当批量较小时,每次迭代中使用的样本少,这会导致并行处理和内存使用效率变低。这使得在计算同样数目样本的情况下比使用更大批量时所花时间更多。当批量较大时,每个小批量梯度里可能含有更多的冗余信息。为了得到较好的解,批量较大时比批量较小时需要计算的样本数目可能更多,例如增大迭代周期数。
由于mini-batch SGD 比 SGD 效果好很多,所以人们一般说SGD都指的是 mini-batch gradient descent. 大家不要和原始的SGD混淆。现在基本所有的大规模深度学习训练都是分为小batch进行训练的。
读取数据
使用一个来自NASA的测试不同飞机机翼噪音的数据集来比较各个优化算法 [1]。我们使用该数据集的前1,500个样本和5个特征,并使用标准化对数据进行预处理。
1 2 3 4 5 6 7
| def get_data_ch7(): data = np.genfromtxt('../../../Datasets/airfoil_self_noise.dat', delimiter='\t') data = (data - data.mean(axis=0)) / data.std(axis=0) return torch.tensor(data[:1500, :-1], dtype=torch.float32), \ torch.tensor(data[:1500, -1], dtype=torch.float32)
features, labels = get_data_ch7()
|
Pytorch实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| batch_size = 20 num_epoch = 5
dataloader = DataLoader(TensorDataset(features, labels), batch_size=batch_size, shuffle=True)
net = nn.Sequential( nn.Linear(5, 1) )
loss = nn.MSELoss() optimizer = optim.SGD(net.parameters(), lr=0.05)
def eval_loss(): return loss(net(features).view(-1), labels).item() / 2
ans = [eval_loss()]
for epoch in range(num_epoch): start = time.time() for step, (x, y) in enumerate(dataloader): output = net(x).view(-1) ls = loss(output, y) / 2 optimizer.zero_grad() ls.backward() optimizer.step()
if (step + 1) * batch_size % 100 == 0: ans.append(eval_loss())
print('loss: %f, %f sec per epoch' % (ans[-1], time.time() - start)) fig = plt.figure() ax = fig.add_subplot() plt.plot(np.linspace(0, num_epoch, len(ans)), ans) plt.xlabel('epoch') plt.ylabel('loss') plt.show()
|
最后结果如下:
1
| loss: 0.242433, 0.071540 sec per epoch
|
![](https://tva1.sinaimg.cn/large/007S8ZIlly1ghxdl7dwsij30hs0dc74k.jpg)