当前位置: 首页>后端>正文

WGAN-GP 原理和代码分析

生成对抗模型(GAN)简介可以参考:https://www.jianshu.com/p/34d9d0755f51
这里介绍的WGAN,将损失函数进行了正则化

文章链接:《Improving protein function prediction with synthetic feature samples created by generative adversarial networks》

这里作者提出了一种新的损失函数定义模式,对于普通GAN的损失函数定义:


WGAN-GP 原理和代码分析,第1张

WGAN-GP 原理和代码分析,\widetilde{x},第2张 由生成器 G 产生的 fake data,x 代表 real data,那么对于WGAN-GP 它的损失函数为:

WGAN-GP 原理和代码分析,第3张

其中 WGAN-GP 原理和代码分析,\widetilde{x},第2张 由生成器 G 产生的 fake data,x 代表 real data,WGAN-GP 原理和代码分析,\widehat{x},第5张 在本研究中代表:
WGAN-GP 原理和代码分析,第6张
α 代表随机的参数,λ 这一项代表正则项作为梯度约束

代码部分:https://github.com/psipred/FFPredGAN/blob/master/src/Generating_Synthetic_Positive_Samples_FFPred-GAN.py

这里只重点讲讲目标函数约束的代码部分:

ITERS = 100000 
CRITIC_ITERS = 5

# 训练模型
for iteration in range(ITERS):
    for p in netD.parameters():  
        p.requires_grad = True  

    data = inf_train_gen()
    real_data = torch.FloatTensor(data)
    real_data_v = autograd.Variable(real_data)
    
    noise = torch.randn(BATCH_SIZE, 258)
    noisev = autograd.Variable(noise, volatile=True)  
    fake = autograd.Variable(netG(noisev, real_data_v).data)

    fake_output=fake.data.cpu().numpy()
    
    # 训练判别器 netD
    for iter_d in range(CRITIC_ITERS):
        # 梯度清零
        netD.zero_grad()

        D_real, hidden_output_real_1, hidden_output_real_2, hidden_output_real_3 = netD(real_data_v)

        # 高维张量取平均值,变成一个标量
        D_real = D_real.mean()

        noise = torch.randn(BATCH_SIZE, 258)
        noisev = autograd.Variable(noise, volatile=True)  
        fake = autograd.Variable(netG(noisev, real_data_v).data)
        
        inputv = fake
        D_fake, hidden_output_fake_1, hidden_output_fake_2, hidden_output_fake_3 = netD(inputv)
       
        # 高维张量取平均值,变成一个标量
        D_fake = D_fake.mean()
        
        # 计算正则项
        gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data)
        
        # WGAN-GP 损失函数
        D_cost = D_fake - D_real + gradient_penalty
        Wasserstein_D = D_real - D_fake

        # 反向传播损失函数
        D_cost.backward()
        # 迭代更新
        optimizerD.step()

    # 训练生成器 netG
    for p in netD.parameters():
            p.requires_grad = False

        netG.zero_grad()
        real_data = torch.Tensor(data)
        real_data_v = autograd.Variable(real_data)

        noise = torch.randn(BATCH_SIZE, 258)
        noisev = autograd.Variable(noise)
        fake = netG(noisev, real_data_v)
        G, hidden_output_ignore_1, hidden_output_ignore_2, hidden_output_ignore_3 = netD(fake)

        G = G.mean()
        G_cost = -G
        # 反向传播损失函数
        G_cost.backward()
        # 迭代更新
        optimizerG.step()

计算gradient_penalty的代码为:

def calc_gradient_penalty(netD, real_data, fake_data):
    alpha = torch.rand(BATCH_SIZE, 1)
    alpha = alpha.expand(real_data.size())
    alpha = alpha.cuda() if use_cuda else alpha

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    interpolates = autograd.Variable(interpolates, requires_grad=True)

    disc_interpolates, hidden_output_1, hidden_output_2, hidden_output_3 = netD(interpolates) 
    
    # 求梯度
    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(
                                  disc_interpolates.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    
    # 正则项,二阶范数
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty

https://www.xamrdz.com/backend/39d1997602.html

相关文章: