1 Star 0 Fork 0

mitbaiyun/tensorflow_demo_yixue

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 3.18 KB
一键复制 编辑 原始数据 按行查看 历史
yaoyaowd 提交于 2017-12-18 14:44 . pix2pix and cgan
# Original Version: Taehoon Kim (http://carpedm20.github.io)
# + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/utils.py
# + License: MIT
"""
Some codes from https://github.com/Newmu/dcgan_code
"""
from __future__ import division
import scipy.misc
import scipy.io
import numpy as np
import os
import sys
import tarfile
import zipfile
from six.moves import urllib
def get_vgg_model(dir_path, model_url):
maybe_download_and_extract(dir_path, model_url)
filename = model_url.split("/")[-1]
filepath = os.path.join(dir_path, filename)
if not os.path.exists(filepath):
raise IOError("VGG Model not found!")
data = scipy.io.loadmat(filepath)
return data
def maybe_download_and_extract(dir_path, url_name, is_tarfile=False, is_zipfile=False):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
filename = url_name.split('/')[-1]
filepath = os.path.join(dir_path, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write(
'\r>> Downloading %s %.1f%%' % (filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(url_name, filepath, reporthook=_progress)
print()
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
if is_tarfile:
tarfile.open(filepath, 'r:gz').extractall(dir_path)
elif is_zipfile:
with zipfile.ZipFile(filepath) as zf:
zf.extractall(dir_path)
def get_image(image_path, image_size, is_crop=True):
return transform(imread(image_path), image_size, is_crop)
def save_images(images, size, image_path):
return imsave(inverse_transform(images), size, image_path)
def save_image(image, save_dir, name):
scipy.misc.imsave(os.path.join(save_dir, name + '.png'), image)
def imread(path, size=None):
input = scipy.misc.imread(path, mode='RGB').astype(np.float32)
if not size:
return input
else:
return scipy.misc.imresize(input, (size, size))
def merge(images, size):
h, w = images.shape[1], images.shape[2]
img = np.zeros((int(h * size[0]), int(w * size[1]), 3))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j*h:j*h+h, i*w:i*w+w, :] = image
return img
def imsave(images, size, path):
img = merge(images, size)
return scipy.misc.imsave(path, (255*img).astype(np.uint8))
def center_crop(x, crop_h, crop_w=None, resize_w=64):
if crop_w is None:
crop_w = crop_h
h, w = x.shape[:2]
j = int(round((h - crop_h)/2.))
i = int(round((w - crop_w)/2.))
return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w],
[resize_w, resize_w])
def transform(image, npx=64, is_crop=True, resize_w=64):
# npx : # of pixels width/height of image
if is_crop:
cropped_image = center_crop(image, npx, resize_w=resize_w)
else:
cropped_image = image
return np.array(cropped_image)/127.5 - 1.
def inverse_transform(images):
return (images+1.)/2.
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mitbaiyun/tensorflow_demo_yixue.git
[email protected]:mitbaiyun/tensorflow_demo_yixue.git
mitbaiyun
tensorflow_demo_yixue
tensorflow_demo_yixue
master

搜索帮助