torch.nn.utils.clip_grad_norm_函数详解
torch.nn.utils.clip_grad_norm_
是 PyTorch 中用于控制神经网络训练时梯度范数的常用工具。梯度范数截断(gradient norm clipping)是一种常用的技术,用于解决梯度爆炸问题,特别是在训练循环神经网络(RNN)或深层网络时常会遇到这种问题。
以下是 torch.nn.utils.clip_grad_norm_
函数的详细信息:
函数定义
torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False)
参数详解
parameters: 要进行梯度剪裁的参数列表,通常是模型的参数(可以通过
model.parameters()
获取)。max_norm: 最大的梯度范数。梯度的总范数如果超过这个值,就会被缩放回这个范围。
norm_type: 使用的范数类型,默认为 L2 范数(即范数的类型为 2)。常用的值包括:
2.0
for L2 norm1.0
for L1 norm
你可以指定其他的范数类型,根据需要调整。
error_if_nonfinite: 布尔值,可选。默认为
False
。如果为True
,函数会在梯度中存在无穷大或 NaN 时抛出错误。
工作原理
clip_grad_norm_
的核心思想是:
1. 计算所有参数梯度的范数。
2. 如果这个范数超过了指定的 max_norm
,则按比例缩小每个梯度以使总范数等于 max_norm
。
这样的操作可以使得即便在反向传播过程中出现巨大的梯度(即梯度爆炸),也不会导致模型更新异常大,从而保持训练的稳定性。
返回值
该函数返回的是梯度的总范数(在剪裁之前计算的值)。
使用示例
import torch
import torch.nn as nn
import torch.optim as optim
# 假设我们有一个简单的模型
model = nn.Linear(10, 2)
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 假设有一个简单的损失
inputs = torch.randn(5, 10)
targets = torch.randn(5, 2)
criterion = nn.MSELoss()
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, targets)
# 反向传播
optimizer.zero_grad()
loss.backward()
# 梯度剪裁
max_norm = 1.0
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# 更新参数
optimizer.step()
在这个示例中,我们为一个简单的线性模型应用了梯度剪裁,以确保梯度的范数不会超过指定的最大值 max_norm
。这对保持训练过程的稳定性特别有用。