data_juicer.ops.mapper.video_clip_reassembly_mapper 源代码

import hashlib

import cv2
import numpy as np
from loguru import logger

from data_juicer.utils.constant import CameraCalibrationKeys, Fields, MetaKeys
from data_juicer.utils.file_utils import load_numpy

from ..base_op import OPERATORS, Mapper

OP_NAME = "video_clip_reassembly_mapper"


[文档] @OPERATORS.register_module(OP_NAME) class VideoClipReassemblyMapper(Mapper): """Reassemble hand-action results from overlapping video clips. When long videos are chopped into overlapping clips (e.g. 5 s with 2 s overlap via ``VideoSplitByDurationMapper``), each clip is processed independently through the 3-D motion labelling pipeline. This operator merges the per-clip results back into **one unified result** per original video, including: * ``hand_action_tags`` — states, actions, valid_frame_ids, joints * ``video_camera_pose_tags`` — ``cam_c2w`` array * ``hand_reconstruction_hawor_tags`` — frame_ids converted to global * ``video_frames`` — per-clip frame path lists merged into one global list * ``camera_calibration_moge_tags`` — per-clip depth/intrinsics merged * ``clips`` — replaced with the original video path Clip global offsets are determined automatically by **pixel-matching** overlapping frames between consecutive clips, rather than assuming an ideal step size. This handles ffmpeg keyframe-alignment drift that causes actual clip boundaries to differ from the nominal ``(split_duration - overlap_duration) * fps`` calculation. Reference (paper §3.1): "To enhance efficiency, we chop long videos into overlapping 20-second clips in this stage and recompose their results." """
[文档] def __init__( self, hand_action_field: str = MetaKeys.hand_action_tags, camera_pose_field: str = MetaKeys.video_camera_pose_tags, hand_reconstruction_field: str = (MetaKeys.hand_reconstruction_hawor_tags), frame_field: str = MetaKeys.video_frames, moge_field: str = MetaKeys.camera_calibration_moge_tags, clip_field: str = "clips", video_key: str = "videos", split_duration: float = None, overlap_duration: float = None, fps: float = None, *args, **kwargs, ): super().__init__(*args, **kwargs) self.hand_action_field = hand_action_field self.camera_pose_field = camera_pose_field self.hand_reconstruction_field = hand_reconstruction_field self.frame_field = frame_field self.moge_field = moge_field self.clip_field = clip_field self.video_key = video_key self.split_duration = split_duration self.overlap_duration = overlap_duration self.fps = fps
# ------------------------------------------------------------------ # Detect actual clip offsets via frame content matching # ------------------------------------------------------------------ @staticmethod def _frame_hash(path: str) -> str: """Compute a fast content hash for a frame image file.""" img = cv2.imread(path) if img is None: return None return hashlib.md5(img.tobytes()).hexdigest() @classmethod def _detect_clip_offsets( cls, per_clip_frames: list[list[str]], nominal_step: int = None, ) -> list[int]: """Determine the global frame offset for each clip. Compares the first frame of clip[i] against frames of clip[i-1] to find the actual overlap point. Falls back to the nominal step if pixel matching fails. Returns: List of global offsets, one per clip. offsets[0] is always 0. """ n_clips = len(per_clip_frames) offsets = [0] for ci in range(1, n_clips): prev_frames = per_clip_frames[ci - 1] curr_frames = per_clip_frames[ci] if not curr_frames or not prev_frames: step = nominal_step or len(prev_frames) offsets.append(offsets[-1] + step) continue # Hash the first frame of the current clip h_curr_0 = cls._frame_hash(curr_frames[0]) if h_curr_0 is None: step = nominal_step or len(prev_frames) offsets.append(offsets[-1] + step) continue # Search for a match in the previous clip # Start from a reasonable range around the nominal step search_start = max(0, (nominal_step or len(prev_frames)) - 30) search_end = min(len(prev_frames), (nominal_step or len(prev_frames)) + 30) found = False for j in range(search_start, search_end): h_prev = cls._frame_hash(prev_frames[j]) if h_prev == h_curr_0: # Verify with a second frame to avoid hash collision if len(curr_frames) > 1 and j + 1 < len(prev_frames): h_c1 = cls._frame_hash(curr_frames[1]) h_p1 = cls._frame_hash(prev_frames[j + 1]) if h_c1 != h_p1: continue offsets.append(offsets[-1] + j) found = True logger.debug( f"Clip {ci}: detected offset {j} from clip {ci-1} " f"(global offset {offsets[-1]})", ) break if not found: step = nominal_step or len(prev_frames) offsets.append(offsets[-1] + step) logger.warning( f"Clip {ci}: frame matching failed, using nominal " f"step {step} (global offset {offsets[-1]})", ) return offsets # ------------------------------------------------------------------ # World-frame alignment between clips # ------------------------------------------------------------------ @staticmethod def _compute_alignment_transforms( cam_pose_list: list[dict], offsets: list[int], clip_lengths: list[int], ) -> list[np.ndarray]: """Compute 4x4 transforms to align each clip's world frame to clip 0. Uses cam_c2w matrices from the overlap region: T_0i = c2w_0[g] @ inv(c2w_i[local]) where g is a global frame index present in both clips. Returns: List of (4, 4) transforms. transforms[0] = identity. """ from scipy.spatial.transform import Rotation n_clips = len(cam_pose_list) transforms = [np.eye(4, dtype=np.float64)] for ci in range(1, n_clips): cp_prev = cam_pose_list[ci - 1] cp_curr = cam_pose_list[ci] if not cp_prev or not isinstance(cp_prev, dict) or not cp_curr or not isinstance(cp_curr, dict): transforms.append(transforms[-1].copy()) continue raw_prev = cp_prev.get(CameraCalibrationKeys.cam_c2w) raw_curr = cp_curr.get(CameraCalibrationKeys.cam_c2w) if raw_prev is None or raw_curr is None: transforms.append(transforms[-1].copy()) continue c2w_prev = np.asarray(load_numpy(raw_prev), dtype=np.float64) c2w_curr = np.asarray(load_numpy(raw_curr), dtype=np.float64) # Overlap: clip_curr[k] corresponds to clip_prev[offsets[ci] - offsets[ci-1] + k] step_in_prev = offsets[ci] - offsets[ci - 1] overlap_len = clip_lengths[ci - 1] - step_in_prev if overlap_len <= 0: transforms.append(transforms[-1].copy()) continue # Compute T for each overlap frame, then average Rs = [] ts = [] for k in range(min(overlap_len, len(c2w_curr))): prev_idx = step_in_prev + k if prev_idx >= len(c2w_prev): break T_local = c2w_prev[prev_idx] @ np.linalg.inv(c2w_curr[k]) Rs.append(T_local[:3, :3]) ts.append(T_local[:3, 3]) if not Rs: transforms.append(transforms[-1].copy()) continue # Robust average: median translation, mean quaternion rotation t_median = np.median(np.array(ts), axis=0) quats = Rotation.from_matrix(np.array(Rs)).as_quat() for j in range(1, len(quats)): if np.dot(quats[j], quats[j - 1]) < 0: quats[j] = -quats[j] mean_quat = np.mean(quats, axis=0) mean_quat /= np.linalg.norm(mean_quat) R_mean = Rotation.from_quat(mean_quat).as_matrix() # This gives T: prev_world -> curr_world # Chain with the accumulated transform to get clip_0 world T_prev_curr = np.eye(4, dtype=np.float64) T_prev_curr[:3, :3] = R_mean T_prev_curr[:3, 3] = t_median transforms.append(transforms[ci - 1] @ T_prev_curr) logger.debug( f"Clip {ci} alignment: rotation " f"{np.degrees(Rotation.from_matrix(R_mean).magnitude()):.1f}°, " f"translation {np.linalg.norm(t_median):.4f}m", ) return transforms @staticmethod def _apply_transform_to_hand_data( hand_data: dict, T: np.ndarray, ) -> dict: """Transform a clip's hand states/joints from its local world frame to the target world frame using rigid transform T (4x4). """ from scipy.spatial.transform import Rotation if not hand_data or not hand_data.get("states"): return hand_data R = T[:3, :3] t = T[:3, 3] R_rot = Rotation.from_matrix(R) states = np.asarray(hand_data["states"], dtype=np.float64) # Transform positions states[:, 0:3] = (R @ states[:, 0:3].T).T + t # Transform orientations orig_rots = Rotation.from_euler("xyz", states[:, 3:6], degrees=False) new_rots = R_rot * orig_rots states[:, 3:6] = new_rots.as_euler("xyz", degrees=False) result = dict(hand_data) result["states"] = states.tolist() # Transform joints_world jw = hand_data.get("joints_world") if jw and len(jw) > 0: jw_arr = np.asarray(jw, dtype=np.float64) # (T, 21, 3) -> transform each joint orig_shape = jw_arr.shape flat = jw_arr.reshape(-1, 3) flat_aligned = (R @ flat.T).T + t result["joints_world"] = flat_aligned.reshape( orig_shape, ).tolist() # Recompute actions from transformed states from data_juicer.ops.mapper.video_hand_motion_smooth_mapper import ( _recompute_actions, ) result["actions"] = _recompute_actions(states).tolist() return result @staticmethod def _apply_transform_to_c2w( c2w: np.ndarray, T: np.ndarray, ) -> np.ndarray: """Transform cam_c2w array from local world to target world frame.""" # c2w maps camera -> local_world # T maps local_world -> target_world # new_c2w = T @ c2w return np.einsum("ij,njk->nik", T, c2w) # ------------------------------------------------------------------ # helpers # ------------------------------------------------------------------ @staticmethod def _empty_hand_result(hand_type: str) -> dict: return { "hand_type": hand_type, "states": [], "actions": [], "valid_frame_ids": [], "joints_cam": [], "joints_world": [], } @staticmethod def _recompute_actions(states: np.ndarray) -> np.ndarray: """Recompute 7-DoF actions from 8-dim states.""" from scipy.spatial.transform import Rotation T = len(states) actions = np.zeros((T, 7), dtype=np.float32) for t in range(T - 1): actions[t, 0:3] = states[t + 1, 0:3] - states[t, 0:3] R_prev = Rotation.from_euler( "xyz", states[t, 3:6], degrees=False, ) R_next = Rotation.from_euler( "xyz", states[t + 1, 3:6], degrees=False, ) R_delta = R_next * R_prev.inv() actions[t, 3:6] = R_delta.as_euler("xyz", degrees=False) actions[t, 6] = states[t + 1, 7] if T > 0: actions[T - 1, 6] = states[T - 1, 7] return actions def _compute_nominal_step(self) -> int: """Compute the nominal step from constructor params (fallback).""" if self.split_duration and self.overlap_duration and self.fps: return int( (self.split_duration - self.overlap_duration) * self.fps, ) return None def _blend_weight( self, clip_idx: int, local_fid: int, n_clips: int, clip_len: int, overlap_prev: int, overlap_next: int, ) -> float: """Compute the blending weight for a frame given its clip position. Args: overlap_prev: number of frames this clip overlaps with the previous clip (ramp-up at start). overlap_next: number of frames this clip overlaps with the next clip (ramp-down at end). """ w = 1.0 if clip_idx > 0 and overlap_prev > 0 and local_fid < overlap_prev: w = (local_fid + 1) / (overlap_prev + 1) if clip_idx < n_clips - 1 and overlap_next > 0 and local_fid >= clip_len - overlap_next: frames_from_end = clip_len - local_fid w_end = frames_from_end / (overlap_next + 1) w = min(w, w_end) return w # ------------------------------------------------------------------ # video_frames merge # ------------------------------------------------------------------ @staticmethod def _merge_video_frames( per_clip_frames: list[list[str]], offsets: list[int], ) -> list[str]: """Merge per-clip frame path lists into one global ordered list.""" total_frames = 0 for ci, clip_frames in enumerate(per_clip_frames): end = offsets[ci] + len(clip_frames) if end > total_frames: total_frames = end merged = [None] * total_frames for ci, clip_frames in enumerate(per_clip_frames): offset = offsets[ci] for local_fid, frame_path in enumerate(clip_frames): gfid = offset + local_fid if gfid < total_frames and merged[gfid] is None: merged[gfid] = frame_path # Fill any remaining None slots for i in range(len(merged)): if merged[i] is None: for delta in range(1, len(merged)): if i - delta >= 0 and merged[i - delta] is not None: merged[i] = merged[i - delta] break if i + delta < len(merged) and merged[i + delta] is not None: merged[i] = merged[i + delta] break return merged # ------------------------------------------------------------------ # moge calibration merge # ------------------------------------------------------------------ @staticmethod def _merge_moge( moge_list: list[dict], offsets: list[int], ) -> dict: """Merge per-clip MoGe calibration results into one global result.""" total_frames = 0 for ci, m in enumerate(moge_list): if not m or not isinstance(m, dict): continue for k in ("depth", "hfov", "intrinsics", "vfov"): v = m.get(k) if isinstance(v, list) and len(v) > 0: end = offsets[ci] + len(v) if end > total_frames: total_frames = end break if total_frames == 0: return moge_list[0] if moge_list else {} per_frame_keys = set() scalar_fields = {} for m in moge_list: if not m or not isinstance(m, dict): continue for k, v in m.items(): if isinstance(v, list) and len(v) > 1: per_frame_keys.add(k) elif k not in scalar_fields: scalar_fields[k] = v merged = dict(scalar_fields) for key in per_frame_keys: arr = [None] * total_frames for ci, m in enumerate(moge_list): if not m or not isinstance(m, dict): continue vals = m.get(key) if not isinstance(vals, list): continue offset = offsets[ci] for local_fid, val in enumerate(vals): gfid = offset + local_fid if gfid < total_frames and arr[gfid] is None: arr[gfid] = val for i in range(len(arr)): if arr[i] is None: for delta in range(1, len(arr)): if i - delta >= 0 and arr[i - delta] is not None: arr[i] = arr[i - delta] break if i + delta < len(arr) and arr[i + delta] is not None: arr[i] = arr[i + delta] break merged[key] = arr return merged # ------------------------------------------------------------------ # hand action merge # ------------------------------------------------------------------ def _merge_hand_across_clips( self, clips_hand_data: list, hand_type: str, n_clips: int, offsets: list[int], clip_lengths: list[int], ) -> dict: """Merge one hand's data across all clips into a single trajectory.""" clip_entries = [] for clip_idx, hand_data in enumerate(clips_hand_data): if not hand_data or not hand_data.get("states"): continue global_offset = offsets[clip_idx] local_ids = hand_data["valid_frame_ids"] global_ids = [fid + global_offset for fid in local_ids] jw = hand_data.get("joints_world") jc = hand_data.get("joints_cam") clip_entries.append( { "clip_idx": clip_idx, "local_ids": local_ids, "global_ids": global_ids, "states": np.asarray(hand_data["states"], dtype=np.float64), "joints_world": (np.asarray(jw, dtype=np.float64) if jw and len(jw) > 0 else None), "joints_cam": (np.asarray(jc, dtype=np.float64) if jc and len(jc) > 0 else None), } ) if not clip_entries: return self._empty_hand_result(hand_type) if len(clip_entries) == 1: e = clip_entries[0] src = clips_hand_data[e["clip_idx"]] return { "hand_type": hand_type, "states": src["states"], "actions": src["actions"], "valid_frame_ids": e["global_ids"], "joints_cam": src.get("joints_cam", []), "joints_world": src.get("joints_world", []), } # Global frame range all_gids = [] for e in clip_entries: all_gids.extend(e["global_ids"]) min_fid = min(all_gids) max_fid = max(all_gids) n_total = max_fid - min_fid + 1 state_sum = np.zeros((n_total, 8), dtype=np.float64) weight_sum = np.zeros(n_total, dtype=np.float64) has_jw = any(e["joints_world"] is not None for e in clip_entries) has_jc = any(e["joints_cam"] is not None for e in clip_entries) jw_sum = np.zeros((n_total, 21, 3), dtype=np.float64) if has_jw else None jc_sum = np.zeros((n_total, 21, 3), dtype=np.float64) if has_jc else None for entry in clip_entries: ci = entry["clip_idx"] clip_len = clip_lengths[ci] # Compute overlap with previous clip if ci > 0: prev_end = offsets[ci - 1] + clip_lengths[ci - 1] overlap_prev = max(0, prev_end - offsets[ci]) else: overlap_prev = 0 # Compute overlap with next clip if ci < n_clips - 1: next_offset = offsets[ci + 1] this_end = offsets[ci] + clip_len overlap_next = max(0, this_end - next_offset) else: overlap_next = 0 for i, gfid in enumerate(entry["global_ids"]): local_fid = entry["local_ids"][i] idx = gfid - min_fid w = self._blend_weight( ci, local_fid, n_clips, clip_len, overlap_prev, overlap_next, ) state_sum[idx] += entry["states"][i] * w weight_sum[idx] += w if has_jw and entry["joints_world"] is not None and i < len(entry["joints_world"]): jw_sum[idx] += entry["joints_world"][i] * w if has_jc and entry["joints_cam"] is not None and i < len(entry["joints_cam"]): jc_sum[idx] += entry["joints_cam"][i] * w valid_mask = weight_sum > 1e-8 valid_idx = np.where(valid_mask)[0] if len(valid_idx) == 0: return self._empty_hand_result(hand_type) w_col = weight_sum[valid_idx, np.newaxis] merged_states = state_sum[valid_idx] / w_col merged_fids = (valid_idx + min_fid).tolist() merged_jw = None if has_jw: merged_jw = (jw_sum[valid_idx] / weight_sum[valid_idx, np.newaxis, np.newaxis]).tolist() merged_jc = None if has_jc: merged_jc = (jc_sum[valid_idx] / weight_sum[valid_idx, np.newaxis, np.newaxis]).tolist() actions = self._recompute_actions(merged_states) return { "hand_type": hand_type, "states": merged_states.astype(np.float32).tolist(), "actions": actions.tolist(), "valid_frame_ids": merged_fids, "joints_cam": merged_jc if merged_jc else [], "joints_world": merged_jw if merged_jw else [], } # ------------------------------------------------------------------ # camera pose (cam_c2w) merge # ------------------------------------------------------------------ def _merge_cam_c2w( self, cam_pose_list: list[dict], offsets: list[int], clip_lengths: list[int], ) -> dict: """Merge per-clip cam_c2w (N,4,4) arrays into a single global array.""" n_clips = len(cam_pose_list) clip_c2ws: list[tuple[int, np.ndarray]] = [] for ci, cp in enumerate(cam_pose_list): if not cp or not isinstance(cp, dict): continue raw = cp.get(CameraCalibrationKeys.cam_c2w) if raw is None: continue arr = np.asarray(load_numpy(raw), dtype=np.float64) clip_c2ws.append((ci, arr)) if not clip_c2ws: return cam_pose_list[0] if cam_pose_list else {} max_global = 0 for ci, arr in clip_c2ws: end = offsets[ci] + len(arr) if end > max_global: max_global = end c2w_sum = np.zeros((max_global, 4, 4), dtype=np.float64) w_sum = np.zeros(max_global, dtype=np.float64) for ci, arr in clip_c2ws: offset = offsets[ci] clip_len = clip_lengths[ci] # Compute overlap with previous clip if ci > 0: prev_end = offsets[ci - 1] + clip_lengths[ci - 1] overlap_prev = max(0, prev_end - offset) else: overlap_prev = 0 # Compute overlap with next clip if ci < n_clips - 1: next_offset = offsets[ci + 1] this_end = offset + clip_len overlap_next = max(0, this_end - next_offset) else: overlap_next = 0 for local_fid in range(len(arr)): gfid = offset + local_fid w = self._blend_weight( ci, local_fid, n_clips, clip_len, overlap_prev, overlap_next, ) c2w_sum[gfid] += arr[local_fid] * w w_sum[gfid] += w valid = w_sum > 1e-8 for gfid in range(max_global): if valid[gfid]: c2w_sum[gfid] /= w_sum[gfid] else: c2w_sum[gfid] = np.eye(4) merged: dict = {} for cp in cam_pose_list: if cp and isinstance(cp, dict): for k, v in cp.items(): if k != CameraCalibrationKeys.cam_c2w and k not in merged: merged[k] = v break merged[CameraCalibrationKeys.cam_c2w] = c2w_sum.tolist() return merged # ------------------------------------------------------------------ # hawor reconstruction merge # ------------------------------------------------------------------ def _merge_hawor( self, hawor_list: list[dict], offsets: list[int], ) -> dict: """Merge per-clip HaWoR results: convert local frame_ids to global.""" n_clips = len(hawor_list) merged: dict = {} for hw in hawor_list: if hw and isinstance(hw, dict): for k in ("fov_x", "img_focal"): if k in hw and k not in merged: merged[k] = hw[k] break hand_types: set[str] = set() for hw in hawor_list: if hw and isinstance(hw, dict): for k in ("left", "right"): if k in hw and isinstance(hw[k], dict): hand_types.add(k) for ht in sorted(hand_types): seen_global: set[int] = set() merged_fids: list[int] = [] merged_transl: list = [] merged_orient: list = [] merged_pose: list = [] merged_betas: list = [] merged_joints_cam: list = [] for ci in range(n_clips): hw = hawor_list[ci] if not hw or not isinstance(hw, dict): continue hand = hw.get(ht, {}) if not hand or not isinstance(hand, dict): continue fids = hand.get("frame_ids", []) offset = offsets[ci] transl = hand.get("transl", []) orient = hand.get("global_orient", []) pose = hand.get("hand_pose", []) betas = hand.get("betas", []) jc = hand.get("joints_cam", None) for i, local_fid in enumerate(fids): gfid = local_fid + offset if gfid in seen_global: continue seen_global.add(gfid) merged_fids.append(gfid) if i < len(transl): merged_transl.append(transl[i]) if i < len(orient): merged_orient.append(orient[i]) if i < len(pose): merged_pose.append(pose[i]) if i < len(betas): merged_betas.append(betas[i]) if jc is not None and i < len(jc): merged_joints_cam.append(jc[i]) if merged_fids: order = np.argsort(merged_fids).tolist() merged_fids = [merged_fids[j] for j in order] merged_transl = [merged_transl[j] for j in order] if merged_transl else [] merged_orient = [merged_orient[j] for j in order] if merged_orient else [] merged_pose = [merged_pose[j] for j in order] if merged_pose else [] merged_betas = [merged_betas[j] for j in order] if merged_betas else [] merged_joints_cam = [merged_joints_cam[j] for j in order] if merged_joints_cam else None merged[ht] = { "frame_ids": merged_fids, "transl": merged_transl, "global_orient": merged_orient, "hand_pose": merged_pose, "betas": merged_betas, } if merged_joints_cam is not None: merged[ht]["joints_cam"] = merged_joints_cam return merged # ------------------------------------------------------------------ # main entry # ------------------------------------------------------------------
[文档] def process_single(self, sample=None, rank=None): if Fields.meta not in sample: return sample meta = sample[Fields.meta] hand_action_list = meta.get(self.hand_action_field) # --- detect actual clip offsets from frame content --- per_clip_frames = sample.get(self.frame_field) has_multi_clips = ( per_clip_frames and isinstance(per_clip_frames, list) and len(per_clip_frames) > 1 and isinstance(per_clip_frames[0], list) ) if not has_multi_clips: return sample n_clips = len(per_clip_frames) clip_lengths = [len(cf) for cf in per_clip_frames] nominal_step = self._compute_nominal_step() offsets = self._detect_clip_offsets(per_clip_frames, nominal_step) total_frames = max(off + clen for off, clen in zip(offsets, clip_lengths)) logger.info( f"Clip offsets: {offsets}, clip_lengths: {clip_lengths}, " f"total_frames: {total_frames}", ) # --- merge video_frames --- try: merged_frames = self._merge_video_frames( per_clip_frames, offsets, ) sample[self.frame_field] = [merged_frames] logger.debug( f"Merged {n_clips} clip frame lists into " f"{len(merged_frames)} global frames", ) except Exception as e: logger.warning(f"video_frames reassembly failed: {e}") # --- merge moge --- moge_list = meta.get(self.moge_field) if moge_list and isinstance(moge_list, list) and len(moge_list) > 1: try: merged_moge = self._merge_moge(moge_list, offsets) meta[self.moge_field] = [merged_moge] except Exception as e: logger.warning(f"MoGe reassembly failed: {e}") # --- compute world-frame alignment transforms --- cam_pose_list = meta.get(self.camera_pose_field) align_transforms = None if cam_pose_list and len(cam_pose_list) > 1: try: align_transforms = self._compute_alignment_transforms( cam_pose_list, offsets, clip_lengths, ) except Exception as e: logger.warning(f"Alignment transform computation failed: {e}") # --- align hand actions to clip 0's world frame, then merge --- if hand_action_list and len(hand_action_list) > 1: if align_transforms: for ci in range(1, len(hand_action_list)): T = align_transforms[ci] if hand_action_list[ci] and not np.allclose(T, np.eye(4)): for ht in hand_action_list[ci]: try: hand_action_list[ci][ht] = self._apply_transform_to_hand_data( hand_action_list[ci][ht], T, ) except Exception as e: logger.warning( f"Alignment failed clip {ci} {ht}: {e}", ) # --- merge hand actions --- if hand_action_list and len(hand_action_list) > 1: hand_types: set[str] = set() for clip_result in hand_action_list: if clip_result: hand_types.update(clip_result.keys()) merged_result: dict = {} for ht in sorted(hand_types): per_clip = [(cr.get(ht) if cr else None) for cr in hand_action_list] try: merged_result[ht] = self._merge_hand_across_clips( per_clip, ht, n_clips, offsets, clip_lengths, ) except Exception as e: logger.warning( f"Hand '{ht}' reassembly failed: {e}. " f"Falling back to first clip.", ) first_valid = next( (d for d in per_clip if d and d.get("states")), None, ) merged_result[ht] = first_valid if first_valid else self._empty_hand_result(ht) meta[self.hand_action_field] = [merged_result] # --- align cam_c2w to clip 0's world frame, then merge --- cam_pose_list = meta.get(self.camera_pose_field) if cam_pose_list and len(cam_pose_list) > 1: # Apply alignment transforms to each clip's c2w before merging if align_transforms: for ci in range(1, len(cam_pose_list)): cp = cam_pose_list[ci] if not cp or not isinstance(cp, dict): continue raw = cp.get(CameraCalibrationKeys.cam_c2w) if raw is None: continue T = align_transforms[ci] if np.allclose(T, np.eye(4)): continue try: c2w_arr = np.asarray( load_numpy(raw), dtype=np.float64, ) aligned = self._apply_transform_to_c2w(c2w_arr, T) cam_pose_list[ci] = dict(cp) cam_pose_list[ci][CameraCalibrationKeys.cam_c2w] = aligned.tolist() except Exception as e: logger.warning( f"cam_c2w alignment failed clip {ci}: {e}", ) try: merged_cam = self._merge_cam_c2w( cam_pose_list, offsets, clip_lengths, ) meta[self.camera_pose_field] = [merged_cam] except Exception as e: logger.warning(f"cam_c2w reassembly failed: {e}") # --- merge hawor --- hawor_list = meta.get(self.hand_reconstruction_field) if hawor_list and len(hawor_list) > 1: try: merged_hawor = self._merge_hawor(hawor_list, offsets) meta[self.hand_reconstruction_field] = [merged_hawor] except Exception as e: logger.warning(f"HaWoR reassembly failed: {e}") # --- merge clips → original video --- clips = sample.get(self.clip_field) if clips and isinstance(clips, list) and len(clips) > 1: videos = sample.get(self.video_key) if videos and isinstance(videos, list) and len(videos) > 0: sample[self.clip_field] = videos else: sample[self.clip_field] = [clips[0]] return sample