0%

GAN - 生成对抗网络的基础代码实现

最近沉迷于GAN,整理了一下基础部分的代码。

原理简介

GAN(Generative adversarial network),即生成对抗网络。网络中有两个模型,分别是生成器(Generator)和判别器(Discriminator)。生成器负责生成所需数据,优化的思路是让他骗过判别器的判断;判别器负责对数据进行判断,优化的思路是让他的判断更准确。

库文件及工具函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt # 用于显示图像

# 从Tensor显示图像
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
image_unflat = image_tensor.detach().cpu().view(-1, *size)
image_grid = make_grid(image_unflat[:num_images], nrow=5)
plt.imshow(image_grid.permute(1, 2, 0).squeeze())
plt.show()

# 生成一个随机噪音
def get_noise(n_samples, z_dim, device='cpu'):
return torch.randn(n_samples, z_dim, device=device)

生成器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Generator(nn.Module):
# 生成器传入噪声(随机数),输出所需图片
# 这里使用最简单的模型结构,即三层线性层
def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
super(Generator, self).__init__()
self.gen = nn.Sequential(
get_generator_block(z_dim, hidden_dim),
get_generator_block(hidden_dim, hidden_dim * 2),
get_generator_block(hidden_dim * 2, hidden_dim * 4),
get_generator_block(hidden_dim * 4, hidden_dim * 8),
nn.Linear(hidden_dim * 8, im_dim),
nn.Sigmoid()
)
def forward(self, noise):
return self.gen(noise)
def get_gen(self):
return self.gen

def get_generator_block(input_dim, output_dim):
return nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.BatchNorm1d(output_dim),
nn.ReLU(inplace=True),
)

判别器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Discriminator(nn.Module):
# 生成器传入数据,输出一个数字,表示是真实图片还是生成图片
def __init__(self, im_dim=784, hidden_dim=128):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
get_discriminator_block(im_dim, hidden_dim * 4),
get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
get_discriminator_block(hidden_dim * 2, hidden_dim),
nn.Linear(hidden_dim, 1)
)
def forward(self, image):
return self.disc(image)
def get_disc(self):
return self.disc

def get_discriminator_block(input_dim, output_dim):
return nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.LeakyReLU(0.2, inplace=True)
)

超参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
criterion = nn.BCEWithLogitsLoss() # 交叉熵损失来使得值更靠近1或者0
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
device = 'cuda'

dataloader = DataLoader(
MNIST('.', download=False, transform=transforms.ToTensor()),
batch_size=batch_size,
shuffle=True
)

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

损失函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
fake_noise = get_noise(num_images, z_dim, device=device)
fake = gen(fake_noise)
disc_fake_pred = disc(fake.detach())
disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred)) # 让假数据的判断接近0
disc_real_pred = disc(real)
disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred)) # 让真数据的判断接近1
disc_loss = (disc_fake_loss + disc_real_loss) / 2
return disc_loss


def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
fake_noise = get_noise(num_images, z_dim, device=device)
fake = gen(fake_noise)
disc_fake_pred = disc(fake)
gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) # 让假数据的判断接近1
return gen_loss

训练过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
logs = []
msg = ''
for epoch in range(n_epochs):
for real, _ in tqdm(dataloader, desc=f'epoch {epoch + 1}'):
cur_batch_size = len(real)
real = real.view(cur_batch_size, -1).to(device)

# 训练判别器
disc_opt.zero_grad()
disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)
disc_loss.backward(retain_graph=True)
disc_opt.step()

# 训练生成器
gen_opt.zero_grad()
gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
gen_loss.backward()
gen_opt.step()

mean_discriminator_loss += disc_loss.item() / display_step
mean_generator_loss += gen_loss.item() / display_step

# 记录数据
if cur_step % display_step == 0 and cur_step > 0:
# print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
msg = f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}"
logs.append(msg)
fake_noise = get_noise(cur_batch_size, z_dim, device=device)
fake = gen(fake_noise)
show_tensor_images(fake)
show_tensor_images(real)
mean_discriminator_loss = 0
mean_generator_loss = 0
cur_step += 1

for str in logs:
print(str)

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
image_unflat = image_tensor.detach().cpu().view(-1, *size)
image_grid = make_grid(image_unflat[:num_images], nrow=5)
plt.imshow(image_grid.permute(1, 2, 0).squeeze())
plt.show()


def get_generator_block(input_dim, output_dim):
return nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.BatchNorm1d(output_dim),
nn.ReLU(inplace=True),
)


class Generator(nn.Module):
def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
super(Generator, self).__init__()
self.gen = nn.Sequential(
get_generator_block(z_dim, hidden_dim),
get_generator_block(hidden_dim, hidden_dim * 2),
get_generator_block(hidden_dim * 2, hidden_dim * 4),
get_generator_block(hidden_dim * 4, hidden_dim * 8),
nn.Linear(hidden_dim * 8, im_dim),
nn.Sigmoid()
)
def forward(self, noise):
return self.gen(noise)
def get_gen(self):
return self.gen


def get_noise(n_samples, z_dim, device='cpu'):
return torch.randn(n_samples, z_dim, device=device)


def get_discriminator_block(input_dim, output_dim):
return nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.LeakyReLU(0.2, inplace=True)
)


class Discriminator(nn.Module):
def __init__(self, im_dim=784, hidden_dim=128):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
get_discriminator_block(im_dim, hidden_dim * 4),
get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
get_discriminator_block(hidden_dim * 2, hidden_dim),
nn.Linear(hidden_dim, 1)
)
def forward(self, image):
return self.disc(image)
def get_disc(self):
return self.disc


criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
device = 'cuda'

dataloader = DataLoader(
MNIST('.', download=False, transform=transforms.ToTensor()),
batch_size=batch_size,
shuffle=True
)

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)


def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
fake_noise = get_noise(num_images, z_dim, device=device)
fake = gen(fake_noise)
disc_fake_pred = disc(fake.detach())
disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
disc_real_pred = disc(real)
disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
disc_loss = (disc_fake_loss + disc_real_loss) / 2
return disc_loss


def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
fake_noise = get_noise(num_images, z_dim, device=device)
fake = gen(fake_noise)
disc_fake_pred = disc(fake)
gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
return gen_loss


cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
logs = []
msg = ''
for epoch in range(n_epochs):
for real, _ in tqdm(dataloader, desc=f'epoch {epoch + 1}'):
cur_batch_size = len(real)
real = real.view(cur_batch_size, -1).to(device)

disc_opt.zero_grad()
disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)
disc_loss.backward(retain_graph=True)
disc_opt.step()

gen_opt.zero_grad()
gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
gen_loss.backward()
gen_opt.step()

mean_discriminator_loss += disc_loss.item() / display_step
mean_generator_loss += gen_loss.item() / display_step

if cur_step % display_step == 0 and cur_step > 0:
# print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
msg = f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}"
logs.append(msg)
fake_noise = get_noise(cur_batch_size, z_dim, device=device)
fake = gen(fake_noise)
show_tensor_images(fake)
show_tensor_images(real)
mean_discriminator_loss = 0
mean_generator_loss = 0
cur_step += 1

for str in logs:
print(str)