快速风格迁移网络学习

  快速风格迁移网络由《Perceptual Losses for Real-Time Style Transfer and Super-Resolution》提出,其由生成网络和损失网络组成,如下图所示。

faststyle1

快速风格迁移网络

  原图片经过生成网络得到生成的风格图片后分别与目标图片和原图片计算两种损失。其中constant损失使生成图片与原图片整体相似,而风格损失则使生成的图片的风格向目标图片偏移。两种损失均为l2损失。图片的风格由Gram矩阵得到,Gram矩阵的计算公式为

1
Gram=A^T·A

  具体代码为:

1
2
3
4
b,c,h,w = input.size()
F = input.view(b,c,h*w)
G = torch.bmm(F,F.transpose(1,2))
G.div_(h*w)

  快速风格迁移效果如下图所示:
faststyle1

实验效果

  代码地址为My Pytorch-Learn fast style Transfer