1 Star 0 Fork 0

jiang/GAN

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
generate_mnist.py 12.89 KB
一键复制 编辑 原始数据 按行查看 历史
jiang 提交于 2018-05-28 10:32 . 重命名文件ttt.py为generate_mnist.py
# encoding=utf-8
import os
import numpy as np
import tensorflow as tf
import tensorflow.contrib.layers as layers
def reduce(feature_map):
"""
:param feature_map: [batch_size, height, width, depth]
:return:
"""
static_shape = feature_map.shape.as_list()
dynamic_shape = tf.shape(feature_map)
shape = [d if d else dynamic_shape[i] for i, d in enumerate(static_shape)]
padding = tf.stack([[0, 0],
[0, shape[1] % 2],
[0, shape[2] % 2],
[0, 0]])
feature_map = tf.pad(feature_map, padding)
new_shape_ = tf.stack([shape[0], (1 + shape[1]) // 2, 2, (1 + shape[2]) // 2, 2, shape[3]])
feature_map = tf.reshape(feature_map, new_shape_)
feature_map = tf.transpose(feature_map, [0, 1, 3, 2, 4, 5])
new_shape = tf.stack([shape[0], (1 + shape[1]) // 2, (1 + shape[2]) // 2, 2 * 2 * shape[3]])
feature_map = tf.reshape(feature_map, new_shape)
return feature_map
def expand(feature_map):
"""
:param feature_map: [batch_size, height, width, depth]
:return:
"""
static_shape = feature_map.shape.as_list()
dynamic_shape = tf.shape(feature_map)
shape = [d if d else dynamic_shape[i] for i, d in enumerate(static_shape)]
new_shape_ = tf.stack([shape[0], shape[1], shape[2], 2, 2, shape[3] // 4])
feature_map = tf.reshape(feature_map, new_shape_)
feature_map = tf.transpose(feature_map, [0, 1, 3, 2, 4, 5])
new_shape = tf.stack([shape[0], shape[1] * 2, shape[2] * 2, shape[3] // 4])
feature_map = tf.reshape(feature_map, new_shape)
return feature_map
def conv_cond_concat(x, y):
"""Concatenate conditioning vector on feature map axis.
:param x:
:param y:
:return:
"""
x_shapes = x.get_shape()
y_shapes = y.get_shape()
return tf.concat([x, y * tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)
def build_generator(z, y=None, reuse=False):
with tf.variable_scope("generator") as scope:
if reuse:
scope.reuse_variables()
if y is not None:
z = tf.concat((z, y), axis=-1)
yy = tf.reshape(y, [-1, 1, 1, 10])
fc0 = layers.fully_connected(z, 4 * 4 * 256, activation_fn=tf.nn.leaky_relu)
conv1 = tf.reshape(fc0, [-1, 4, 4, 256]) # 2x2
conv1 = layers.batch_norm(conv1, updates_collections="generator", reuse=reuse,
scope="batch_norm1")
if y is not None:
conv1 = conv_cond_concat(conv1, yy)
# conv2 = layers.conv2d(conv1, 4 * 64, kernel_size=3, activation_fn=tf.nn.leaky_relu)
# conv2 = expand(conv2) # 4x4
# conv2 = layers.batch_norm(conv2, updates_collections="generator", reuse=reuse,
# scope="batch_norm2")
# if y is not None:
# conv2 = conv_cond_concat(conv2, yy)
conv3 = layers.conv2d(conv1, 4 * 128, kernel_size=3, activation_fn=tf.nn.leaky_relu)
conv3 = expand(conv3) # 8x8
conv3 = layers.batch_norm(conv3, updates_collections="generator", reuse=reuse,
scope="batch_norm3")
if y is not None:
conv3 = conv_cond_concat(conv3, yy)
conv4 = layers.conv2d(conv3, 4 * 32, kernel_size=3, activation_fn=tf.nn.leaky_relu)
conv4 = expand(conv4) # 16x16
conv4 = layers.batch_norm(conv4, updates_collections="generator", reuse=reuse,
scope="batch_norm4")
if y is not None:
conv4 = conv_cond_concat(conv4, yy)
conv5 = layers.conv2d(conv4, 4 * 1, kernel_size=3, activation_fn=tf.nn.sigmoid)
conv5 = expand(conv5) # 32x32
output = tf.slice(conv5, [0, 2, 2, 0], [-1, 28, 28, -1])
return output
def build_discriminator(x, y=None, reuse=False):
with tf.variable_scope("discriminator") as scope:
if reuse:
scope.reuse_variables()
y_dim = 0
if y is not None:
yy = tf.reshape(y, [-1, 1, 1, 10])
y_dim = 10
x = 2 * x - 0.5
if y is not None:
x = conv_cond_concat(x, yy)
conv1 = layers.conv2d(x, 128, kernel_size=5, activation_fn=tf.nn.leaky_relu)
conv1 = reduce(conv1) # 14x14
conv1 = layers.batch_norm(conv1, updates_collections="discriminator", reuse=reuse,
scope="batch_norm1")
if y is not None:
conv1 = conv_cond_concat(conv1, yy)
conv2 = layers.conv2d(conv1, 256, kernel_size=3, activation_fn=tf.nn.leaky_relu)
conv2 = reduce(conv2) # 7x7
conv2 = layers.batch_norm(conv2, updates_collections="discriminator", reuse=reuse,
scope="batch_norm2")
if y is not None:
conv2 = conv_cond_concat(conv2, yy)
conv3 = layers.conv2d(conv2, 256, kernel_size=3, activation_fn=tf.nn.leaky_relu)
conv3 = reduce(conv3) # 4x4
conv3 = layers.batch_norm(conv3, updates_collections="discriminator", reuse=reuse,
scope="batch_norm3")
if y is not None:
conv3 = conv_cond_concat(conv3, yy)
flat = layers.flatten(conv3)
fc4 = layers.fully_connected(flat, 1024, activation_fn=tf.nn.leaky_relu)
if y is not None:
fc4 = tf.concat((fc4, y), axis=-1)
fc5 = layers.fully_connected(fc4, 1, activation_fn=tf.nn.sigmoid)
return fc5
def build_gan():
end_points = {}
batch_size = 32
global_step = tf.train.get_or_create_global_step()
update_step_op = tf.assign_add(global_step, 1)
real_x = tf.placeholder(tf.float32, [batch_size, 28, 28, 1])
real_y = tf.placeholder(tf.float32, [batch_size, 10])
random_x = tf.random_normal([batch_size, 100])
random_y = tf.random_uniform([batch_size], minval=0, maxval=10, dtype=tf.int32)
random_y = tf.one_hot(random_y, 10, dtype=tf.float32)
fake_x = build_generator(random_x, random_y, reuse=False)
unstack_fake_x = tf.unstack(fake_x)
for i in range(min(32, len(unstack_fake_x))):
tf.summary.image("fake_x/%d" % i, tf.expand_dims(unstack_fake_x[i], axis=0))
unstack_real_x = tf.unstack(real_x)
for i in range(min(32, len(unstack_real_x))):
tf.summary.image("real_x/%d" % i, tf.expand_dims(unstack_real_x[i], axis=0))
real_prd = build_discriminator(real_x, real_y, reuse=False)
fake_prd = build_discriminator(fake_x, random_y, reuse=True)
random_y_ = tf.random_uniform([batch_size], minval=0, maxval=10, dtype=tf.int32)
random_y_ = tf.one_hot(random_y_, 10, dtype=tf.float32)
real_prd_rdm = build_discriminator(real_x, random_y_, reuse=True)
label_rdm = tf.reduce_sum(random_y_ * real_y, axis=-1, keep_dims=True)
noise_label = tf.random_uniform(tf.shape(real_prd), minval=0.98, maxval=1.0)
noise_label_ = tf.random_uniform(tf.shape(real_prd), minval=0.0, maxval=0.02)
dis_loss_p = -(noise_label * tf.log(real_prd) + (1 - noise_label) * tf.log(1 - real_prd))
dis_loss_n = -(noise_label_ * tf.log(fake_prd) + (1 - noise_label_) * tf.log(1 - fake_prd))
dis_loss_rd = -(label_rdm * tf.log(real_prd_rdm) + (1 - label_rdm) * tf.log(1 - real_prd_rdm))
dis_loss = dis_loss_p + dis_loss_n + dis_loss_rd
gen_loss = -(tf.ones_like(fake_prd) * tf.log(fake_prd))
dis_loss_p = tf.reduce_mean(dis_loss_p)
dis_loss_n = tf.reduce_mean(dis_loss_n)
dis_loss = tf.reduce_mean(dis_loss)
gen_loss = tf.reduce_mean(gen_loss)
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
dis_ups = tf.get_collection("discriminator")
with tf.control_dependencies(dis_ups):
# train_dis_op = tf.train.GradientDescentOptimizer(learning_rate=0.0001).minimize(loss=dis_loss, var_list=d_vars)
train_dis_p_op = tf.train.GradientDescentOptimizer(learning_rate=0.0001).minimize(loss=dis_loss_p,
var_list=d_vars)
train_dis_n_op = tf.train.GradientDescentOptimizer(learning_rate=0.0001).minimize(loss=dis_loss_n,
var_list=d_vars)
train_dis_op = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5).minimize(loss=dis_loss, var_list=d_vars)
gen_ups = tf.get_collection("generator")
with tf.control_dependencies(gen_ups):
# train_gen_op = tf.train.GradientDescentOptimizer(learning_rate=0.0001).minimize(loss=gen_loss, var_list=g_vars)
train_gen_op = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5).minimize(loss=gen_loss, var_list=g_vars)
end_points["real_x"] = real_x
end_points["real_y"] = real_y
end_points["fake_x"] = fake_x
end_points["dis_loss"] = dis_loss
end_points["gen_loss"] = gen_loss
end_points["train_dis_op"] = train_dis_op
end_points["train_gen_op"] = train_gen_op
end_points["global_step"] = global_step
end_points["update_step_op"] = update_step_op
end_points["summaries"] = tf.summary.merge_all()
end_points["dis_loss_p"] = dis_loss_p
end_points["dis_loss_n"] = dis_loss_n
end_points["train_dis_p_op"] = train_dis_p_op
end_points["train_dis_n_op"] = train_dis_n_op
return end_points
def train():
xx, yy = load_mnist("data/mnist")
state = {"st": 0}
batch_size = 32
def get_next_batch():
st = state['st']
ed = (st + batch_size) % len(xx)
state['st'] = ed
if ed > st:
return xx[st:ed, ...], yy[st:ed, ...]
else:
return np.concatenate([xx[st:, ...], xx[:ed, ...]], axis=0), \
np.concatenate([yy[st:, ...], yy[:ed, ...]], axis=0)
end_points = build_gan()
real_x = end_points["real_x"]
real_y = end_points["real_y"]
fake_x = end_points["fake_x"]
dis_loss = end_points["dis_loss"]
gen_loss = end_points["gen_loss"]
train_dis_op = end_points["train_dis_op"]
train_gen_op = end_points["train_gen_op"]
global_step = end_points["global_step"]
update_step_op = end_points["update_step_op"]
summaries = end_points["summaries"]
dis_loss_p = end_points["dis_loss_p"]
dis_loss_n = end_points["dis_loss_n"]
train_dis_p_op = end_points["train_dis_p_op"]
train_dis_n_op = end_points["train_dis_n_op"]
sv = tf.train.Supervisor(logdir="logdir")
with sv.managed_session() as sess:
step = sess.run(global_step)
while step < 200000:
bx, by = get_next_batch()
if step % 100 == 0:
_, dls, fx, sumr = sess.run([train_dis_op, dis_loss, fake_x, summaries],
feed_dict={real_x: bx, real_y: by})
# _, dls_p = sess.run([train_dis_p_op, dis_loss_p],
# feed_dict={real_x: get_next_batch()})
# _, dls_n, fx, sumr = sess.run([train_dis_n_op, dis_loss_n, fake_x, summaries])
# dls = (dls_p + dls_n) / 2
sv.summary_writer.add_summary(sumr, step)
else:
_, dls, fx = sess.run([train_dis_op, dis_loss, fake_x],
feed_dict={real_x: bx, real_y: by})
# _, dls_p = sess.run([train_dis_p_op, dis_loss_p],
# feed_dict={real_x: get_next_batch()})
# _, dls_n, fx = sess.run([train_dis_n_op, dis_loss_n, fake_x])
# dls = (dls_p + dls_n) / 2
_, gls0, fx = sess.run([train_gen_op, gen_loss, fake_x])
_, gls1, fx = sess.run([train_gen_op, gen_loss, fake_x])
_, gls2, fx = sess.run([train_gen_op, gen_loss, fake_x])
_, step = sess.run([update_step_op, global_step])
print("%08d:\t%0.8f\t%0.8f" % (step, dls, (gls0 + gls1 + gls2) / 3))
def load_mnist(data_dir):
fd = open(os.path.join(data_dir, 'train-images-idx3-ubyte'))
loaded = np.fromfile(file=fd, dtype=np.uint8)
trX = loaded[16:].reshape((60000, 28, 28, 1)).astype(np.float)
fd = open(os.path.join(data_dir, 'train-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd, dtype=np.uint8)
trY = loaded[8:].reshape((60000)).astype(np.float)
fd = open(os.path.join(data_dir, 't10k-images-idx3-ubyte'))
loaded = np.fromfile(file=fd, dtype=np.uint8)
teX = loaded[16:].reshape((10000, 28, 28, 1)).astype(np.float)
fd = open(os.path.join(data_dir, 't10k-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd, dtype=np.uint8)
teY = loaded[8:].reshape((10000)).astype(np.float)
trY = np.asarray(trY)
teY = np.asarray(teY)
X = np.concatenate((trX, teX), axis=0)
y = np.concatenate((trY, teY), axis=0).astype(np.int)
seed = 547
np.random.seed(seed)
np.random.shuffle(X)
np.random.seed(seed)
np.random.shuffle(y)
y_vec = np.zeros((len(y), 10), dtype=np.float32)
for i, label in enumerate(y):
y_vec[i, y[i]] = 1.0
return X / 255., y_vec
if __name__ == '__main__':
train()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/jiang_sir/GAN.git
[email protected]:jiang_sir/GAN.git
jiang_sir
GAN
GAN
master

搜索帮助