6 Star 28 Fork 6

东南大学-王雁刚/CVPR2020-OOH

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
modules.py 11.05 KB
一键复制 编辑 原始数据 按行查看 历史
boycehbz 提交于 2022-03-15 22:58 . add fitting mesh code
import time
import os
from utils.logger import Logger
import yaml
from utils.uv_map_generator import UV_Map_Generator
from utils.smpl_torch_batch import SMPLModel
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from utils.imutils import uv_to_torch_noModifyChannel, img_reshape, est_trans, convert_color
import cv2
from utils.resample import resample_torch, resample_np, batch_resample_np
import numpy as np
import torch.utils.data as data
from utils.imutils import im_to_torch
from utils.render import Renderer
from utils.fitting.SMPLfitting import SMPLfitting
import pickle
min_batch = 0
def init(note='occlusion', dtype=torch.float32, **kwargs):
# Create the folder for the current experiment
mon, day, hour, min, sec = time.localtime(time.time())[1:6]
out_dir = os.path.join('output', note)
out_dir = os.path.join(out_dir, '%02d.%02d-%02dh%02dm%02ds' %(mon, day, hour, min, sec))
if not os.path.exists(out_dir):
os.makedirs(out_dir)
# Create the log for the current experiment
logger = Logger(os.path.join(out_dir, 'log.txt'), title="occlusion")
logger.set_names([note])
logger.set_names(['%02d/%02d-%02dh%02dm%02ds' %(mon, day, hour, min, sec)])
# Store the arguments for the current experiment
conf_fn = os.path.join(out_dir, 'conf.yaml')
with open(conf_fn, 'w') as conf_file:
yaml.dump(kwargs, conf_file)
# load smpl model
model_smpl = SMPLModel(
device=torch.device('cpu'),
model_path='./data/SMPL_NEUTRAL.pkl',
dtype=dtype,
)
# load UV generator
generator = UV_Map_Generator(
UV_height=256,
UV_pickle='./data/param.pkl' #separate UV map
#totalhuman.pickle #connecting UV map
)
# load virtual occlusion
if kwargs.get('virtual_mask'):
occlusion_folder = os.path.join(kwargs.get('data_folder'), 'occlusion/images')
occlusions = [os.path.join(occlusion_folder, k) for k in os.listdir(occlusion_folder)]
else:
occlusions = None
return out_dir, logger, model_smpl, generator, occlusions
class ImageLoader(data.Dataset):
def __init__(self, data_folder='./data', **kwargs):
self.images = [os.path.join(data_folder, img) for img in os.listdir(data_folder)]
self.len = len(self.images)
def create_UV_maps(self, index=0):
data = {}
image_path = self.images[index]
image = cv2.imread(image_path)
h, w = image.shape[:2]
if h != 256 or w != 256:
max_size = max(h, w)
ratio = 256/max_size
image = cv2.resize(image, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_CUBIC)
image = img_reshape(image)
assert image.shape[0] == 256 and image.shape[1] == 256 , "The image size must be 256*256*3"
dst_image = image
inp = im_to_torch(dst_image)
data['img'] = inp
return data
def __getitem__(self, index):
data = self.create_UV_maps(index)
return data
def __len__(self):
return self.len
class ModelLoader():
def __init__(self, model=None, lr=0.001, device=torch.device('cpu'), pretrain=False, pretrain_dir='', output='', smpl=None, generator=None, uv_mask=None, batchsize=10, fitting=False, **kwargs):
self.smpl = smpl
self.generator = generator
self.output = output
self.batchsize = batchsize
self.model_type = model
exec('from model.' + self.model_type + ' import ' + self.model_type)
self.model = eval(self.model_type)()
self.device = device
#if uv_mask:
self.uv_mask = cv2.imread('./data/MASK.png')
if self.uv_mask.max() > 1:
self.uv_mask = self.uv_mask / 255.
if fitting:
self.fitting = SMPLfitting()
else:
self.fitting = None
print('load model: %s' %self.model_type)
model_params = 0
for parameter in self.model.segnet.parameters():
model_params += parameter.numel()
print('INFO: Model parameter count:', model_params)
self.render = Renderer()
if torch.cuda.is_available():
self.model.to(self.device)
print("device: cuda")
else:
print("device: cpu")
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', factor=0.1, patience=1, verbose=True)
# load pretrain parameters
if pretrain:
model_dict = self.model.state_dict()
premodel_dict = torch.load(pretrain_dir).state_dict()
premodel_dict = {k: v for k ,v in premodel_dict.items() if k in model_dict}
model_dict.update(premodel_dict)
self.model.load_state_dict(model_dict)
print("load pretrain parameters from %s" %pretrain_dir)
# load fixed model
if kwargs.get('task') == 'latent':
fixmodel_dir = kwargs.pop('fixmodel_dir')
if self.model_type == 'latent':
exec('from model.inpainting import inpainting')
self.inpainting = eval('inpainting')()
else:
exec('from model.inpaintingR50 import inpaintingR50')
self.inpainting = eval('inpaintingR50')()
inpainting_dict = self.inpainting.state_dict()
fixmodel_dict = torch.load(fixmodel_dir).state_dict()
fixmodel_dict = {k: v for k, v in fixmodel_dict.items() if k in inpainting_dict}
inpainting_dict.update(fixmodel_dict)
self.inpainting.load_state_dict(inpainting_dict)
for param in self.inpainting.parameters():
param.requires_grad = False
self.inpainting.to(self.device)
print("load fixed model from %s" %fixmodel_dir)
def save_results(self, results, iter):
"""
object order:
"""
output = os.path.join(self.output, 'images')
if not os.path.exists(output):
os.makedirs(output)
for item in results:
index = 0
opt = results[item]
for img in opt:
img_name = "%05d_%s.jpg" % (iter * self.batchsize + index, item)
img = img.transpose(1, 2, 0) # H*W*C
# save mesh
if item == 'pred' or item == 'uv_gt':
resample_img = img.copy()
resample_img = resample_img * self.uv_mask
resampled_mesh = resample_np(self.generator, resample_img)
self.smpl.write_obj(
resampled_mesh, os.path.join(output, '%05d_%s_mesh.obj' %(iter * self.batchsize + index, item) )
)
# save img
if item == 'pred' or item == 'uv_gt' or item == 'uv_in':
img = img * self.uv_mask
img = (img + 0.5) * 255
else:
img = img * 255
cv2.imwrite(os.path.join(output, img_name), img)
index += 1
def save_results_render(self, results, iter):
"""
object order:
"""
global min_batch
regressor = self.smpl.joint_regressor.clone().cpu().numpy()
output = os.path.join(self.output, 'images')
if not os.path.exists(output):
os.makedirs(output)
imgs = results['img'].transpose(0, 2, 3, 1) * 255
masks = results['mask'].transpose(0, 2, 3, 1) * 255
heatmaps = results['heatmap'].transpose(0, 2, 3, 1)
preds = results['pred'].transpose(0, 2, 3, 1)
if len(imgs) > min_batch:
min_batch = len(imgs)
preds = preds * self.uv_mask
resampled_meshes = (batch_resample_np(self.generator, preds) + 0.5) * 2
joint3ds = np.matmul(regressor, resampled_meshes)
joint2ds = np.zeros((heatmaps.shape[0], heatmaps.shape[-1], 3))
confidence = np.max(heatmaps, axis=(1,2))
confidence[np.where(confidence < 0.3)] = 0
joint2ds[:,:,-1] = confidence
for index, (img, mask, heatmap, mesh, joint3d, joint2d, pred) in enumerate(zip(imgs, masks, heatmaps, resampled_meshes, joint3ds, joint2ds, preds)):
for j in range(heatmap.shape[-1]):
if joint2d[j][2] < 0.3:
continue
joint2d[j][:2] = np.mean(np.where(heatmap[:,:,j]==joint2d[j][2]), axis=1)[::-1]
if self.fitting is not None:
print('Fit %05d model' % (iter * min_batch + index))
params, mesh = self.fitting(mesh)
param_name = "%05d_param.pkl" % (iter * min_batch + index)
result_fn = os.path.join(output, param_name)
with open(result_fn, 'wb') as result_file:
pickle.dump(params, result_file, protocol=2)
# (mesh_3d, mesh_2d), cam_t = wp_project_render(resampled_mesh, joints_lsp, joints_2ds, img, focal=1000)
rot, trans, intri = est_trans(mesh, joint3d, joint2d, img, focal=1000)
render_out = self.render(mesh, self.smpl.faces, rot.copy(), trans.copy(), intri.copy(), img.copy(), color=[1,1,0.9])
# self.render.vis_img('render', render_out)
self.smpl.write_obj(
mesh, os.path.join(output, '%05d_pred_mesh.obj' %(iter * min_batch + index) )
)
uv_name = "%05d_uv.jpg" % (iter * min_batch + index)
cv2.imwrite(os.path.join(output, uv_name), (pred+0.5) * 255)
# render = draw_smpl(mesh_2d, self.smpl.faces, img.copy())
render_name = "%05d_render.jpg" % (iter * min_batch + index)
cv2.imwrite(os.path.join(output, render_name), render_out)
heatmap = np.max(heatmap, axis=2)
heatmap = convert_color(heatmap*255)
heatmap = cv2.addWeighted(img.astype(np.uint8), 0.5, heatmap.astype(np.uint8),0.5,0)
heatmap_name = "%05d_heatmap.jpg" % (iter * min_batch + index)
cv2.imwrite(os.path.join(output, heatmap_name), heatmap)
mask = convert_color(mask)
mask = cv2.addWeighted(img.astype(np.uint8), 0.5, mask.astype(np.uint8),0.5,0)
mask_name = "%05d_mask.jpg" % (iter * min_batch + index)
cv2.imwrite(os.path.join(output, mask_name), mask)
img_name = "%05d_img.jpg" % (iter * min_batch + index)
cv2.imwrite(os.path.join(output, img_name), img)
def viz_result(self, rgb_img=None, masks=None, pred=None):
masks = masks.detach().data.cpu().numpy().astype(np.float32)
rgb_image = rgb_img.detach().data.cpu().numpy().astype(np.float32)
img_decoded = pred.detach().data.cpu().numpy().astype(np.float32)
for mask, rgb, img_d in zip(masks, rgb_image, img_decoded):
mask = mask.transpose(1,2,0)
rgb = rgb.transpose(1,2,0)
img_d = img_d.transpose(1,2,0)
img_d = img_d * self.uv_mask
cv2.imshow("mask",(mask))
cv2.imshow("rgb_img",rgb)
cv2.imshow("d_img",(img_d+0.5))
cv2.waitKey()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/seuvcl/CVPR2020-OOH.git
[email protected]:seuvcl/CVPR2020-OOH.git
seuvcl
CVPR2020-OOH
CVPR2020-OOH
master

搜索帮助