Source code for data_juicer.ops.mapper.video_atomic_action_segment_mapper

import numpy as np
from loguru import logger

from data_juicer.utils.constant import Fields, MetaKeys

from ..base_op import OPERATORS, Mapper

OP_NAME = "video_atomic_action_segment_mapper"


[docs] @OPERATORS.register_module(OP_NAME) class VideoAtomicActionSegmentMapper(Mapper): """Segment a unified hand trajectory into atomic action clips. Implements the algorithm from paper https://arxiv.org/pdf/2510.21571: "we detect speed minima of the 3D hand wrists in the world space and use them as cutting points. We smooth the hand trajectory and select points that are local speed minima within a fixed window centered on each point." The operator reads the merged hand_action_tags (output of ``VideoClipReassemblyMapper``) and produces a list of segments. Each segment contains the start and end frame indices, plus sliced states / actions / joints for that segment. Segmentation is applied **independently** for left and right hands. A frame is a cutting point if it is a speed local minimum within a window of ``min_window`` frames on each side. Output field (``segment_field``) structure:: [ { "hand_type": "right", "segment_id": 0, "start_frame": 10, "end_frame": 45, "states": [...], "actions": [...], "valid_frame_ids": [...], "joints_world": [...], }, ... ] """
[docs] def __init__( self, hand_action_field: str = MetaKeys.hand_action_tags, segment_field: str = "atomic_action_segments", speed_smooth_window: int = 5, min_window: int = 15, min_segment_frames: int = 8, max_segment_frames: int = 300, hand_type: str = "both", *args, **kwargs, ): """ Initialization method. :param hand_action_field: Meta field storing merged hand action results (output of VideoClipReassemblyMapper). :param segment_field: Output meta field for atomic segments. :param speed_smooth_window: Window size for Savitzky-Golay smoothing of the speed signal before minima detection. Must be odd. :param min_window: Half-window size for local minima detection. A frame is a local minimum only if it is the minimum within ``[t - min_window, t + min_window]``. Larger values → fewer, longer segments. :param min_segment_frames: Minimum frames per segment. Segments shorter than this are merged with neighbors. :param max_segment_frames: Maximum frames per segment. Segments longer than this are forcibly split at the deepest speed minimum. :param hand_type: Which hand(s) to segment: 'left', 'right', or 'both'. """ super().__init__(*args, **kwargs) self.hand_action_field = hand_action_field self.segment_field = segment_field self.speed_smooth_window = speed_smooth_window self.min_window = min_window self.min_segment_frames = min_segment_frames self.max_segment_frames = max_segment_frames self.hand_type = hand_type
# ------------------------------------------------------------------ # Speed computation & smoothing # ------------------------------------------------------------------ @staticmethod def _compute_speed(positions: np.ndarray) -> np.ndarray: """Compute per-frame wrist speed from world-space positions. Returns an array of length N where speed[0] = 0. """ if len(positions) < 2: return np.zeros(len(positions)) vel = np.linalg.norm(np.diff(positions, axis=0), axis=1) return np.concatenate([[0.0], vel]) @staticmethod def _smooth_speed( speed: np.ndarray, window: int, ) -> np.ndarray: """Smooth speed signal with Savitzky-Golay filter.""" n = len(speed) if n < 5: return speed.copy() try: from scipy.signal import savgol_filter win = min(window, n) if win % 2 == 0: win -= 1 if win < 3: return speed.copy() return savgol_filter(speed, win, polyorder=2) except Exception: return speed.copy() # ------------------------------------------------------------------ # Local minima detection # ------------------------------------------------------------------ @staticmethod def _find_local_minima( speed: np.ndarray, half_window: int, ) -> list[int]: """Find indices that are local speed minima within a window. A frame t is a local minimum if speed[t] <= speed[k] for all k in [t - half_window, t + half_window]. """ n = len(speed) minima = [] for t in range(1, n - 1): lo = max(0, t - half_window) hi = min(n, t + half_window + 1) if speed[t] <= np.min(speed[lo:hi]): minima.append(t) return minima # ------------------------------------------------------------------ # Segment merging (too-short) and splitting (too-long) # ------------------------------------------------------------------ def _merge_short_segments( self, cut_points: list[int], n_frames: int, ) -> list[int]: """Remove cut points that would produce segments shorter than ``min_segment_frames``.""" if not cut_points: return cut_points filtered = [cut_points[0]] for cp in cut_points[1:]: if cp - filtered[-1] >= self.min_segment_frames: filtered.append(cp) # Check last segment if n_frames - filtered[-1] < self.min_segment_frames and len(filtered) > 1: filtered.pop() return filtered def _split_long_segments( self, cut_points: list[int], speed: np.ndarray, n_frames: int, ) -> list[int]: """Split segments exceeding ``max_segment_frames`` at the deepest speed minimum within the segment.""" boundaries = [0] + cut_points + [n_frames] new_cuts = [] for i in range(len(boundaries) - 1): start = boundaries[i] end = boundaries[i + 1] if i > 0: new_cuts.append(start) seg_len = end - start if seg_len <= self.max_segment_frames: continue # Find the deepest minimum in this segment to split mid = start + np.argmin(speed[start:end]) if mid > start + self.min_segment_frames and end - mid > self.min_segment_frames: new_cuts.append(mid) return sorted(set(new_cuts)) # ------------------------------------------------------------------ # Segment one hand # ------------------------------------------------------------------ def _segment_hand( self, hand_data: dict, hand_type: str, ) -> list[dict]: """Segment a single hand's trajectory into atomic actions.""" states = hand_data.get("states") if not states or len(states) < self.min_segment_frames: return [] states_arr = np.asarray(states, dtype=np.float64) positions = states_arr[:, 0:3] n_frames = len(states_arr) # 1. Compute and smooth speed speed = self._compute_speed(positions) smooth_speed = self._smooth_speed(speed, self.speed_smooth_window) # 2. Detect local minima minima = self._find_local_minima(smooth_speed, self.min_window) # 3. Merge short segments, split long ones cut_points = self._merge_short_segments(minima, n_frames) cut_points = self._split_long_segments( cut_points, smooth_speed, n_frames, ) # 4. Build segment boundaries boundaries = [0] + cut_points + [n_frames] valid_fids = hand_data.get("valid_frame_ids", list(range(n_frames))) actions = hand_data.get("actions", []) joints_world = hand_data.get("joints_world", []) joints_cam = hand_data.get("joints_cam", []) segments = [] for seg_idx in range(len(boundaries) - 1): s = boundaries[seg_idx] e = boundaries[seg_idx + 1] if e - s < 2: continue seg = { "hand_type": hand_type, "segment_id": seg_idx, "start_frame": valid_fids[s] if s < len(valid_fids) else s, "end_frame": (valid_fids[e - 1] if e - 1 < len(valid_fids) else e - 1), "states": states[s:e], "actions": actions[s:e] if actions else [], "valid_frame_ids": valid_fids[s:e], } if joints_world: seg["joints_world"] = joints_world[s:e] if joints_cam: seg["joints_cam"] = joints_cam[s:e] segments.append(seg) logger.debug( f"Segmented {hand_type} hand: {len(segments)} atomic actions " f"from {n_frames} frames, cut_points={cut_points}", ) return segments # ------------------------------------------------------------------ # Main entry # ------------------------------------------------------------------
[docs] def process_single(self, sample=None, rank=None): if Fields.meta not in sample: return sample meta = sample[Fields.meta] hand_action_list = meta.get(self.hand_action_field) if not hand_action_list: return sample # After reassembly, hand_action_list is [merged_result] # merged_result is a dict: {"right": {...}, "left": {...}} hand_types = ["right", "left"] if self.hand_type == "both" else [self.hand_type] all_segments = [] for clip_result in hand_action_list: if not clip_result or not isinstance(clip_result, dict): continue for ht in hand_types: hand_data = clip_result.get(ht) if not hand_data or not hand_data.get("states"): continue segs = self._segment_hand(hand_data, ht) all_segments.extend(segs) # Sort segments by start_frame for consistent ordering all_segments.sort(key=lambda s: (s["start_frame"], s["hand_type"])) meta[self.segment_field] = all_segments logger.info( f"Atomic action segmentation: {len(all_segments)} segments", ) return sample