1 Star 1 Fork 0

令1/DCGAN无监督数据生成

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
dcgan.py 3.69 KB
一键复制 编辑 原始数据 按行查看 历史
baozaoderenlei 提交于 2023-02-13 10:44 . first commit
import itertools
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch import nn
from nets.dcgan import generator
from utils.utils import postprocess_output, show_config
class DCGAN(object):
_defaults = {
#-----------------------------------------------#
# model_path指向logs文件夹下的权值文件
#-----------------------------------------------#
"model_path" : 'logs/G_Epoch350-GLoss26.0214-DLoss0.0802.pth',
#-----------------------------------------------#
# 卷积通道数的设置
#-----------------------------------------------#
"channel" : 64,
#-----------------------------------------------#
# 输入图像大小的设置
#-----------------------------------------------#
"input_shape" : [256, 256],
#-------------------------------#
# 是否使用Cuda
# 没有GPU可以设置成False
#-------------------------------#
"cuda" : True,
}
#---------------------------------------------------#
# 初始化DCGAN
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
self._defaults[name] = value
self.generate()
show_config(**self._defaults)
def generate(self):
#----------------------------------------#
# 创建GAN模型
#----------------------------------------#
self.net = generator(self.channel, self.input_shape).eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
self.net = self.net.eval()
print('{} model loaded.'.format(self.model_path))
if self.cuda:
self.net = nn.DataParallel(self.net)
self.net = self.net.cuda()
#---------------------------------------------------#
# 生成5x5的图片
#---------------------------------------------------#
# def generate_5x5_image(self, save_path):
# with torch.no_grad():
# randn_in = torch.randn((5*5, 100))
# if self.cuda:
# randn_in = randn_in.cuda()
#
# test_images = self.net(randn_in)
#
# size_figure_grid = 5
# fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
# for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
# ax[i, j].get_xaxis().set_visible(False)
# ax[i, j].get_yaxis().set_visible(False)
#
# for k in range(5*5):
# i = k // 5
# j = k % 5
# ax[i, j].cla()
# ax[i, j].imshow(np.uint8(postprocess_output(test_images[k].cpu().data.numpy().transpose(1, 2, 0))))
#
# label = 'predict_5x5_results'
# fig.text(0.5, 0.04, label, ha='center')
# plt.savefig(save_path)
#---------------------------------------------------#
# 生成1x1的图片
#---------------------------------------------------#
def generate_1x1_image(self, save_path):
with torch.no_grad():
randn_in = torch.randn((1, 100))
if self.cuda:
randn_in = randn_in.cuda()
test_images = self.net(randn_in)
test_images = postprocess_output(test_images[0].cpu().data.numpy().transpose(1, 2, 0))
Image.fromarray(np.uint8(test_images)).save(save_path)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhou-xuanling/dcgan-unsupervised-data.git
[email protected]:zhou-xuanling/dcgan-unsupervised-data.git
zhou-xuanling
dcgan-unsupervised-data
DCGAN无监督数据生成
master

搜索帮助