一百行代码实现一个GAN模型

GAN:对抗性生成网络,通俗来讲,即有两个网络一个是g(generator )网络,用于生成,一个是d(discriminator)网络,用于判断。
GAN网络的目的就是使其自己生成一副图片,比如说经过对一系列猫的图片的处理,g网络可以自己“绘制”出一张猫的图片,且尽量真实。
d网络则是用来进行判断的,将一张真实的图片和一张由g网络生成的照片同时交给d网络,不断训练d网络,使其可以准确判断,将d网络生成的“假图片”找出来。
再回到两个网络上,g网络不断改进使其可以骗过d网络,而d网络不断改进使其可以更准确找到“假图片”,这种相互促进相互对抗的关系,就叫做对抗网络。

我们可以使用tensorflow中的mnist手写体数据来进行实现。
实现原理如下:


将一张随机像素的图片经过一个全连接层后经过一个Leaky ReLU处理,之后为了避免过拟合dropout后再经过一个全连接层进行tanh激活后,生成一张“假图片”

def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
    with tf.variable_scope("generator", reuse=reuse):
        hidden1 = tf.layers.dense(noise_img, n_units)  # 全连接层
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
        logits = tf.layers.dense(hidden1, out_dim)
        outputs = tf.tanh(logits)
        return logits, outputs

将待判定的图片经过全连接层-->Leaky ReLU-->全连接层-->sigmoid激活函数处理后,得到0或1的结果。

def get_discriminator(img, n_units, reuse=False, alpha=0.01):
    with tf.variable_scope("discriminator", reuse=reuse):
        hidden1 = tf.layers.dense(img, n_units)
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        logits = tf.layers.dense(hidden1, 1)
        outputs = tf.sigmoid(logits)
        return logits, outputs

在实现时,我们可以首先把MNIST数据中的标签为0的图像提取出来,存到列表中。

i = j = 0
while i<5000:
    if mnist.train.labels[j] == 0:
        samples.append(mnist.train.images[j])
        i += 1
    j += 1

这样就可以在训练时只训练标签为0的图像。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np

mnist = input_data.read_data_sets("D:/python/MNIST_data/")
img = mnist.train.images[50]


def get_inputs(real_size, noise_size):
    real_img = tf.placeholder(tf.float32, [None, real_size], name="real_img")
    noise_img = tf.placeholder(tf.float32, [None, noise_size], name="noise_img")
    return real_img, noise_img


# 生成
def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
    with tf.variable_scope("generator", reuse=reuse):
        hidden1 = tf.layers.dense(noise_img, n_units)  # 全连接层
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
        logits = tf.layers.dense(hidden1, out_dim)
        outputs = tf.tanh(logits)
        return logits, outputs


# 判别
def get_discriminator(img, n_units, reuse=False, alpha=0.01):
    with tf.variable_scope("discriminator", reuse=reuse):
        hidden1 = tf.layers.dense(img, n_units)
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        logits = tf.layers.dense(hidden1, 1)
        outputs = tf.sigmoid(logits)
        return logits, outputs

img_size = mnist.train.images[0].shape[0]
noise_size = 100
g_units = 128
d_units = 128
alpha = 0.01
learning_rate = 0.001
smooth = 0.1
tf.reset_default_graph()
real_img, noise_img = get_inputs(img_size, noise_size)
g_logits, g_outputs = get_generator(noise_img, g_units, img_size)

d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)

d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_real, labels=tf.ones_like(d_logits_real)
) * (1 - smooth))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)
))
d_loss = tf.add(d_loss_real, d_loss_fake)
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)
) * (1 - smooth))

train_vars = tf.trainable_variables()
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]

d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)


epochs = 5000
samples = []
n_sample = 10
losses = []

i = j = 0
while i<5000:
    if mnist.train.labels[j] == 0:
        samples.append(mnist.train.images[j])
        i += 1
    j += 1

print(len(samples))
size = samples[0].size

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for e in range(epochs):
        batch_images = samples[e] * 2 -1
        batch_noise = np.random.uniform(-1, 1, size=noise_size)

        _ = sess.run(d_train_opt, feed_dict={real_img:[batch_images], noise_img:[batch_noise]})
        _ = sess.run(g_train_opt, feed_dict={noise_img:[batch_noise]})

    sample_noise = np.random.uniform(-1, 1, size=noise_size)
    g_logit, g_output = sess.run(get_generator(noise_img, g_units, img_size,
                                         reuse=True), feed_dict={
        noise_img:[sample_noise]
    })
    print(g_logit.size)
    g_output = (g_output+1)/2
    plt.imshow(g_output.reshape([28, 28]), cmap='Greys_r')
    plt.show()

运行结果:


可以看出,在经过了5000次的迭代后,g网络生成的图片已经可以大致呈现出一个0的形状。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 159,117评论 4 362
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,328评论 1 293
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,839评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,007评论 0 206
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,384评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,629评论 1 219
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,880评论 2 313
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,593评论 0 198
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,313评论 1 243
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,575评论 2 246
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,066评论 1 260
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,392评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,052评论 3 236
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,082评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,844评论 0 195
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,662评论 2 274
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,575评论 2 270

推荐阅读更多精彩内容