缘由
更新方法直接修改optimizer中的lr参数;自己在尝试了官方的代码后就想提高训练的精度就想到了调整学习率,但固定的学习率肯定不适合训练就尝试了几个更改学习率的方法,但没想到居然更差!可能有几个学习率没怎么尝试吧!
- 定义一个简单的神经网络模型:y=Wx+b
import torchimport m���,����atplotlib.pyplot as plt%matplotlib inlinefrom torch.optim import *import torch.nn as nnclass net(nn.Module): def __init__(self): super(net,self).__init__() self.fc = nn.Linear(1,10) def forward(self,x): return self.fc(x)
- 直接更改lr的值
model = net()LR = 0.01optimizer = Adam(model.parameters(),lr = LR)lr_list = []for epoch in range(100): if epoch % 5 == 0: for p in optimizer.param_groups: p['lr'] *= 0.9 lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])plt.plot(range(100),lr_list,color = 'r')
关键是如下两行能达到手动阶梯式更改,自己也可按需求来更改变换函数
for p in optimizer.param_groups:p['lr'] *= 0.9利用lr_scheduler()提供的几种衰减函数
- torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
import numpy as np lr_list = []model = net()LR = 0.01optimizer = Adam(model.parameters(),lr = LR)lambda1 = lambda epoch:np.sin(epoch) / epochscheduler = lr_scheduler.LambdaLR(optimizer,lr_lambda = lambda1)for epoch in range(100): scheduler.step() lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])plt.plot(range(100),lr_list,color = 'r')
- torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1)
lr_list = []model = net()LR = 0.01optimizer = Adam(model.parameters(),lr = LR)scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max = 20)for epoch in range(100): scheduler.step() lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])plt.plot(range(100),lr_list,color = 'r')
- torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=‘min’, factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode=‘rel’, cooldown=0, min_lr=0, eps=1e-08)
在发现loss不再降低或者acc不再提高之后,降低学习率。各参数意义如下:
如需了解其它学习率更新方法请访问: https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/示例使用的更新方法
代码中可选的选项有:余弦方式(默认方式,其他两种注释了)、e^-x的方式以及按loss是否不在降低来判断的三种方式,其他就自己测试吧!
训练截图(第一个图为trainingg_loss,第二个为学习率变化曲线)
import torchimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as pltimport numpy as npimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom datetime import datetimefrom torch.utils.tensorboard import SummaryWriterfrom torch.optim import *transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=0)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=0)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')#如需了解示例完整代码及其后续内容请访问: [https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/](https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/)