自回归Transformer模型示例代码分享
自回归Transformer是一类用于生成任务的模型,例如文本生成,各种自然语言处理任务等。以下是一个简单的自回归Transformer模型的示例代码,使用PyTorch框架来实现。这个示例以字符级别文本生成任务为例。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import log_softmax
class SimpleTransformer(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_layers, dropout=0.5):
super(SimpleTransformer, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.transformer = nn.Transformer(d_model, nhead, num_layers, num_layers, dim_feedforward=512, dropout=dropout)
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask):
src_emb = self.embedding(src) * math.sqrt(d_model)
tgt_emb = self.embedding(tgt) * math.sqrt(d_model)
output = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask)
return self.fc_out(output)
def generate(self, src, max_len, start_symbol, src_mask, src_padding_mask):
src_emb = self.embedding(src) * math.sqrt(d_model)
memory = self.transformer.encoder(src_emb, src_mask, src_padding_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(src.device)
for i in range(max_len-1):
tgt_emb = self.embedding(ys) * math.sqrt(d_model)
out = self.transformer.decoder(tgt_emb, memory, None, None)
prob = self.fc_out(out[:, -1])
_, next_word = torch.max(prob, dim=1)
next_word = next_word.item()
ys = torch.cat([ys, torch.ones(1, 1).fill_(next_word).type(torch.long).to(src.device)], dim=1)
if next_word == end_token:
break
return ys
# 超参数设置
vocab_size = 100 # 例如100个字符
d_model = 512
nhead = 8
num_layers = 6
max_len = 100
start_token = 0
end_token = 1
# 模型和优化器
model = SimpleTransformer(vocab_size, d_model, nhead, num_layers)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
# 训练循环和样本数据(需根据具体任务调整数据集)
for epoch in range(10): # 训练10个epoch
for batch, (src, tgt) in enumerate(train_data):
src_input = src
tgt_input = tgt[:, :-1]
tgt_output = tgt[:, 1:]
src_mask = model.transformer.generate_square_subsequent_mask(src_input.shape[0]).to(src.device)
tgt_mask = model.transformer.generate_square_subsequent_mask(tgt_input.shape[0]).to(tgt.device)
src_padding_mask = src_input == padding_idx
tgt_padding_mask = tgt_input == padding_idx
optimizer.zero_grad()
output = model(src_input, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask)
loss = criterion(output.reshape(-1, vocab_size), tgt_output.reshape(-1))
loss.backward()
optimizer.step()
if batch % 100 == 0:
print(f"Epoch {epoch}, Step {batch}, Loss: {loss.item()}")
# 使用模型进行生成
generated_symbols = model.generate(src_input, max_len, start_token, src_mask, src_padding_mask)
print("Generated sequence:", generated_symbols)
请根据您的数据集准备train_data
,并确保调整vocab_size
和其他超参数以适应具体任务。这段代码仅是基本结构,用于展示如何设置一个自回归Transformer进行简单的字符级生成任务。Use appropriate data and preprocessing for better results on specific tasks.