生成对抗模型(GAN)简介可以参考:https://www.jianshu.com/p/34d9d0755f51
这里介绍的WGAN,将损失函数进行了正则化
文章链接:《Improving protein function prediction with synthetic feature samples created by generative adversarial networks》
这里作者提出了一种新的损失函数定义模式,对于普通GAN的损失函数定义:
由生成器 G 产生的 fake data,x 代表 real data,那么对于WGAN-GP 它的损失函数为:
其中 由生成器 G 产生的 fake data,x 代表 real data, 在本研究中代表:
代码部分:https://github.com/psipred/FFPredGAN/blob/master/src/Generating_Synthetic_Positive_Samples_FFPred-GAN.py
这里只重点讲讲目标函数约束的代码部分:
ITERS = 100000
CRITIC_ITERS = 5
# 训练模型
for iteration in range(ITERS):
for p in netD.parameters():
p.requires_grad = True
data = inf_train_gen()
real_data = torch.FloatTensor(data)
real_data_v = autograd.Variable(real_data)
noise = torch.randn(BATCH_SIZE, 258)
noisev = autograd.Variable(noise, volatile=True)
fake = autograd.Variable(netG(noisev, real_data_v).data)
fake_output=fake.data.cpu().numpy()
# 训练判别器 netD
for iter_d in range(CRITIC_ITERS):
# 梯度清零
netD.zero_grad()
D_real, hidden_output_real_1, hidden_output_real_2, hidden_output_real_3 = netD(real_data_v)
# 高维张量取平均值,变成一个标量
D_real = D_real.mean()
noise = torch.randn(BATCH_SIZE, 258)
noisev = autograd.Variable(noise, volatile=True)
fake = autograd.Variable(netG(noisev, real_data_v).data)
inputv = fake
D_fake, hidden_output_fake_1, hidden_output_fake_2, hidden_output_fake_3 = netD(inputv)
# 高维张量取平均值,变成一个标量
D_fake = D_fake.mean()
# 计算正则项
gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data)
# WGAN-GP 损失函数
D_cost = D_fake - D_real + gradient_penalty
Wasserstein_D = D_real - D_fake
# 反向传播损失函数
D_cost.backward()
# 迭代更新
optimizerD.step()
# 训练生成器 netG
for p in netD.parameters():
p.requires_grad = False
netG.zero_grad()
real_data = torch.Tensor(data)
real_data_v = autograd.Variable(real_data)
noise = torch.randn(BATCH_SIZE, 258)
noisev = autograd.Variable(noise)
fake = netG(noisev, real_data_v)
G, hidden_output_ignore_1, hidden_output_ignore_2, hidden_output_ignore_3 = netD(fake)
G = G.mean()
G_cost = -G
# 反向传播损失函数
G_cost.backward()
# 迭代更新
optimizerG.step()
计算gradient_penalty的代码为:
def calc_gradient_penalty(netD, real_data, fake_data):
alpha = torch.rand(BATCH_SIZE, 1)
alpha = alpha.expand(real_data.size())
alpha = alpha.cuda() if use_cuda else alpha
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates, hidden_output_1, hidden_output_2, hidden_output_3 = netD(interpolates)
# 求梯度
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(
disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
# 正则项,二阶范数
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gradient_penalty