Source code for data_juicer.ops.common.mano_func

from typing import Optional

from data_juicer.utils.lazy_loader import LazyLoader

smplx = LazyLoader("smplx")
torch = LazyLoader("torch")


[docs] class MANO(smplx.MANOLayer):
[docs] def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs): """ Extension of the official MANO implementation to support more joints. Args: Same as MANOLayer. joint_regressor_extra (str): Path to extra joint regressor. """ super(MANO, self).__init__(*args, **kwargs) mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20] # 2, 3, 5, 4, 1 if joint_regressor_extra is not None: import pickle self.register_buffer( "joint_regressor_extra", torch.tensor(pickle.load(open(joint_regressor_extra, "rb"), encoding="latin1"), dtype=torch.float32), ) self.register_buffer( "extra_joints_idxs", smplx.utils.to_tensor(list(smplx.vertex_ids.vertex_ids["mano"].values()), dtype=torch.long), ) self.register_buffer("joint_map", torch.tensor(mano_to_openpose, dtype=torch.long))
[docs] def forward(self, *args, **kwargs) -> smplx.utils.MANOOutput: """ Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified. """ mano_output = super(MANO, self).forward(*args, **kwargs) extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs) joints = torch.cat([mano_output.joints, extra_joints], dim=1) joints = joints[:, self.joint_map, :] if hasattr(self, "joint_regressor_extra"): extra_joints = smplx.lbs.vertices2joints(self.joint_regressor_extra, mano_output.vertices) joints = torch.cat([joints, extra_joints], dim=1) mano_output.joints = joints return mano_output
[docs] def query(self, hmr_output): batch_size = hmr_output["pred_rotmat"].shape[0] pred_rotmat = hmr_output["pred_rotmat"].reshape(batch_size, -1, 3, 3) pred_shape = hmr_output["pred_shape"].reshape(batch_size, 10) mano_output = self( global_orient=pred_rotmat[:, [0]], hand_pose=pred_rotmat[:, 1:], betas=pred_shape, pose2rot=False ) return mano_output