Source code for data_juicer.ops.mapper.video_hand_reconstruction_mapper

import os
import subprocess
import sys

from pydantic import PositiveInt

import data_juicer
from data_juicer.ops.load import load_ops
from data_juicer.utils.cache_utils import DATA_JUICER_ASSETS_CACHE
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.mm_utils import SpecialTokens
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Mapper
from ..op_fusion import LOADED_VIDEOS

OP_NAME = "video_hand_reconstruction_mapper"


# LazyLoader.check_packages(["numpy==1.26"])
# To output visual overlay images, it is necessary to install pyrender.
# Note that pyrender requires numpy==1.26 to correctly generate rendering results.

numpy = LazyLoader("numpy")
cv2 = LazyLoader("cv2", "opencv-python")
torch = LazyLoader("torch")


[docs] @OPERATORS.register_module(OP_NAME) @LOADED_VIDEOS.register_module(OP_NAME) class VideoHandReconstructionMapper(Mapper): """Use the WiLoR model for hand localization and reconstruction.""" _accelerator = "cuda"
[docs] def __init__( self, wilor_model_path: str = "wilor_final.ckpt", wilor_model_config: str = "model_config.yaml", detector_model_path: str = "detector.pt", mano_right_path: str = "path_to_mano_right_pkl", frame_num: PositiveInt = 3, duration: float = 0, batch_size: int = 16, tag_field_name: str = MetaKeys.hand_reconstruction_tags, frame_dir: str = DATA_JUICER_ASSETS_CACHE, if_save_visualization: bool = True, save_visualization_dir: str = DATA_JUICER_ASSETS_CACHE, if_save_mesh: bool = True, save_mesh_dir: str = DATA_JUICER_ASSETS_CACHE, *args, **kwargs, ): """ Initialization method. :param wilor_model_path: The path to 'wilor_final.ckpt'. :param wilor_model_config: The path to 'model_config.yaml' for the WiLOR model. :param detector_model_path: The path to 'detector.pt' for the WiLOR model. :param mano_right_path: The path to 'MANO_RIGHT.pkl'. Users need to download this file from https://mano.is.tue.mpg.de/ and comply with the MANO license. :param frame_num: The number of frames to be extracted uniformly from the video. If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. If "duration" > 0, frame_num is the number of frames per segment. :param duration: The duration of each segment in seconds. If 0, frames are extracted from the entire video. If duration > 0, the video is segmented into multiple segments based on duration, and frames are extracted from each segment. :param batch_size: Batch size for simultaneous hand inference. :param tag_field_name: The field name to store the tags. It's "hand_reconstruction_tags" in default. :param frame_dir: Output directory to save extracted frames. :param if_save_visualization: Whether to save overlay images. :param save_visualization_dir: The path for saving overlay images. :param if_save_mesh: Whether to save images of the hand mesh. :param save_mesh_dir: The path for saving images of the hand mesh. :param args: extra args :param kwargs: extra args """ super().__init__(*args, **kwargs) LazyLoader.check_packages(["chumpy @ git+https://github.com/mattloper/chumpy"]) LazyLoader.check_packages(["smplx==0.1.28", "yacs", "timm", "pyrender", "pytorch_lightning"]) LazyLoader.check_packages(["scikit-image"], pip_args=["--no-deps"]) self.video_extract_frames_mapper_args = { "frame_sampling_method": "uniform", "frame_num": frame_num, "duration": duration, "frame_dir": frame_dir, "frame_key": MetaKeys.video_frames, } self.fused_ops = load_ops([{"video_extract_frames_mapper": self.video_extract_frames_mapper_args}]) self.model_key = prepare_model( model_type="wilor", wilor_model_path=wilor_model_path, wilor_model_config=wilor_model_config, detector_model_path=detector_model_path, mano_right_path=mano_right_path, ) wilor_repo_path = os.path.join(DATA_JUICER_ASSETS_CACHE, "WiLoR") if not os.path.exists(wilor_repo_path): subprocess.run(["git", "clone", "https://github.com/rolpotamias/WiLoR.git", wilor_repo_path], check=True) sys.path.append(wilor_repo_path) from wilor.datasets.vitdet_dataset import ViTDetDataset from wilor.utils import recursive_to from wilor.utils.renderer import cam_crop_to_full self.ViTDetDataset = ViTDetDataset self.cam_crop_to_full = cam_crop_to_full self.recursive_to = recursive_to self.frame_num = frame_num self.duration = duration self.batch_size = batch_size self.tag_field_name = tag_field_name self.frame_dir = frame_dir self.if_save_visualization = if_save_visualization self.save_visualization_dir = save_visualization_dir self.if_save_mesh = if_save_mesh self.save_mesh_dir = save_mesh_dir
[docs] def project_full_img(self, points, cam_trans, focal_length, img_res): camera_center = [img_res[0] / 2.0, img_res[1] / 2.0] K = torch.eye(3) K[0, 0] = focal_length K[1, 1] = focal_length K[0, 2] = camera_center[0] K[1, 2] = camera_center[1] points = points + cam_trans points = points / points[..., -1:] V_2d = (K @ points.T).T return V_2d[..., :-1]
[docs] def process_single(self, sample=None, rank=None): # check if it's generated already if self.tag_field_name in sample[Fields.meta]: return sample # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: return [] # load videos ds_list = [{"text": SpecialTokens.video, "videos": sample[self.video_key]}] dataset = data_juicer.core.data.NestedDataset.from_list(ds_list) dataset = self.fused_ops[0].run(dataset) temp_frame_name = os.path.splitext(os.path.basename(sample[self.video_key][0]))[0] frames_root = os.path.join(self.frame_dir, temp_frame_name) frame_names = os.listdir(frames_root) frames_path = sorted([os.path.join(frames_root, frame_name) for frame_name in frame_names]) wilor_model, detector, model_cfg, renderer = get_model(self.model_key, rank, self.use_cuda()) if rank is not None: device = f"cuda:{rank}" if self.use_cuda() else "cpu" else: device = "cuda" if self.use_cuda() else "cpu" if self.if_save_visualization: visualization_frame_dir = os.path.join(self.save_visualization_dir, temp_frame_name) os.makedirs(visualization_frame_dir, exist_ok=True) if self.if_save_mesh: mesh_frame_dir = os.path.join(self.save_mesh_dir, temp_frame_name) os.makedirs(mesh_frame_dir, exist_ok=True) final_all_verts = [] final_all_cam_t = [] final_all_right = [] final_all_joints = [] final_all_kpts = [] for img_path in frames_path: img_cv2 = cv2.imread(img_path) detections = detector(img_cv2, conf=0.3, verbose=False)[0] bboxes = [] is_right = [] for det in detections: Bbox = det.boxes.data.cpu().detach().squeeze().numpy() is_right.append(det.boxes.cls.cpu().detach().squeeze().item()) bboxes.append(Bbox[:4].tolist()) if len(bboxes) == 0: final_all_verts.append([]) final_all_cam_t.append([]) final_all_right.append([]) final_all_joints.append([]) final_all_kpts.append([]) continue boxes = numpy.stack(bboxes) right = numpy.stack(is_right) dataset = self.ViTDetDataset(model_cfg, img_cv2, boxes, right, rescale_factor=2.0) dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0) all_verts = [] all_cam_t = [] all_right = [] all_joints = [] all_kpts = [] for batch in dataloader: batch = self.recursive_to(batch, device) with torch.no_grad(): out = wilor_model(batch) multiplier = 2 * batch["right"] - 1 pred_cam = out["pred_cam"] pred_cam[:, 1] = multiplier * pred_cam[:, 1] box_center = batch["box_center"].float() box_size = batch["box_size"].float() img_size = batch["img_size"].float() scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max() pred_cam_t_full = ( self.cam_crop_to_full(pred_cam, box_center, box_size, img_size, scaled_focal_length) .detach() .cpu() .numpy() ) # Render the result batch_size = batch["img"].shape[0] for n in range(batch_size): # Get filename from path img_path img_fn, _ = os.path.splitext(os.path.basename(img_path)) verts = out["pred_vertices"][n].detach().cpu().numpy() joints = out["pred_keypoints_3d"][n].detach().cpu().numpy() is_right = batch["right"][n].cpu().numpy() verts[:, 0] = (2 * is_right - 1) * verts[:, 0] joints[:, 0] = (2 * is_right - 1) * joints[:, 0] cam_t = pred_cam_t_full[n] kpts_2d = self.project_full_img(verts, cam_t, scaled_focal_length, img_size[n]) all_verts.append(verts) all_cam_t.append(cam_t) all_right.append(is_right) all_joints.append(joints) all_kpts.append(kpts_2d) # Save all meshes to disk if self.if_save_mesh: camera_translation = cam_t.copy() tmesh = renderer.vertices_to_trimesh( verts, camera_translation, (0.25098039, 0.274117647, 0.65882353), is_right=is_right ) tmesh.export(os.path.join(mesh_frame_dir, f"{img_fn}_{n}.obj")) final_all_verts.append(all_verts) final_all_cam_t.append(all_cam_t) final_all_right.append(all_right) final_all_joints.append(all_joints) final_all_kpts.append(all_kpts) # Render front view if self.if_save_visualization: if len(all_verts) > 0: misc_args = dict( mesh_base_color=(0.25098039, 0.274117647, 0.65882353), scene_bg_color=(1, 1, 1), focal_length=scaled_focal_length, ) cam_view = renderer.render_rgba_multiple( all_verts, cam_t=all_cam_t, render_res=img_size[n], is_right=all_right, **misc_args ) # Overlay image input_img = img_cv2.astype(numpy.float32)[:, :, ::-1] / 255.0 input_img = numpy.concatenate( [input_img, numpy.ones_like(input_img[:, :, :1])], axis=2 ) # Add alpha channel input_img_overlay = ( input_img[:, :, :3] * (1 - cam_view[:, :, 3:]) + cam_view[:, :, :3] * cam_view[:, :, 3:] ) cv2.imwrite( os.path.join(visualization_frame_dir, f"{img_fn}.jpg"), 255 * input_img_overlay[:, :, ::-1] ) sample[Fields.meta][self.tag_field_name] = { "vertices": final_all_verts, "camera_translation": final_all_cam_t, "if_right_hand": final_all_right, "joints": final_all_joints, "keypoints": final_all_kpts, } return sample