1 Star 0 Fork 0

ChangFeng2015/InsightFace_TF

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
eval_ckpt_file.py 3.00 KB
一键复制 编辑 原始数据 按行查看 历史
auroua 提交于 2018-05-04 12:22 . update readme file and fix some errors.
import tensorflow as tf
import argparse
from data.eval_data_reader import load_bin
from losses.face_losses import arcface_loss
from nets.L_Resnet_E_IR import get_resnet
import tensorlayer as tl
from verification import ver_test
def get_args():
parser = argparse.ArgumentParser(description='input information')
parser.add_argument('--ckpt_file', default='/home/aurora/workspaces2018/InsightFace_TF/output/ckpt_model_c/InsightFace_iter_best_',
type=str, help='the ckpt file path')
# parser.add_argument('--eval_datasets', default=['lfw', 'cfp_ff', 'cfp_fp', 'agedb_30'], help='evluation datasets')
parser.add_argument('--eval_datasets', default=['agedb_30'], help='evluation datasets')
parser.add_argument('--eval_db_path', default='./datasets/faces_ms1m_112x112', help='evluate datasets base path')
parser.add_argument('--image_size', default=[112, 112], help='the image size')
parser.add_argument('--net_depth', default=50, help='resnet depth, default is 50')
parser.add_argument('--num_output', default=85164, help='the image size')
parser.add_argument('--batch_size', default=32, help='batch size to train network')
parser.add_argument('--ckpt_index_list',
default=['1950000.ckpt'], help='ckpt file indexes')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_args()
ver_list = []
ver_name_list = []
for db in args.eval_datasets:
print('begin db %s convert.' % db)
data_set = load_bin(db, args.image_size, args)
ver_list.append(data_set)
ver_name_list.append(db)
images = tf.placeholder(name='img_inputs', shape=[None, *args.image_size, 3], dtype=tf.float32)
labels = tf.placeholder(name='img_labels', shape=[None, ], dtype=tf.int64)
dropout_rate = tf.placeholder(name='dropout_rate', dtype=tf.float32)
w_init_method = tf.contrib.layers.xavier_initializer(uniform=False)
net = get_resnet(images, args.net_depth, type='ir', w_init=w_init_method, trainable=False, keep_rate=dropout_rate)
embedding_tensor = net.outputs
# mv_mean = tl.layers.get_variables_with_name('resnet_v1_50/bn0/moving_mean', False, True)[0]
# 3.2 get arcface loss
logit = arcface_loss(embedding=net.outputs, labels=labels, w_init=w_init_method, out_num=args.num_output)
sess = tf.Session()
saver = tf.train.Saver()
result_index = []
for file_index in args.ckpt_index_list:
feed_dict_test = {}
path = args.ckpt_file + file_index
saver.restore(sess, path)
print('ckpt file %s restored!' % file_index)
feed_dict_test.update(tl.utils.dict_to_one(net.all_drop))
feed_dict_test[dropout_rate] = 1.0
results = ver_test(ver_list=ver_list, ver_name_list=ver_name_list, nbatch=0, sess=sess,
embedding_tensor=embedding_tensor, batch_size=args.batch_size, feed_dict=feed_dict_test,
input_placeholder=images)
result_index.append(results)
print(result_index)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/chengfeng2015_bolt/InsightFace_TF.git
[email protected]:chengfeng2015_bolt/InsightFace_TF.git
chengfeng2015_bolt
InsightFace_TF
InsightFace_TF
master

搜索帮助