1 Star 0 Fork 0

tristenqaq/good-bad-orange

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 2.60 KB
一键复制 编辑 原始数据 按行查看 历史
tristenqaq 提交于 2022-02-01 14:03 . 首次提交
import os
from input_data import *
from model import *
import matplotlib.pyplot as plt
import configparser
def test(test_dir, logs_dir, img_size):
N_CLASS = 2
IMG_SIZE = img_size
BATCH_SIZE = 1
CAPACITY = 200
MAX_STEP = 100
LIST_CHANNELS = [3, 16, 32, 128, 128]
keep_prob = tf.placeholder(tf.float32)
sess = tf.Session()
train_list = get_train_files(test_dir, random=True)
image_train_batch, label_train_batch = get_train_batch(train_list, IMG_SIZE, BATCH_SIZE, CAPACITY, True)
softmax = inference(image_train_batch, N_CLASS, LIST_CHANNELS, "test", keep_prob)
# 载入检查点
print("载入检查点...")
save = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(logs_dir)
if ckpt and ckpt.model_checkpoint_path:
global_step = ckpt.model_checkpoint_path.split("/")[-1].split("-")[-1]
save.restore(sess, ckpt.model_checkpoint_path)
print("载入成功,global_step=%s" % global_step)
else:
print("没有找到检查点")
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
fig = plt.figure(figsize=(16, 8))
fig_list = list(range(10))
for i in range(10):
fig_list[i] = fig.add_subplot(2, 5, i + 1)
try:
for step in range(MAX_STEP):
if coord.should_stop():
break
image, prediction = sess.run([image_train_batch, softmax], feed_dict={keep_prob: 1})
max_index = np.argmax(prediction) # 取出prediction中最大值对应的索引
if max_index == 0:
label = "%.2f%% is a Label_0." % (prediction[0][0] * 100)
else:
label = "%.2f%% is a Label_1." % (prediction[0][1] * 100)
fig_list[step % 10].set_title(label, fontsize=10, y=1.02)
fig_list[step % 10].imshow(image[0])
if (step + 1) % 10 == 0:
plt.draw()
plt.pause(5)
# input("input any key to continue...")
plt.clf()
if step + 1 != MAX_STEP:
for i in range(10):
fig_list[i] = fig.add_subplot(2, 5, i + 1)
plt.close()
except tf.errors.OutOfRangeError:
print("Done.")
finally:
coord.request_stop()
coord.join(threads=threads)
sess.close()
def main():
config = configparser.ConfigParser()
config.read("config.ini", encoding="utf-8")
IMG_SIZE = int(config.get("section_2", "IMG_SIZE"))
DATA = r"data\test"
LOGS = "logs"
test(DATA, LOGS, IMG_SIZE)
if __name__ == '__main__':
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/tristenqaq/good-bad-orange.git
[email protected]:tristenqaq/good-bad-orange.git
tristenqaq
good-bad-orange
good-bad-orange
master

搜索帮助