Skip to content
Snippets Groups Projects
Commit a4d83119 authored by liuxingyu's avatar liuxingyu
Browse files

rm dr utils

parent 9deff8b8
No related branches found
No related tags found
No related merge requests found
Showing
with 0 additions and 4220 deletions
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
from __future__ import print_function
from __future__ import division
import torch
import torch.nn
##################################################
def perspective_projection(points_bxpx3, faces_fx3, cameras):
# perspective, use just one camera intrinc parameter
camera_rot_bx3x3, camera_pos_bx3, camera_proj_3x1 = cameras
if camera_proj_3x1.shape[-1] == 4: # 4x4 proj
# NOTE: use real camera perspective projection
return perspective_projection_real(points_bxpx3, faces_fx3, cameras)
cameratrans_rot_bx3x3 = camera_rot_bx3x3.permute(0, 2, 1)
# follow pixel2mesh!!!
# new_p = cam_mat * (old_p - cam_pos)
# NOTE: make sure here points_bxpx3 is not in-place modified
points_bxpx3 = points_bxpx3 - camera_pos_bx3.view(-1, 1, 3)
points_bxpx3 = torch.matmul(points_bxpx3, cameratrans_rot_bx3x3)
camera_proj_bx1x3 = camera_proj_3x1.view(-1, 1, 3)
xy_bxpx3 = points_bxpx3 * camera_proj_bx1x3
xy_bxpx2 = xy_bxpx3[:, :, :2] / xy_bxpx3[:, :, 2:3]
##########################################################
# 1 points
pf0_bxfx3 = points_bxpx3[:, faces_fx3[:, 0], :]
pf1_bxfx3 = points_bxpx3[:, faces_fx3[:, 1], :]
pf2_bxfx3 = points_bxpx3[:, faces_fx3[:, 2], :]
points3d_bxfx9 = torch.cat((pf0_bxfx3, pf1_bxfx3, pf2_bxfx3), dim=2)
xy_f0 = xy_bxpx2[:, faces_fx3[:, 0], :]
xy_f1 = xy_bxpx2[:, faces_fx3[:, 1], :]
xy_f2 = xy_bxpx2[:, faces_fx3[:, 2], :]
points2d_bxfx6 = torch.cat((xy_f0, xy_f1, xy_f2), dim=2)
######################################################
# 2 normals
v01_bxfx3 = pf1_bxfx3 - pf0_bxfx3
v02_bxfx3 = pf2_bxfx3 - pf0_bxfx3
# bs cannot be 3, if it is 3, we must specify dim
normal_bxfx3 = torch.cross(v01_bxfx3, v02_bxfx3, dim=2)
return points3d_bxfx9, points2d_bxfx6, normal_bxfx3
def perspective_projection_real(points_bxpx3, faces_fx3, cameras):
# perspective, use just one camera intrinc parameter
camera_rot_bx3x3, camera_pos_bx3, camera_proj_4x4 = cameras
cameratrans_rot_bx3x3 = camera_rot_bx3x3.permute(0, 2, 1)
# follow pixel2mesh!!!
# new_p = cam_mat * (old_p - cam_pos)
# NOTE: make sure here points_bxpx3 is not in-place modified
points_bxpx3 = points_bxpx3 - camera_pos_bx3.view(-1, 1, 3)
points_bxpx3 = torch.matmul(points_bxpx3, cameratrans_rot_bx3x3)
b, p = points_bxpx3.shape[:2]
points_bxpx4 = points_bxpx3.new_ones(b, p, 4)
points_bxpx4[:, :, :3] = points_bxpx3
camera_proj_bx4x4 = camera_proj_4x4.view(-1, 4, 4)
xy_bxpx4 = torch.matmul(points_bxpx4, camera_proj_bx4x4)
xy_bxpx2 = xy_bxpx4[:, :, :2] / xy_bxpx4[:, :, 3:4]
##########################################################
# 1 points
pf0_bxfx3 = points_bxpx3[:, faces_fx3[:, 0], :]
pf1_bxfx3 = points_bxpx3[:, faces_fx3[:, 1], :]
pf2_bxfx3 = points_bxpx3[:, faces_fx3[:, 2], :]
points3d_bxfx9 = torch.cat((pf0_bxfx3, pf1_bxfx3, pf2_bxfx3), dim=2)
xy_f0 = xy_bxpx2[:, faces_fx3[:, 0], :]
xy_f1 = xy_bxpx2[:, faces_fx3[:, 1], :]
xy_f2 = xy_bxpx2[:, faces_fx3[:, 2], :]
points2d_bxfx6 = torch.cat((xy_f0, xy_f1, xy_f2), dim=2)
######################################################
# 2 normals
v01_bxfx3 = pf1_bxfx3 - pf0_bxfx3
v02_bxfx3 = pf2_bxfx3 - pf0_bxfx3
# bs cannot be 3, if it is 3, we must specify dim
normal_bxfx3 = torch.cross(v01_bxfx3, v02_bxfx3, dim=2)
return points3d_bxfx9, points2d_bxfx6, normal_bxfx3
import os
import os.path as osp
import numpy as np
from . import DIBRenderer
import torch
from tqdm import tqdm
import cv2
from core.utils.pose_utils import quat2mat_torch
from lib.pysixd import inout, misc
from lib.dr_utils.rep import TriangleMesh
def load_ply_models(
obj_paths,
texture_paths=None,
vertex_scale=0.001,
device="cuda",
width=512,
height=512,
tex_resize=False,
):
"""
NOTE: ignore width and height if tex_resize=False
Args:
vertex_scale: default 0.001 is used for bop models!
tex_resize: resize the texture to smaller size for GPU memory saving
Returns:
a list of dicts
"""
assert all([".obj" in _path for _path in obj_paths])
models = []
for i, obj_path in enumerate(tqdm(obj_paths)):
model = {}
mesh = TriangleMesh.from_obj(obj_path)
vertices = mesh.vertices[:, :3] # x,y,z
colors = mesh.vertices[:, 3:6] # rgb
faces = mesh.faces.int()
# normalize verts ( - center)
vertices_max = vertices.max()
vertices_min = vertices.min()
vertices_middle = (vertices_max + vertices_min) / 2.0
vertices = vertices - vertices_middle
model["vertices"] = vertices[:, :].to(device)
model["colors"] = colors[:, :].to(device)
model["faces"] = faces[:, :].to(device) # NOTE: -1
if texture_paths is not None:
texture = cv2.imread(texture_paths[i], cv2.IMREAD_COLOR)[:, :, ::-1].astype(np.float32) / 255.0
if tex_resize:
texture = cv2.resize(texture, (width, height), interpolation=cv2.INTER_AREA)
# CHW
texture = torch.from_numpy(texture.transpose(2, 0, 1)).to(device)
model["face_uvs"] = mesh.uvs[:, :].to(device)
model["face_uv_ids"] = mesh.face_textures[:, :].to(device)
model["texture"] = texture
# NOTE: texture_uv is None
model["texture_uv"] = None
models.append(model)
return models
class Renderer_dibr(object):
def __init__(self, height, width, mode):
self.dib_ren = DIBRenderer(height, width, mode)
def render_scene(
self,
Rs,
ts,
models,
*,
K,
width,
height,
znear=0.01,
zfar=100,
rot_type="mat",
with_mask=False,
with_depth=True,
):
"""render a scene with m>=1 objects
Args:
Rs: [m,3,3] or [m,4] tensor
ts: [m,3,] tensor
models: list of dicts, each stores {"vertices":, "colors":, "faces":, }
K: [3,3]
Returns:
a dict:
color: (h,w,3)
mask: (h,w) fg mask
depth: (h,w)
"""
ret = {}
self.scene_ren = DIBRenderer(height, width, mode="VertexColorMulti")
self.scene_ren.set_camera_parameters_from_RT_K(
Rs, ts, K, height, width, near=znear, far=zfar, rot_type=rot_type
)
colors = [model["colors"][None] for model in models] # m * [1, p, 3]
points = [[model["vertices"][None], model["faces"].long()] for model in models]
# points: list of [vertices, faces]
# colors: list of colors
color, im_prob, _, im_mask = self.scene_ren.forward(points=points, colors=colors)
ret["color"] = color.squeeze()
ret["prob"] = im_prob.squeeze()
ret["mask"] = im_mask.squeeze()
if with_depth:
# transform xyz
if not isinstance(Rs, torch.Tensor):
Rs = torch.stack(Rs) # list
if rot_type == "quat":
R_mats = quat2mat_torch(Rs)
else:
R_mats = Rs
xyzs = [
misc.transform_pts_Rt_th(model["vertices"], R_mats[_id], ts[_id])[None]
for _id, model in enumerate(models)
]
ren_xyzs, _, _, _ = self.scene_ren.forward(points=points, colors=xyzs)
ret["depth"] = ren_xyzs[0, :, :, 2] # bhw
# color: hw3; mask: hw; depth: hw
return ret
def render_scene_tex(
self,
Rs,
ts,
models,
*,
K,
width,
height,
znear=0.01,
zfar=100,
rot_type="mat",
uv_type="vertex",
with_mask=False,
with_depth=True,
):
"""render a scene with m>=1 object for textured objects
Args:
Rs: [m,3,3] or [m,4] tensor
ts: [m,3] tensor
models: list of dict, each stores
vertex uv: {"vertices":, "faces":, "texture":, "vertex_uvs":,}
face uv: {"vertices":, "faces":, "texture":, "face_uvs":, "face_uv_ids":,}
K: [3,3]
uv_type: `vertex` | `face`
Returns:
dict:
color: (h,w,3)
mask: (h,w) fg mask (to get instance masks, use batch mode)
depth: (h,w)
"""
ret = {}
self.scene_ren = DIBRenderer(height, width, mode="TextureMulti")
self.scene_ren.set_camera_parameters_from_RT_K(
Rs, ts, K, height, width, near=znear, far=zfar, rot_type=rot_type
)
# points: list of [vertices, faces]
points = [[model["vertices"][None], model["faces"].long()] for model in models]
if uv_type == "vertex":
uv_bxpx2 = [model["vertex_uvs"][None] for model in models]
else: # face uv
uv_bxpx2 = [model["face_uvs"][None] for model in models]
ft_fx3_list = [model["face_uv_ids"] for model in models]
texture_bx3xthxtw = [model["texture"][None] for model in models]
dib_ren_im, dib_ren_prob, _, dib_ren_mask = self.scene_ren.forward(
points=points,
uv_bxpx2=uv_bxpx2,
texture_bx3xthxtw=texture_bx3xthxtw,
ts=ts,
ft_fx3=ft_fx3_list,
)
ret["color"] = dib_ren_im.squeeze()
ret["prob"] = dib_ren_prob.squeeze()
ret["mask"] = dib_ren_mask.squeeze()
if with_depth:
# transform xyz
# NOTE: check whether it should be in [0, 1] (maybe need to record min, max and denormalize later)
if not isinstance(Rs, torch.Tensor):
Rs = torch.stack(Rs) # list
if rot_type == "quat":
R_mats = quat2mat_torch(Rs)
else:
R_mats = Rs
xyzs = [
misc.transform_pts_Rt_th(model["vertices"], R_mats[_id], ts[_id])[None]
for _id, model in enumerate(models)
]
dib_ren_vc_batch = DIBRenderer(height, width, mode="VertexColorMulti")
dib_ren_vc_batch.set_camera_parameters(self.scene_ren.camera_params)
ren_xyzs, _, _, _ = dib_ren_vc_batch.forward(points=points, colors=xyzs)
ret["depth"] = ren_xyzs[0, :, :, 2] # hw
# color: hw3; mask: hw; depth: hw
return ret
def render_batch(
self,
Rs,
ts,
models,
*,
Ks,
width,
height,
znear=0.01,
zfar=100,
rot_type="mat",
mode=["color", "depth"],
):
"""render a batch (vertex color), each contain one object
Args:
Rs (tensor): [b,3,3] or [b,4]
ts (tensor): [b,3,]
models (list of dicts): each stores {"vertices":, "colors":, "faces":, }
Ks (tensor): [b,3,3]
mode: color, depth, mask, xyz (one or more must be given)
Returns:
dict:
color: bhw3
mask: bhw
depth: bhw
xyz: bhw3
probs: bhw
"""
assert self.dib_ren.mode in ["VertexColorBatch"], self.dib_ren.mode
ret = {}
self.dib_ren.set_camera_parameters_from_RT_K(Rs, ts, Ks, height, width, near=znear, far=zfar, rot_type=rot_type)
colors = [model["colors"][None] for model in models] # b x [1, p, 3]
points = [[model["vertices"][None], model["faces"].long()] for model in models]
# points: list of [vertices, faces]
# colors: list of colors
color, im_prob, _, im_mask = self.dib_ren.forward(points=points, colors=colors)
ret["color"] = color
ret["prob"] = im_prob.squeeze(-1)
ret["mask"] = im_mask.squeeze(-1)
if "depth" in mode:
# transform xyz
if not isinstance(Rs, torch.Tensor):
Rs = torch.stack(Rs) # list
if rot_type == "quat":
R_mats = quat2mat_torch(Rs)
else:
R_mats = Rs
xyzs = [
misc.transform_pts_Rt_th(model["vertices"], R_mats[_id], ts[_id])[None]
for _id, model in enumerate(models)
]
ren_xyzs, _, _, _ = self.dib_ren.forward(points=points, colors=xyzs)
ret["depth"] = ren_xyzs[:, :, :, 2] # bhw
if "xyz" in mode: # TODO: check this
obj_xyzs = [model["vertices"][None] for _id, model in enumerate(models)]
ren_obj_xyzs, _, _, _ = self.dib_ren.forward(points=points, colors=obj_xyzs)
ret["xyz"] = ren_obj_xyzs
return ret
def render_batch_tex(
self,
Rs,
ts,
models,
*,
Ks,
width,
height,
znear=0.01,
zfar=100,
uv_type="vertex",
rot_type="mat",
mode=["color", "depth"],
):
"""render a batch for textured objects
Args:
Rs: [b,3,3] or [b,4] tensor
ts: [b,3] tensor
models: list of dict, each stores
vertex uv: {"vertices":, "faces":, "texture":, "vertex_uvs":,}
face uv: {"vertices":, "faces":, "texture":, "face_uvs":, "face_uv_ids":,}
Ks: [b,3,3] or [3,3]
uv_type: `vertex` | `face`
mode: color, depth, mask, xyz (one or more must be given)
Returns:
dict:
color: bhw3
mask: bhw
depth: bhw
xyz: bhw3
"""
assert self.dib_ren.mode in ["TextureBatch"], self.dib_ren.mode
ret = {}
self.dib_ren.set_camera_parameters_from_RT_K(Rs, ts, Ks, height, width, near=znear, far=zfar, rot_type=rot_type)
# points: list of [vertices, faces]
points = [[model["vertices"][None], model["faces"].long()] for model in models]
if uv_type == "vertex":
uv_bxpx2 = [model["vertex_uvs"][None] for model in models]
else: # face uv
uv_bxpx2 = [model["face_uvs"][None] for model in models]
ft_fx3_list = [model["face_uv_ids"] for model in models]
texture_bx3xthxtw = [model["texture"][None] for model in models]
# points: list of [vertices, faces]
# colors: list of colors
dib_ren_im, dib_ren_prob, _, dib_ren_mask = self.dib_ren.forward(
points=points,
uv_bxpx2=uv_bxpx2,
texture_bx3xthxtw=texture_bx3xthxtw,
ft_fx3=ft_fx3_list,
)
ret["color"] = dib_ren_im
ret["prob"] = dib_ren_prob.squeeze(-1) # bhw1 -> bhw
ret["mask"] = dib_ren_mask.squeeze(-1) # bhw1 -> bhw
if "depth" in mode:
# transform xyz
# NOTE: check whether it should be in [0, 1] (maybe need to record min, max and denormalize later)
if not isinstance(Rs, torch.Tensor):
Rs = torch.stack(Rs) # list
if rot_type == "quat":
R_mats = quat2mat_torch(Rs)
else:
R_mats = Rs
xyzs = [
misc.transform_pts_Rt_th(model["vertices"], R_mats[_id], ts[_id])[None]
for _id, model in enumerate(models)
]
dib_ren_vc_batch = DIBRenderer(height, width, mode="VertexColorBatch")
dib_ren_vc_batch.set_camera_parameters(self.dib_ren.camera_params)
ren_xyzs, _, _, _ = dib_ren_vc_batch.forward(points=points, colors=xyzs)
if "depth" in mode:
ret["depth"] = ren_xyzs[:, :, :, 2] # bhw
if "xyz" in mode: # TODO: check this
obj_xyzs = [model["vertices"][None] for _id, model in enumerate(models)]
dib_ren_vc_batch = DIBRenderer(height, width, mode="VertexColorBatch")
dib_ren_vc_batch.set_camera_parameters(self.dib_ren.camera_params)
ren_obj_xyzs, _, _, _ = dib_ren_vc_batch.forward(points=points, colors=obj_xyzs)
ret["xyz"] = ren_obj_xyzs
return ret # bxhxwx3 rgb, bhw prob/mask/depth
from .utils import *
from .mesh import *
from .perspective import *
from .sphericalcoord import *
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import os
import torch
import numpy as np
##################################################################
# faces begin from 0!!!
def face2edge(facenp_fx3):
"""facenp_fx3, int32 return edgenp_ex2, int32."""
f1 = facenp_fx3[:, 0:1]
f2 = facenp_fx3[:, 1:2]
f3 = facenp_fx3[:, 2:3]
e1 = np.concatenate((f1, f1, f2), axis=0)
e2 = np.concatenate((f2, f3, f3), axis=0)
edgenp_ex2 = np.concatenate((e1, e2), axis=1)
# sort & unique
edgenp_ex2 = np.sort(edgenp_ex2, axis=1)
edgenp_ex2 = np.unique(edgenp_ex2, axis=0)
return edgenp_ex2
def face2edge2(facenp_fx3, edgenp_ex2):
"""facenp_fx3, int32 edgenp_ex2, int32 return face_fx3, int32 this face is
indexed by edge."""
fnum = facenp_fx3.shape[0]
enum = edgenp_ex2.shape[0]
edgesort = np.sort(edgenp_ex2, axis=1)
edgere_fx3 = np.zeros_like(facenp_fx3)
for i in range(fnum):
for j in range(3):
pbe, pen = facenp_fx3[i, j], facenp_fx3[i, (j + 1) % 3]
if pbe > pen:
pbe, pen = pen, pbe
cond = (edgesort[:, 0] == pbe) & (edgesort[:, 1] == pen)
idx = np.where(cond)[0]
edgere_fx3[i, j] = idx
return edgere_fx3
def edge2face(facenp_fx3, edgenp_ex2):
"""facenp_fx3, int32 edgenp_ex2, int32 return edgenp_ex2, int32 this edge
is indexed by face."""
fnum = facenp_fx3.shape[0]
enum = edgenp_ex2.shape[0]
facesort = np.sort(facenp_fx3, axis=1)
edgesort = np.sort(edgenp_ex2, axis=1)
edgere_ex2 = np.zeros_like(edgesort)
for i in range(enum):
pbe, pen = edgesort[i]
eid = 0
for j in range(fnum):
f1, f2, f3 = facesort[j]
cond1 = f1 == pbe and f2 == pen
cond2 = f1 == pbe and f3 == pen
cond3 = f2 == pbe and f3 == pen
if cond1 or cond2 or cond3:
edgere_ex2[i, eid] = j
eid += 1
return edgere_ex2
def face2pneimtx(facenp_fx3):
"""facenp_fx3, int32 return pointneighbourmtx, pxp, float32 will normalize!
assume it is a good mesh every point has more than one neighbour
"""
pnum = np.max(facenp_fx3) + 1
pointneighbourmtx = np.zeros(shape=(pnum, pnum), dtype=np.float32)
for i in range(3):
be = i
en = (i + 1) % 3
idx1 = facenp_fx3[:, be]
idx2 = facenp_fx3[:, en]
pointneighbourmtx[idx1, idx2] = 1
pointneighbourmtx[idx2, idx1] = 1
pointneicount = np.sum(pointneighbourmtx, axis=1, keepdims=True)
assert np.all(pointneicount > 0)
pointneighbourmtx /= pointneicount
return pointneighbourmtx
def face2pfmtx(facenp_fx3):
"""facenp_fx3, int32 reutrn pfmtx, pxf, float32."""
pnum = np.max(facenp_fx3) + 1
fnum = facenp_fx3.shape[0]
pfmtx = np.zeros(shape=(pnum, fnum), dtype=np.float32)
for i, f in enumerate(facenp_fx3):
pfmtx[f[0], i] = 1
pfmtx[f[1], i] = 1
pfmtx[f[2], i] = 1
return pfmtx
# upsample new points
def meshresample(pointnp_px3, facenp_fx3, edgenp_ex2):
p1 = pointnp_px3[edgenp_ex2[:, 0], :]
p2 = pointnp_px3[edgenp_ex2[:, 1], :]
pmid = (p1 + p2) / 2
point2np_px3 = np.concatenate((pointnp_px3, pmid), axis=0)
# delete f
# add 4 new faces
face2np_fx3 = []
pnum = np.max(facenp_fx3) + 1
for f in facenp_fx3:
p1, p2, p3 = f
p12 = (edgenp_ex2 == (min(p1, p2), max(p1, p2))).all(axis=1).nonzero()[0] + pnum
p23 = (edgenp_ex2 == (min(p2, p3), max(p2, p3))).all(axis=1).nonzero()[0] + pnum
p31 = (edgenp_ex2 == (min(p3, p1), max(p3, p1))).all(axis=1).nonzero()[0] + pnum
face2np_fx3.append([p1, p12, p31])
face2np_fx3.append([p12, p2, p23])
face2np_fx3.append([p31, p23, p3])
face2np_fx3.append([p12, p23, p31])
face2np_fx3 = np.array(face2np_fx3, dtype=np.int64)
return point2np_px3, face2np_fx3
def mtx2tfsparse(mtx):
m, n = mtx.shape
rows, cols = np.nonzero(mtx)
# N = rows.shape[0]
# value = np.ones(shape=(N,), dtype=np.float32)
value = mtx[rows, cols]
v = torch.FloatTensor(value)
i = torch.LongTensor(np.stack((rows, cols), axis=0))
tfspmtx = torch.sparse.FloatTensor(i, v, torch.Size([m, n]))
return tfspmtx
################################################################
def loadobj(meshfile):
v = []
f = []
meshfp = open(meshfile, "r")
for line in meshfp.readlines():
data = line.strip().split(" ")
data = [da for da in data if len(da) > 0]
if len(data) != 4:
continue
if data[0] == "v":
v.append([float(d) for d in data[1:]])
if data[0] == "f":
data = [da.split("/")[0] for da in data]
f.append([int(d) for d in data[1:]])
meshfp.close()
# torch need int64
facenp_fx3 = np.array(f, dtype=np.int64) - 1
pointnp_px3 = np.array(v, dtype=np.float32)
return pointnp_px3, facenp_fx3
def loadobjcolor(meshfile):
v = []
vc = []
f = []
meshfp = open(meshfile, "r")
for line in meshfp.readlines():
data = line.strip().split(" ")
data = [da for da in data if len(da) > 0]
if data[0] == "v":
v.append([float(d) for d in data[1:4]])
if len(data) == 7:
vc.append([float(d) for d in data[4:7]])
if data[0] == "f":
data = [da.split("/")[0] for da in data]
f.append([int(d) for d in data[1:4]])
meshfp.close()
# torch need int64
facenp_fx3 = np.array(f, dtype=np.int64) - 1
pointnp_px3 = np.array(v, dtype=np.float32)
if len(vc) > 0:
vc = np.array(vc, dtype=np.float32)
else:
vc = np.ones_like(pointnp_px3)
return pointnp_px3, facenp_fx3, vc
def loadobjtex(meshfile):
v = []
vt = []
f = []
ft = []
meshfp = open(meshfile, "r")
for line in meshfp.readlines():
data = line.strip().split(" ")
data = [da for da in data if len(da) > 0]
if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)):
continue
if data[0] == "v":
if len(data) == 4:
v.append([float(d) for d in data[1:]])
if data[0] == "vt":
if len(data) == 3 or len(data) == 4:
vt.append([float(d) for d in data[1:3]])
if data[0] == "f":
data = [da.split("/") for da in data]
if len(data) == 4:
f.append([int(d[0]) for d in data[1:]])
# print(data[1:])
ft.append([int(d[1]) for d in data[1:]])
elif len(data) == 5:
idx1 = [1, 2, 3]
data1 = [data[i] for i in idx1]
f.append([int(d[0]) for d in data1])
ft.append([int(d[1]) for d in data1])
idx2 = [1, 3, 4]
data2 = [data[i] for i in idx2]
f.append([int(d[0]) for d in data2])
ft.append([int(d[1]) for d in data2])
meshfp.close()
# torch need int64
facenp_fx3 = np.array(f, dtype=np.int64) - 1
ftnp_fx3 = np.array(ft, dtype=np.int64) - 1
pointnp_px3 = np.array(v, dtype=np.float32)
uvs = np.array(vt, dtype=np.float32)
return pointnp_px3, facenp_fx3, uvs, ftnp_fx3
def savemesh(pointnp_px3, facenp_fx3, fname, partinfo=None):
if partinfo is None:
fid = open(fname, "w")
for pidx, p in enumerate(pointnp_px3):
pp = p
fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
for f in facenp_fx3:
f1 = f + 1
fid.write("f %d %d %d\n" % (f1[0], f1[1], f1[2]))
fid.close()
else:
fid = open(fname, "w")
for pidx, p in enumerate(pointnp_px3):
if partinfo[pidx, -1] == 0:
pp = p
color = [1, 0, 0]
else:
pp = p
color = [0, 0, 1]
fid.write("v %f %f %f %f %f %f\n" % (pp[0], pp[1], pp[2], color[0], color[1], color[2]))
for f in facenp_fx3:
f1 = f + 1
fid.write("f %d %d %d\n" % (f1[0], f1[1], f1[2]))
fid.close()
return
def savemeshcolor(pointnp_px3, facenp_fx3, fname, color_px3=None):
if color_px3 is None:
fid = open(fname, "w")
for pidx, p in enumerate(pointnp_px3):
pp = p
fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
for f in facenp_fx3:
f1 = f + 1
fid.write("f %d %d %d\n" % (f1[0], f1[1], f1[2]))
fid.close()
else:
fid = open(fname, "w")
for pidx, p in enumerate(pointnp_px3):
pp = p
color = color_px3[pidx]
fid.write("v %f %f %f %f %f %f\n" % (pp[0], pp[1], pp[2], color[0], color[1], color[2]))
for f in facenp_fx3:
f1 = f + 1
fid.write("f %d %d %d\n" % (f1[0], f1[1], f1[2]))
fid.close()
return
def savemeshtes(pointnp_px3, tcoords_px2, facenp_fx3, fname):
import os
fol, na = os.path.split(fname)
na, _ = os.path.splitext(na)
matname = "%s/%s.mtl" % (fol, na)
fid = open(matname, "w")
fid.write("newmtl material_0\n")
fid.write("Kd 1 1 1\n")
fid.write("Ka 0 0 0\n")
fid.write("Ks 0.4 0.4 0.4\n")
fid.write("Ns 10\n")
fid.write("illum 2\n")
fid.write("map_Kd %s.png\n" % na)
fid.close()
fid = open(fname, "w")
fid.write("mtllib %s.mtl\n" % na)
for pidx, p in enumerate(pointnp_px3):
pp = p
fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
for pidx, p in enumerate(tcoords_px2):
pp = p
fid.write("vt %f %f\n" % (pp[0], pp[1]))
fid.write("usemtl material_0\n")
for f in facenp_fx3:
f1 = f + 1
fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f1[0], f1[1], f1[1], f1[2], f1[2]))
fid.close()
return
def save_textured_mesh(
directory,
file_name,
vertex_pos_px3,
face_fx3,
tex_coord_px2,
normalize_tex_coord=False,
flip_vertical=False,
texture_bias=0.01,
):
"""Save a textured mesh. Assumes the texture is *already* saved into.
<directory> as <file_name>.png.
Args:
directory (str): The path to the folder containing the mesh to be saved.
file_name (str): The name of the mesh to be saved (without extension).
<file_name>.obj and <file_name>.mtl will be saved.
vertex_pos_px3 (numpy.ndarray): An array of shape (num_points, 3).
Denotes the vertex position.
face_fx3 (numpy.ndarray): An array of shape (num_faces, 3).
Specifies, for each face, which vertices are used.
tex_coord_px2 (numpy.ndarray): An array of shape (num_points, 2).
Specifies the texture coordinate of each vertex.
Each coordinate should be in the range [0, 1] or [-1, -1].
If the range is [-1, -1], set normalize_tex_coord to True.
NOTE: if this array is of the same format as specified for
torch.nn.functional.grid_sample(), set both normalize_tex_coord
and flip_vertical to True.
normalize_tex_coord (bool): Whether to normalize texture coordinates,
from [-1, 1] to [0, 1].
flip_vertical (bool): Whether to flip the texture coordinates vertically.
texture_bias (float): If positive, trim the edge of the texture by this
amount to avoid artifacts.
"""
if os.path.splitext(file_name)[1]:
raise ValueError("file_name to save_textured_mesh cannot contain extension")
if file_name.find(" ") != -1:
raise ValueError("file_name cannot contain space")
obj_path = os.path.join(directory, file_name + ".obj")
mtl_path = os.path.join(directory, file_name + ".mtl")
with open(obj_path, "w") as obj_file:
obj_file.write("mtllib ./{}.mtl\n".format(file_name))
for pos in vertex_pos_px3:
obj_file.write("v {} {} {}\n".format(pos[0], pos[1], pos[2]))
for uv in tex_coord_px2:
uv = uv * 0.5 + 0.5 # normalize from [-1, 1] to [0, 1]
uv = uv * (1.0 - texture_bias * 2.0) + texture_bias
obj_file.write("vt {} {}\n".format(uv[0], 1.0 - uv[1] if flip_vertical else uv[1]))
obj_file.write("usemtl material_0\n")
for i in range(face_fx3.shape[0]):
face = face_fx3[i] + 1
obj_file.write("f {0}/{0} {1}/{1} {2}/{2}\n".format(face[0], face[1], face[2]))
with open(mtl_path, "w") as mtl_file:
mtl_file.write(
"""newmtl material_0
Ka 0.200000 0.200000 0.200000
Kd 1.000000 1.000000 1.000000
Ks 1.000000 1.000000 1.000000
map_Kd {}.png""".format(
file_name
)
)
return
def saveobjscale(meshfile, scale, maxratio, shift=None):
mname, prefix = os.path.splitext(meshfile)
mnamenew = "%s-%.2f%s" % (mname, maxratio, prefix)
meshfp = open(meshfile, "r")
meshfp2 = open(mnamenew, "w")
for line in meshfp.readlines():
data = line.strip().split(" ")
data = [da for da in data if len(da) > 0]
if len(data) != 4:
meshfp2.write(line)
continue
else:
if data[0] == "v":
p = [scale * float(d) for d in data[1:]]
meshfp2.write("v %f %f %f\n" % (p[0], p[1], p[2]))
else:
meshfp2.write(line)
continue
meshfp.close()
meshfp2.close()
return
if __name__ == "__main__":
import cv2
meshjson = "1.obj"
# f begin from 0!!!
pointnp_px3, facenp_fx3 = loadobj(meshjson)
assert np.max(facenp_fx3) == pointnp_px3.shape[0] - 1
assert np.min(facenp_fx3) == 0
pointnp_px3[:, 1] -= 0.05
X = pointnp_px3[:, 0]
Y = pointnp_px3[:, 1]
Z = pointnp_px3[:, 2]
h = 248 * (Y / Z) + 111.5
w = -248 * (X / Z) + 111.5
height = 224
width = 224
im = np.zeros(shape=(height, width), dtype=np.uint8)
for cir in zip(w, h):
cv2.circle(im, (int(cir[0]), int(cir[1])), 3, (255, 0, 0), -1)
cv2.imshow("", im)
cv2.waitKey()
# edge, neighbour and pfmtx
edgenp_ex2 = face2edge(facenp_fx3)
face_edgeidx_fx3 = face2edge2(facenp_fx3, edgenp_ex2)
pneimtx = face2pneimtx(facenp_fx3)
pfmtx = face2pfmtx(facenp_fx3)
# save
savemesh(pointnp_px3, facenp_fx3, "1s.obj")
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import torch
import numpy as np
def unit(v):
norm = np.linalg.norm(v)
if norm == 0:
return v
return v / norm
def lookatnp(eye_3x1, center_3x1, up_3x1):
# 3 variables should be length 1
camz = center_3x1 - eye_3x1
camz /= np.sqrt(np.sum(camz**2))
camx = np.cross(camz[:, 0], up_3x1[:, 0]).reshape(3, 1)
camy = np.cross(camx[:, 0], camz[:, 0]).reshape(3, 1)
# they are not guaranteed to be 1!!!
mtx = np.concatenate([unit(camx), unit(camy), -camz], axis=1).transpose()
shift = -(np.matmul(mtx, eye_3x1))
return mtx, shift
def camera_info(param):
theta = np.deg2rad(param[0])
phi = np.deg2rad(param[1])
camY = param[3] * np.sin(phi)
temp = param[3] * np.cos(phi)
camX = temp * np.cos(theta)
camZ = temp * np.sin(theta)
cam_pos = np.array([camX, camY, camZ])
axisZ = cam_pos.copy()
axisY = np.array([0, 1, 0], dtype=np.float32)
axisX = np.cross(axisY, axisZ)
axisY = np.cross(axisZ, axisX)
# cam_mat = np.array([axisX, axisY, axisZ])
cam_mat = np.array([unit(axisX), unit(axisY), unit(axisZ)])
# for verify
# mtx, shift = lookatnp(cam_pos_3xb.reshape(3, 1), np.zeros(shape=(3, 1), dtype=np.float32), np.array([0,1,0], dtype=np.float32).reshape(3, 1))
# note, it is different from lookatnp
# new_p = mtx * old_p + shift
# new_p = cam_mat * (old_p - cam_pos)
return cam_mat, cam_pos
#####################################################
def perspectiveprojectionnp(fovy, ratio=1.0, near=0.01, far=10.0):
"""
fovy: radian, 2 * atan2(h, 2*fy)
ratio: aspect_ratio, w/h, typically 4/3
"""
tanfov = np.tan(fovy / 2.0) # h/(2*fy)
# top = near * tanfov
# right = ratio * top
# mtx = [near / right, 0, 0, 0, \
# 0, near / top, 0, 0, \
# 0, 0, -(far+near)/(far-near), -2*far*near/(far-near), \
# 0, 0, -1, 0]
mtx = [
[1.0 / (ratio * tanfov), 0, 0, 0],
[0, 1.0 / tanfov, 0, 0],
[0, 0, -(far + near) / (far - near), -2 * far * near / (far - near)],
[0, 0, -1.0, 0],
]
# return np.array(mtx, dtype=np.float32)
# 2*fy/h/ratio=2*fy/w, 2*fy/h
return np.array([[1.0 / (ratio * tanfov)], [1.0 / tanfov], [-1]], dtype=np.float32)
def projectiveprojection_real(cam, x0, y0, w, h, nc=0.01, fc=10.0):
# this is for center view
# NOTE: only return a 3x1 vector (diagonal??)
q = -(fc + nc) / float(fc - nc)
qn = -2 * (fc * nc) / float(fc - nc)
fx = cam[0, 0]
fy = cam[1, 1]
px = cam[0, 2]
py = cam[1, 2]
"""
# transpose: compensate for the flipped image
proj_T = [
[2*fx/w, 0, 0, 0],
[0, 2*fy/h, 0, 0],
[(-2*px+w+2*x0)/w, (2*py-h+2*y0)/h, q, -1],
[0, 0, qn, 0],
]
sometimes: P[1,:] *= -1, P[2,:] *= -1
# Third column is standard glPerspective and sets near and far planes
"""
# Draw our images upside down, so that all the pixel-based coordinate systems are the same
if isinstance(cam, np.ndarray):
proj_T = np.zeros((4, 4), dtype=np.float32)
elif isinstance(cam, torch.Tensor):
proj_T = torch.zeros(4, 4).to(cam)
else:
raise TypeError("cam should be ndarray or tensor, got {}".format(type(cam)))
proj_T[0, 0] = 2 * fx / w
proj_T[1, 0] = -2 * cam[0, 1] / w # =0
proj_T[1, 1] = 2 * fy / h
proj_T[2, 0] = (-2 * px + w + 2 * x0) / w
proj_T[2, 1] = (+2 * py - h + 2 * y0) / h
proj_T[2, 2] = q
proj_T[3, 2] = qn
proj_T[2, 3] = -1.0
return proj_T
#####################################################
def camera_info_batch(param_bx4):
bnum = param_bx4.shape[0]
cam_mat_bx3x3 = []
cam_pos_bx3 = []
for i in range(bnum):
param = param_bx4[i]
cam_mat, cam_pos = camera_info(param)
cam_mat_bx3x3.append(cam_mat)
cam_pos_bx3.append(cam_pos)
cam_mat_bx3x3 = np.stack(cam_mat_bx3x3, axis=0)
cam_pos_bx3 = np.stack(cam_pos_bx3, axis=0)
return cam_mat_bx3x3, cam_pos_bx3
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import numpy as np
##################################################################
# symmetric over z axis
def get_spherical_coords_z(X):
# X is N x 3
rad = np.linalg.norm(X, axis=1)
# Inclination
theta = np.arccos(X[:, 2] / rad)
# Azimuth
phi = np.arctan2(X[:, 1], X[:, 0])
# Normalize both to be between [-1, 1]
vv = (theta / np.pi) * 2 - 1
uu = ((phi + np.pi) / (2 * np.pi)) * 2 - 1
# Return N x 2
return np.stack([uu, vv], 1)
# symmetric over x axis
def get_spherical_coords_x(X):
# X is N x 3
rad = np.linalg.norm(X, axis=1)
# Inclination
# y == 1
# cos = 0
# y == -1
# cos = pi
theta = np.arccos(X[:, 0] / rad)
# Azimuth
phi = np.arctan2(X[:, 2], X[:, 1])
# Normalize both to be between [-1, 1]
uu = (theta / np.pi) * 2 - 1
vv = ((phi + np.pi) / (2 * np.pi)) * 2 - 1
# Return N x 2
return np.stack([uu, vv], 1)
# symmetric spherical projection
def get_symmetric_spherical_tex_coords(vertex_pos, symmetry_axis=1, up_axis=2, front_axis=0):
# vertex_pos is N x 3
length = np.linalg.norm(vertex_pos, axis=1)
# Inclination
theta = np.arccos(vertex_pos[:, front_axis] / length)
# Azimuth
phi = np.abs(np.arctan2(vertex_pos[:, symmetry_axis], vertex_pos[:, up_axis]))
# Normalize both to be between [-1, 1]
uu = (theta / np.pi) * 2 - 1
# vv = ((phi + np.pi) / (2 * np.pi)) * 2 - 1
vv = (phi / np.pi) * 2 - 1
# Return N x 2
return np.stack([uu, vv], 1)
#########################################################################
if __name__ == "__main__":
from utils.utils_mesh import loadobj, savemeshtes
import cv2
p, f = loadobj("2.obj")
uv = get_spherical_coords_x(p)
uv[:, 0] = -uv[:, 0]
uv[:, 1] = -uv[:, 1]
uv = (uv + 1) / 2
savemeshtes(p, uv, f, "./2_x.obj")
tex = np.zeros(shape=(256, 512, 3), dtype=np.uint8)
font = cv2.FONT_HERSHEY_SIMPLEX
bottomLeftCornerOfText = (10, 200)
fontScale = 5
fontColor = (0, 255, 255)
lineType = 2
cv2.putText(
tex,
"Hello World!",
bottomLeftCornerOfText,
font,
fontScale,
fontColor,
lineType,
)
cv2.imshow("", tex)
cv2.waitKey()
cv2.imwrite("2_x.png", np.transpose(tex, [1, 0, 2]))
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
from __future__ import print_function
from __future__ import division
import torch
import torch.nn
eps = 1e-15
##################################################
def datanormalize(data, axis):
datalen = torch.sqrt(torch.sum(data**2, dim=axis, keepdim=True))
return data / (datalen + eps)
# differentiable renderer utils
import cv2
import numpy as np
import torch
import os.path as osp
from tqdm import tqdm
import hashlib
import logging
import mmcv
from .rep import TriangleMesh
from .dib_renderer_x import DIBRenderer
from core.utils.pose_utils import quat2mat_torch
from lib.utils.utils import iprint
from lib.pysixd import misc
def load_objs(
obj_paths,
texture_paths=None,
height=480,
width=640,
tex_resize=True,
tex_fmt="CHW",
tex_vflip=False,
):
"""
NOTE: ignore width, height if tex_resize=False
"""
assert all([".obj" in _path for _path in obj_paths])
if texture_paths is not None:
assert len(obj_paths) == len(texture_paths)
models = []
for i, obj_path in enumerate(tqdm(obj_paths)):
model = {}
mesh = TriangleMesh.from_obj(obj_path)
vertices = mesh.vertices[:, :3] # x,y,z
colors = mesh.vertices[:, 3:6] # rgb
faces = mesh.faces.int()
###########################
# normalize verts ( - center)
###########################
vertices_max = vertices.max()
vertices_min = vertices.min()
vertices_middle = (vertices_max + vertices_min) / 2.0
vertices = vertices - vertices_middle
model["vertices"] = vertices[None, :, :].cuda()
model["colors"] = colors[None, :, :].cuda()
model["faces"] = faces[None, :, :].cuda() # NOTE: -1
if texture_paths is not None:
uvs = mesh.uvs
face_textures = mesh.face_textures # NOTE: -1
assert osp.exists(texture_paths[i]), texture_paths[i]
if tex_vflip:
texture = cv2.imread(texture_paths[i], cv2.IMREAD_COLOR)[::-1, :, ::-1].astype(np.float32) / 255.0
else:
texture = cv2.imread(texture_paths[i], cv2.IMREAD_COLOR)[:, :, ::-1].astype(np.float32) / 255.0
if tex_resize:
texture = cv2.resize(texture, (width, height), interpolation=cv2.INTER_AREA)
# print('texture map: ', texture.shape)
if tex_fmt == "CHW":
texture = torch.from_numpy(texture.transpose(2, 0, 1)[None, :, :, :]).cuda()
else: # HWC
texture = torch.from_numpy(texture[None, :, :, :]).cuda()
model["face_uvs"] = uvs[None, :, :].cuda()
model["face_uv_ids"] = face_textures[None, :, :].cuda()
model["texture"] = texture.cuda()
models.append(model)
return models
def render_dib_vc_batch(
ren,
Rs,
ts,
Ks,
obj_ids,
models,
rot_type="quat",
H=480,
W=640,
near=0.01,
far=100.0,
with_depth=False,
):
"""
Args:
ren: A DIB-renderer
models: All models loaded by load_objs
"""
assert ren.mode in ["VertexColorBatch"], ren.mode
bs = len(Rs)
if len(Ks) == 1:
Ks = [Ks[0] for _ in range(bs)]
ren.set_camera_parameters_from_RT_K(Rs, ts, Ks, height=H, width=W, near=near, far=far, rot_type=rot_type)
colors = [models[_id]["colors"] for _id in obj_ids] # b x [1, p, 3]
points = [[models[_id]["vertices"], models[_id]["faces"][0].long()] for _id in obj_ids]
# points: list of [vertices, faces]
# colors: list of colors
predictions, im_probs, _, im_masks = ren.forward(points=points, colors=colors)
if with_depth:
# transform xyz
if not isinstance(Rs, torch.Tensor):
Rs = torch.stack(Rs) # list
if rot_type == "quat":
R_mats = quat2mat_torch(Rs)
else:
R_mats = Rs
xyzs = [
misc.transform_pts_Rt_th(models[obj_id]["vertices"][0], R_mats[_id], ts[_id])[None]
for _id, obj_id in enumerate(obj_ids)
]
ren_xyzs, _, _, _ = ren.forward(points=points, colors=xyzs)
depth = ren_xyzs[:, :, :, 2] # bhw
else:
depth = None
# bxhxwx3 rgb, bhw1 prob, bhw1 mask, bhw depth
return predictions, im_probs, im_masks, depth
def render_dib_tex_batch(
ren,
Rs,
ts,
Ks,
obj_ids,
models,
rot_type="quat",
H=480,
W=640,
near=0.01,
far=100.0,
with_depth=False,
):
assert ren.mode in ["TextureBatch"], ren.mode
bs = len(Rs)
if len(Ks) == 1:
Ks = [Ks[0] for _ in range(bs)]
ren.set_camera_parameters_from_RT_K(Rs, ts, Ks, height=H, width=W, near=near, far=far, rot_type=rot_type)
# points: list of [vertices, faces]
points = [[models[_id]["vertices"], models[_id]["faces"][0].long()] for _id in obj_ids]
uv_bxpx2 = [models[_id]["face_uvs"] for _id in obj_ids]
texture_bx3xthxtw = [models[_id]["texture"] for _id in obj_ids]
ft_fx3_list = [models[_id]["face_uv_ids"][0] for _id in obj_ids]
# points: list of [vertices, faces]
# colors: list of colors
dib_ren_im, dib_ren_prob, _, dib_ren_mask = ren.forward(
points=points,
uv_bxpx2=uv_bxpx2,
texture_bx3xthxtw=texture_bx3xthxtw,
ft_fx3=ft_fx3_list,
)
if with_depth:
# transform xyz
if not isinstance(Rs, torch.Tensor):
Rs = torch.stack(Rs) # list
if rot_type == "quat":
R_mats = quat2mat_torch(Rs)
else:
R_mats = Rs
xyzs = [
misc.transform_pts_Rt_th(models[obj_id]["vertices"][0], R_mats[_id], ts[_id])[None]
for _id, obj_id in enumerate(obj_ids)
]
dib_ren_vc_batch = DIBRenderer(height=H, width=W, mode="VertexColorBatch")
dib_ren_vc_batch.set_camera_parameters(ren.camera_params)
ren_xyzs, _, _, _ = dib_ren_vc_batch.forward(points=points, colors=xyzs)
depth = ren_xyzs[:, :, :, 2] # bhw
else:
depth = None
return (
dib_ren_im,
dib_ren_prob,
dib_ren_mask,
depth,
) # bxhxwx3 rgb, bhw1 prob/mask, bhw depth
def render_dib_vc_multi(
ren,
Rs,
ts,
K,
obj_ids,
models,
rot_type="quat",
H=480,
W=640,
near=0.01,
far=100.0,
):
assert ren.mode in ["VertexColorMulti"], ren.mode
ren.set_camera_parameters_from_RT_K(Rs, ts, K, height=H, width=W, near=near, far=far, rot_type=rot_type)
colors = [models[_id]["colors"] for _id in obj_ids] # b x [1, p, 3]
points = [[models[_id]["vertices"], models[_id]["faces"][0].long()] for _id in obj_ids]
# points: list of [vertices, faces]
# colors: list of colors
predictions, im_prob, _, im_mask = ren.forward(points=points, colors=colors)
# TODO: add depth
return predictions, im_prob, im_mask # 1xhxwx3 rgb
def render_dib_tex_multi(
ren,
Rs,
ts,
K,
obj_ids,
models,
rot_type="quat",
H=480,
W=640,
near=0.01,
far=100.0,
):
assert ren.mode in ["TextureMulti"], ren.mode
ren.set_camera_parameters_from_RT_K(Rs, ts, K, height=H, width=W, near=near, far=far, rot_type=rot_type)
# points: list of [vertices, faces]
points = [[models[_id]["vertices"], models[_id]["faces"][0].long()] for _id in obj_ids]
uv_bxpx2 = [models[_id]["face_uvs"] for _id in obj_ids]
texture_bx3xthxtw = [models[_id]["texture"] for _id in obj_ids]
ft_fx3_list = [models[_id]["face_uv_ids"][0] for _id in obj_ids]
dib_ren_im, dib_ren_prob, _, dib_ren_mask = ren.forward(
points=points,
uv_bxpx2=uv_bxpx2,
texture_bx3xthxtw=texture_bx3xthxtw,
ts=ts,
ft_fx3=ft_fx3_list,
)
# TODO: add depth
return (
dib_ren_im,
dib_ren_prob,
dib_ren_mask,
) # 1xhxwx3 rgb, (1,h,w,1) prob/mask
This diff is collapsed.
This diff is collapsed.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
import torch
import numpy as np
import torch.nn.functional as F
from .helpers import _composedecorator
# from kaolin.rep.Mesh import Mesh
from .Mesh import Mesh
class TriangleMesh(Mesh):
"""Abstract class to represent 3D Trianlge meshes."""
def __init__(
self,
vertices: torch.Tensor,
faces: torch.Tensor,
uvs: torch.Tensor,
face_textures: torch.Tensor,
textures: torch.Tensor,
edges: torch.Tensor,
edge2key: dict,
vv: torch.Tensor,
vv_count: torch.Tensor,
vf: torch.Tensor,
vf_count: torch.Tensor,
ve: torch.Tensor,
ve_count: torch.Tensor,
ff: torch.Tensor,
ff_count: torch.Tensor,
ef: torch.Tensor,
ef_count: torch.Tensor,
ee: torch.Tensor,
ee_count: torch.Tensor,
):
# Vertices of the mesh
self.vertices = vertices
# Faces of the mesh
self.faces = faces
# uv coordinates of each vertex
self.uvs = uvs
# uv indecies for each face
self.face_textures = face_textures
# texture for each face
self.textures = textures
# Edges of the mesh
self.edges = edges
# Dictionary that maps an edge (tuple) to an edge idx
self.edge2key = edge2key
# Vertex-Vertex neighborhood tensor (for each vertex, contains
# indices of the vertices neighboring it)
self.vv = vv
# Number of vertices neighbouring each vertex
self.vv_count = vv_count
# Vertex-Face neighborhood tensor
self.vf = vf
# Number of faces neighbouring each vertex
self.vf_count = vf_count
# Vertex-Edge neighborhood tensor
self.ve = ve
# Number of edges neighboring each vertex
self.ve_count = ve_count
# Face-Face neighborhood tensor
self.ff = ff
# Number of faces neighbouring each face
self.ff_count = ff_count
# Edge-Face neighbourhood tensor
self.ef = ef
# Number of edges neighbouring each face
self.ef_count = ef_count
# Edge-Edge neighbourhood tensor
self.ee = ee
# Number of edges neighbouring each edge
self.ee_count = ee_count
# adjacency matrix for verts
self.adj = None
# Initialize device on which tensors reside.
self.device = self.vertices.device
@staticmethod
def normalize_zerosafe(matrix):
"""Normalizes each row of a matrix in a 'division by zero'-safe way.
Args:
matrix (torch.tensor): Matrix where each row contains a vector
to be normalized
"""
assert matrix.dim() == 2, "Need matrix to contain exactly 2 dimensions"
magnitude = torch.sqrt(torch.sum(torch.pow(matrix, 2), dim=1))
valid_inds = magnitude > 0
matrix[valid_inds] = torch.div(matrix[valid_inds], magnitude[valid_inds].unsqueeze(1))
return matrix
def compute_vertex_normals(self):
"""Compute vertex normals for each mesh vertex."""
# Let each face ordering be denoted a, b, c, d. For consistent order,
# we vectorize operations, so that a (for example) denotes the first
# vertex of each face in the mesh.
a = torch.index_select(self.vertices, dim=0, index=self.faces[:, 0].flatten())
b = torch.index_select(self.vertices, dim=0, index=self.faces[:, 1].flatten())
c = torch.index_select(self.vertices, dim=0, index=self.faces[:, 2].flatten())
# Compute vertex normals.
# Eg. Normals for vertices 'a' are given by (b-a) x (c - a)
vn_a = TriangleMesh.normalize_zerosafe(torch.cross(b - a, c - a, dim=1))
vn_b = TriangleMesh.normalize_zerosafe(torch.cross(c - b, a - b, dim=1))
vn_c = TriangleMesh.normalize_zerosafe(torch.cross(a - c, b - c, dim=1))
# Using the above, we have duplicate vertex normals (since a vertex is
# usually a part of more than one face). We only select the first face
# each vertex is a 'neighbor' to, to avoid confusion.
face_inds = self.vf[:, 0]
# Now that we know which face each vertex belongs to, we need to find
# the index of the vertex in that selected face. (i.e., is the
# selected vertex the 'a', the 'b', the 'c', or the 'd' vertex of the
# face?).
vertex_inds = torch.arange(self.vertices.shape[0]).unsqueeze(1).to(self.vertices.device)
# Mask that specifies which index of each face to look at, for the
# vertex we wish to find.
mask_abc = self.faces[face_inds] == vertex_inds.repeat(1, 3)
mask_abc = mask_abc.cuda()
# Array to hold vertex normals
vn = torch.zeros_like(self.vertices)
inds = torch.nonzero(mask_abc[:, 0])
inds = torch.cat((inds, torch.zeros_like(inds)), dim=1)
vn[inds] = vn_a[face_inds[inds]]
inds = torch.nonzero(mask_abc[:, 1])
inds = torch.cat((inds, 1 * torch.ones_like(inds)), dim=1)
vn[inds] = vn_b[face_inds[inds]]
inds = torch.nonzero(mask_abc[:, 2])
inds = torch.cat((inds, 2 * torch.ones_like(inds)), dim=1)
vn[inds] = vn_c[face_inds[inds]]
return vn
def compute_face_normals(self):
r"""Compute normals for each face in the mesh."""
# Let each face be denoted (a, b, c). We vectorize operations, so,
# we take `a` to mean the "first vertex of every face", and so on.
a = torch.index_select(self.vertices, dim=0, index=self.faces[:, 0].flatten())
b = torch.index_select(self.vertices, dim=0, index=self.faces[:, 1].flatten())
c = torch.index_select(self.vertices, dim=0, index=self.faces[:, 2].flatten())
# Compute vertex normals (for each face). Note the the same vertex
# can have different normals for each face.
# Eg. Normals for vertices 'a' are given by (b-a) x (c - a)
vn_a = TriangleMesh.normalize_zerosafe(torch.cross(b - a, c - a, dim=1))
vn_b = TriangleMesh.normalize_zerosafe(torch.cross(c - b, a - b, dim=1))
vn_c = TriangleMesh.normalize_zerosafe(torch.cross(a - c, b - c, dim=1))
# Add and normalize the normals (for a more robust estimate)
face_normals = vn_a + vn_b + vn_c
face_normals_norm = face_normals.norm(dim=1)
face_normals = face_normals / torch.where(
face_normals_norm > 0, face_normals_norm, torch.ones_like(face_normals_norm)
).view(-1, 1)
return face_normals
def compute_edge_lengths(self):
"""Compute edge lengths for each edge of the mesh."""
self.edges = self.edges.to(self.vertices.device)
# Let each edge be denoted (a, b). We perform a vectorized select
# and then compute the magnitude of the vector b - a.
a = torch.index_select(self.vertices, dim=0, index=self.edges[:, 0].flatten())
b = torch.index_select(self.vertices, dim=0, index=self.edges[:, 1].flatten())
return (b - a).norm(dim=1)
def compute_face_areas(self):
raise NotImplementedError
def compute_interior_angles_per_edge(self):
raise NotImplementedError
def compute_dihedral_angles_per_edge(self):
raise NotImplementedError
def save_mesh(self, filename: str):
r"""Save a mesh to a wavefront .obj file format
Args:
filename (str) : target filename
"""
with open(filename, "w") as f:
# write vertices
for vert in self.vertices:
f.write("v %f %f %f\n" % tuple(vert))
# write faces
for face in self.faces:
f.write("f %d %d %d\n" % tuple(face + 1))
def sample(self, num_samples: int, eps: float = 1e-10):
r"""Uniformly samples the surface of a mesh.
Args:
num_samples (int): number of points to sample
eps (float): a small number to prevent division by zero
for small surface areas.
Returns:
(torch.Tensor, torch.Tensor) uniformly sampled points and
the face idexes which each point corresponds to.
Example:
>>> points, chosen_faces = mesh.sample(10)
>>> points
tensor([[ 0.0293, 0.2179, 0.2168],
[ 0.2003, -0.3367, 0.2187],
[ 0.2152, -0.0943, 0.1907],
[-0.1852, 0.1686, -0.0522],
[-0.2167, 0.3171, 0.0737],
[ 0.2219, -0.0289, 0.1531],
[ 0.2217, -0.0115, 0.1247],
[-0.1400, 0.0364, -0.1618],
[ 0.0658, -0.0310, -0.2198],
[ 0.1926, -0.1867, -0.2153]])
>>> chosen_faces
tensor([ 953, 38, 6, 3480, 563, 393, 395, 3309, 373, 271])
"""
if self.vertices.is_cuda:
dist_uni = torch.distributions.Uniform(torch.tensor([0.0]).cuda(), torch.tensor([1.0]).cuda())
else:
dist_uni = torch.distributions.Uniform(torch.tensor([0.0]), torch.tensor([1.0]))
# calculate area of each face
x1, x2, x3 = torch.split(
torch.index_select(self.vertices, 0, self.faces[:, 0])
- torch.index_select(self.vertices, 0, self.faces[:, 1]),
1,
dim=1,
)
y1, y2, y3 = torch.split(
torch.index_select(self.vertices, 0, self.faces[:, 1])
- torch.index_select(self.vertices, 0, self.faces[:, 2]),
1,
dim=1,
)
a = (x2 * y3 - x3 * y2) ** 2
b = (x3 * y1 - x1 * y3) ** 2
c = (x1 * y2 - x2 * y1) ** 2
Areas = torch.sqrt(a + b + c) / 2
# percentage of each face w.r.t. full surface area
Areas = Areas / (torch.sum(Areas) + eps)
# define descrete distribution w.r.t. face area ratios caluclated
cat_dist = torch.distributions.Categorical(Areas.view(-1))
face_choices = cat_dist.sample([num_samples])
# from each face sample a point
select_faces = self.faces[face_choices]
v0 = torch.index_select(self.vertices, 0, select_faces[:, 0])
v1 = torch.index_select(self.vertices, 0, select_faces[:, 1])
v2 = torch.index_select(self.vertices, 0, select_faces[:, 2])
u = torch.sqrt(dist_uni.sample([num_samples]))
v = dist_uni.sample([num_samples])
points = (1 - u) * v0 + (u * (1 - v)) * v1 + u * v * v2
return points, face_choices
def compute_adjacency_matrix_full(self):
r"""Calcualtes a binary adjacency matrix for a mesh.
Returns:
(torch.Tensor) : binary adjacency matrix
Example:
>>> mesh = TriangleMesh.from_obj('model.obj')
>>> adj_info = mesh.compute_adjacency_matrix_full()
>>> neighborhood_sum = torch.mm( adj_info, mesh.vertices)
"""
adj = torch.zeros((self.vertices.shape[0], self.vertices.shape[0])).to(self.vertices.device)
v1 = self.faces[:, 0]
v2 = self.faces[:, 1]
v3 = self.faces[:, 2]
adj[(v1, v1)] = 1
adj[(v2, v2)] = 1
adj[(v3, v3)] = 1
adj[(v1, v2)] = 1
adj[(v2, v1)] = 1
adj[(v1, v3)] = 1
adj[(v3, v1)] = 1
adj[(v2, v3)] = 1
adj[(v2, v3)] = 1
return adj
def load_tensors(filename: str, enable_adjacency: bool = False):
r"""Loads the tensor information of the mesh from a saved numpy array.
Args:
filename: Path of the file to load the file from.
Example:
>>> mesh = TriangleMesh.load_tensors('mesh.npy')
"""
data = np.load(filename)
vertices = torch.FloatTensor(data["vertices"])
faces = torch.LongTensor(data["faces"].astype(int))
return TriangleMesh.from_tensors(vertices, faces)
def compute_adjacency_matrix_sparse(self):
r"""Calcualtes a sparse adjacency matrix for a mess
Returns:
(torch.sparse.Tensor) : sparse adjacency matrix
Example:
>>> mesh = Mesh.from_obj('model.obj')
>>> adj_info = mesh.compute_adjacency_matrix_sparse()
>>> neighborhood_sum = torch.sparse.mm(adj_info, mesh.vertices)
"""
if self.adj is None:
v1 = self.faces[:, 0].view(-1, 1)
v2 = self.faces[:, 1].view(-1, 1)
v3 = self.faces[:, 2].view(-1, 1)
vert_len = self.vertices.shape[0]
identity_indices = torch.arange(vert_len).view(-1, 1).to(v1.device)
identity = torch.cat((identity_indices, identity_indices), dim=1).to(v1.device)
identity = torch.cat((identity, identity))
i_1 = torch.cat((v1, v2), dim=1)
i_2 = torch.cat((v1, v3), dim=1)
i_3 = torch.cat((v2, v1), dim=1)
i_4 = torch.cat((v2, v3), dim=1)
i_5 = torch.cat((v3, v2), dim=1)
i_6 = torch.cat((v3, v1), dim=1)
indices = torch.cat((identity, i_1, i_2, i_3, i_4, i_5, i_6), dim=0).t()
values = torch.ones(indices.shape[1]).to(indices.device) * 0.5
self.adj = torch.sparse.FloatTensor(indices, values, torch.Size([vert_len, vert_len]))
return self.adj.clone()
from .Mesh import *
from .TriangleMesh import *
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Several helper functions, for internal use in Kaolin."""
import torch
import hashlib
from pathlib import Path
from typing import Callable
import numpy as np
def _composedecorator(*decs):
"""Returns a composition of several decorators.
Source: https://stackoverflow.com/a/5409569
Usage::
@composedec(decorator1, decorator2)
def func_that_needs_decoration(args):
pass
is equavalent to::
@decorator1
@decorator2
def func_that_needs_decoration(args):
pass
"""
def deco(f):
for dec in reversed(decs):
f = dec(f)
return f
return deco
def _normalize_zerosafe(matrix: torch.Tensor):
"""Normalizes each row of a matrix in a 'division by zero'-safe way.
Args:
matrix (torch.tensor): Matrix where each row contains a vector
to be normalized
"""
assert matrix.dim() == 2, "Need matrix to contain exactly 2 dimensions"
magnitude = torch.sqrt(torch.sum(torch.pow(matrix, 2), dim=1))
valid_inds = magnitude > 0
matrix[valid_inds] = torch.div(matrix[valid_inds], magnitude[valid_inds].unsqueeze(1))
return matrix
def _assert_tensor(inp):
"""Asserts that the input is of type torch.Tensor."""
if not torch.is_tensor(inp):
raise TypeError("Expected input to be of type torch.Tensor." " Got {0} instead".format(type(inp)))
def _assert_dim_gt(inp, tgt):
"""Asserts that the number of dims in inp is greater than the value
sepecified in tgt.
Args:
inp (torch.Tensor): Input tensor, whose number of dimensions is
to be compared.
tgt (int): Value which the number of dims of inp should exceed.
"""
if inp.dim() <= tgt:
raise ValueError("Expected input to contain more than {0} dims. " "Got {1} instead.".format(tgt, inp.dim()))
def _assert_dim_lt(inp, tgt):
"""Asserts that the number of dims in inp is less than the value sepecified
in tgt.
Args:
inp (torch.Tensor): Input tensor, whose number of dimensions is
to be compared.
tgt (int): Value which the number of dims of inp should be less than.
"""
if not inp.dim() >= tgt:
raise ValueError("Expected input to contain less than {0} dims. " "Got {1} instead.".format(tgt, inp.dim()))
def _assert_dim_ge(inp, tgt):
"""Asserts that the number of dims in inp is greater than or equal to the
value sepecified in tgt.
Args:
inp (torch.Tensor): Input tensor, whose number of dimensions is
to be compared.
tgt (int): Value which the number of dims of inp should exceed.
"""
if inp.dim() < tgt:
raise ValueError("Expected input to contain at least {0} dims. " "Got {1} instead.".format(tgt, inp.dim()))
def _assert_dim_le(inp, tgt):
"""Asserts that the number of dims in inp is less than or equal to the
value sepecified in tgt.
Args:
inp (torch.Tensor): Input tensor, whose number of dimensions is
to be compared.
tgt (int): Value which the number of dims of inp should not exceed.
"""
if inp.dim() > tgt:
raise ValueError("Expected input to contain at most {0} dims. " "Got {1} instead.".format(tgt, inp.dim()))
def _assert_dim_eq(inp, tgt):
"""Asserts that the number of dims in inp is exactly equal to the value
sepecified in tgt.
Args:
inp (torch.Tensor): Input tensor, whose number of dimensions is
to be compared.
tgt (int): Value which the number of dims of inp should equal.
"""
if inp.dim() != tgt:
raise ValueError("Expected input to contain exactly {0} dims. " "Got {1} instead.".format(tgt, inp.dim()))
def _assert_shape_eq(inp, tgt_shape, dim=None):
"""Asserts that the shape of tensor `inp` is equal to the tuple `tgt_shape`
along dimension `dim`.
If `dim` is None, shapes along all dimensions must be equal.
"""
if dim is None:
if inp.shape != tgt_shape:
raise ValueError(
"Size mismatch. Input and target have different " "shapes: {0} vs {1}.".format(inp.shape, tgt_shape)
)
else:
if inp.shape[dim] != tgt_shape[dim]:
raise ValueError(
"Size mismatch. Input and target have different "
"shapes at dimension {2}: {0} vs {1}.".format(inp.shape[dim], tgt_shape[dim], dim)
)
def _assert_gt(inp, val):
"""Asserts that all elements in tensor `inp` are greater than value
`val`."""
if not (inp > val).all():
raise ValueError("Each element of input must be greater " "than {0}.".format(val))
def _get_hash(x):
"""Generate a hash from a string, or dictionary."""
if isinstance(x, dict):
x = tuple(sorted(pair for pair in x.items()))
return hashlib.md5(bytes(repr(x), "utf-8")).hexdigest()
class Cache(object):
"""Caches the results of a function to disk.
If already cached, data is returned from disk, otherwise,
the function is executed. Output tensors are always on CPU device.
Args:
transforms (Iterable): List of transforms to compose.
cache_dir (str): Directory where objects will be cached. Default
to 'cache'.
"""
def __init__(self, func: Callable, cache_dir: [str, Path], cache_key: str):
self.func = func
self.cache_dir = Path(cache_dir) / str(cache_key)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.cached_ids = [p.stem for p in self.cache_dir.glob("*")]
def __call__(self, unique_id: str, *args, **kwargs):
"""Execute self.func if not cached, otherwise, read data from disk.
Args:
unique_id (str): The unique id with which to name the cached file.
**kwargs: The arguments to be passed to self.func.
Returns:
dict of {str: torch.Tensor}: Dictionary of tensors.
"""
fpath = self.cache_dir / f"{unique_id}.p"
if not fpath.exists():
output = self.func(*args, **kwargs)
self._write(output, fpath)
self.cached_ids.append(unique_id)
else:
output = self._read(fpath)
# Read file to move tensors to CPU.
return self._read(fpath)
def _write(self, x, fpath):
torch.save(x, fpath)
def _read(self, fpath):
return torch.load(fpath, map_location="cpu")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment