1 Star 0 Fork 0

tristenqaq/good-bad-orange

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
input_data.py 4.94 KB
一键复制 编辑 原始数据 按行查看 历史
tristenqaq 提交于 2022-02-01 14:03 . 首次提交
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
def get_train_files(path, random=True):
'''
获取训练图片路径与分类标签
:param path:string,图片目录
:param random:bool,是否对图片乱序
:return:[[],[]],图片路径与对应标签
'''
image_list = []
label_list = []
orange_b = 0
orange_g = 0
list = os.listdir(path)
if not list:
raise ValueError("训练集为空!")
for f in list:
f_path = path + "\\" + f
f_label = f.split(".")[0]
if not os.path.isfile(f_path):
raise ValueError("训练集中有非文件项!")
image_list.append(f_path)
if f_label == "0":
label_list.append(0)
orange_b += 1
else:
label_list.append(1)
orange_g += 1
print("训练集中有{}个Label_0,{}个Label_1。".format(orange_b, orange_g))
image_list = np.array(image_list)
label_list = np.array(label_list)
if random == True: # 如果需要乱序
random_index = np.arange(len(image_list)) # 返回固定步长序列
np.random.shuffle(random_index) # 打乱random_index顺序
# 根据打乱序列的random_index打乱image_list与label_list
image_list = image_list[random_index]
label_list = label_list[random_index]
print(image_list[0])
return image_list, label_list
def get_train_batch(train_list, image_size, batch_size, capacity, random=True):
'''
获取训练批次
:param train_list:[[],[]],get_train_files返回值,[image_list,label_list]
:param image_size:int,图片尺寸调整大小
:param batch_size:int,一个批次图片数量
:param capacity:int,队列容量
:param random:bool,是否乱序
:return:
'''
# tensorflow数据读取,文件队列方式,train_queue是一个文件名队列
train_queue = tf.train.slice_input_producer(train_list, shuffle=False)
# tf.train.slice_input_producer()可以读取tensor_list,tf.train.string_input_producer只能读取string_tensor
# 读取图片
image_train = tf.read_file(train_queue[0])
# 将图片解码为3维张量,像素点位置+像素点灰度值或RGB值,channels为颜色通道数
image_train = tf.image.decode_jpeg(image_train, channels=3)
# 调整图片尺寸,method = 0、1、2、3,双线性插值法、最近邻居法、双三次插值法、面积插值法,默认=0
# 这里不调整图片尺寸/不调用tf.image.resize_images(),下面的会报shape错,等待研究
image_train = tf.image.resize_images(image_train, [image_size, image_size])
# 数据类型转换,并归一化
image_train = tf.cast(image_train, tf.float32) / 255.0
# 读取标签
label_train = train_queue[1]
# 获取批次
if random:
# 通过随机打乱张量的顺序创建批次。采用队列进行读取
# 当一次出列操作完成后,队列中元素的最小数量。队列中元素大于它的时候就输出乱序的bacth
image_train_batch, label_train_batch = tf.train.shuffle_batch([image_train, label_train],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=100,
num_threads=4) # 多线程读取
else:
image_train_batch, label_train_batch = tf.train.batch([image_train, label_train],
batch_size=1,
capacity=capacity,
num_threads=1)
print("-----")
print(image_train_batch)
print("-----")
return image_train_batch, label_train_batch
if __name__ == '__main__':
image_dir = "data\\train"
train_list = get_train_files(image_dir, True)
print(train_list)
image_train_batch, label_train_batch = get_train_batch(train_list, 208, 1, 200, False)
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
fig = plt.figure()
for step in range(10):
if coord.should_stop(): # 查询是否应该终止所有线程
break
image_batch, label_batch = sess.run([image_train_batch, label_train_batch])
if label_batch[0] == 0:
label = 'bad'
else:
label = 'good'
fig = plt.figure()
plt.imshow(image_batch[0])
plt.title(label)
plt.draw()
plt.pause(0.2)
plt.close(fig)
except tf.errors.OutOfRangeError:
print('Done.')
finally:
coord.request_stop()
coord.join(threads=threads)
sess.close()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/tristenqaq/good-bad-orange.git
[email protected]:tristenqaq/good-bad-orange.git
tristenqaq
good-bad-orange
good-bad-orange
master

搜索帮助