1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
| import torch from torch import nn from tqdm.auto import tqdm from torchvision import transforms from torchvision.datasets import MNIST from torchvision.utils import make_grid from torch.utils.data import DataLoader import matplotlib.pyplot as plt
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)): image_unflat = image_tensor.detach().cpu().view(-1, *size) image_grid = make_grid(image_unflat[:num_images], nrow=5) plt.imshow(image_grid.permute(1, 2, 0).squeeze()) plt.show()
def get_generator_block(input_dim, output_dim): return nn.Sequential( nn.Linear(input_dim, output_dim), nn.BatchNorm1d(output_dim), nn.ReLU(inplace=True), )
class Generator(nn.Module): def __init__(self, z_dim=10, im_dim=784, hidden_dim=128): super(Generator, self).__init__() self.gen = nn.Sequential( get_generator_block(z_dim, hidden_dim), get_generator_block(hidden_dim, hidden_dim * 2), get_generator_block(hidden_dim * 2, hidden_dim * 4), get_generator_block(hidden_dim * 4, hidden_dim * 8), nn.Linear(hidden_dim * 8, im_dim), nn.Sigmoid() ) def forward(self, noise): return self.gen(noise) def get_gen(self): return self.gen
def get_noise(n_samples, z_dim, device='cpu'): return torch.randn(n_samples, z_dim, device=device)
def get_discriminator_block(input_dim, output_dim): return nn.Sequential( nn.Linear(input_dim, output_dim), nn.LeakyReLU(0.2, inplace=True) )
class Discriminator(nn.Module): def __init__(self, im_dim=784, hidden_dim=128): super(Discriminator, self).__init__() self.disc = nn.Sequential( get_discriminator_block(im_dim, hidden_dim * 4), get_discriminator_block(hidden_dim * 4, hidden_dim * 2), get_discriminator_block(hidden_dim * 2, hidden_dim), nn.Linear(hidden_dim, 1) ) def forward(self, image): return self.disc(image) def get_disc(self): return self.disc
criterion = nn.BCEWithLogitsLoss() n_epochs = 200 z_dim = 64 display_step = 500 batch_size = 128 lr = 0.00001 device = 'cuda'
dataloader = DataLoader( MNIST('.', download=False, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True )
gen = Generator(z_dim).to(device) gen_opt = torch.optim.Adam(gen.parameters(), lr=lr) disc = Discriminator().to(device) disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device): fake_noise = get_noise(num_images, z_dim, device=device) fake = gen(fake_noise) disc_fake_pred = disc(fake.detach()) disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred)) disc_real_pred = disc(real) disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred)) disc_loss = (disc_fake_loss + disc_real_loss) / 2 return disc_loss
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device): fake_noise = get_noise(num_images, z_dim, device=device) fake = gen(fake_noise) disc_fake_pred = disc(fake) gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) return gen_loss
cur_step = 0 mean_generator_loss = 0 mean_discriminator_loss = 0 logs = [] msg = '' for epoch in range(n_epochs): for real, _ in tqdm(dataloader, desc=f'epoch {epoch + 1}'): cur_batch_size = len(real) real = real.view(cur_batch_size, -1).to(device)
disc_opt.zero_grad() disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device) disc_loss.backward(retain_graph=True) disc_opt.step()
gen_opt.zero_grad() gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device) gen_loss.backward() gen_opt.step()
mean_discriminator_loss += disc_loss.item() / display_step mean_generator_loss += gen_loss.item() / display_step
if cur_step % display_step == 0 and cur_step > 0: msg = f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}" logs.append(msg) fake_noise = get_noise(cur_batch_size, z_dim, device=device) fake = gen(fake_noise) show_tensor_images(fake) show_tensor_images(real) mean_discriminator_loss = 0 mean_generator_loss = 0 cur_step += 1
for str in logs: print(str)
|