Source code for data_juicer.ops.mapper.image_mmpose_mapper

import importlib
import os
import subprocess
import sys
from typing import Dict, Optional, Sequence, Union

from loguru import logger

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = "image_mmpose_mapper"


[docs] @OPERATORS.register_module(OP_NAME) class ImageMMPoseMapper(Mapper): """Mapper to perform human keypoint detection inference using MMPose models. It requires three essential components for model initialization: - deploy_cfg (str): Path to the deployment configuration file (defines inference settings) - model_cfg (str): Path to the model configuration file (specifies model architecture) - model_files (List[str]): Model weight files including pre-trained weights and parameters The implementation follows the official MMPose deployment guidelines from MMDeploy. For detailed configuration requirements and usage examples, refer to: https://github.com/open-mmlab/mmdeploy/blob/main/docs/en/04-supported-codebases/mmpose.md """ _accelerator = "cuda"
[docs] def __init__( self, deploy_cfg: str = None, model_cfg: str = None, model_files: Optional[Union[str, Sequence[str]]] = None, pose_key: str = MetaKeys.pose_info, visualization_dir: str = None, *args, **kwargs, ): """ Initialization method. :param deploy_cfg: MMPose deployment config file. :param model_cfg: MMPose model config file. :param model_files: Path to the model weight files. :param pose_key: Key to store pose information. :param visualization_dir: Directory to save visualization results. :param args: extra args :param kwargs: extra args """ self._install_required_packages() super().__init__(*args, **kwargs) self.pose_key = pose_key self.deploy_cfg = deploy_cfg self.model_cfg = model_cfg if isinstance(model_files, str): self.model_files = [model_files] else: self.model_files = model_files self.visualization_dir = visualization_dir self.model_key = prepare_model( "mmlab", model_cfg=self.model_cfg, deploy_cfg=self.deploy_cfg, model_files=self.model_files, )
def _install_required_packages(self): try: importlib.import_module("mim") except ImportError: logger.info("Installing openmim...") try: subprocess.run([sys.executable, "-m", "pip", "install", "openmim"], check=True) except Exception: raise ValueError( "Failed to install openmim, please refer to the documentation at " "https://github.com/open-mmlab/mim/blob/main/docs/en/installation.md for installation instructions." ) try: importlib.import_module("mmpose") except ImportError: logger.info("Installing mmpose...") try: subprocess.run( [sys.executable, "-m", "pip", "install", "chumpy", "--no-build-isolation", "--no-deps"], check=True ) subprocess.run([sys.executable, "-m", "mim", "install", "mmpose"], check=True) except Exception: raise ValueError( "Failed to install mmpose, please refer to the documentation at " "https://mmpose.readthedocs.io/en/latest/installation.html for installation instructions." ) try: importlib.import_module("mmdet") except ImportError: logger.info("Installing mmdet using mim...") try: subprocess.run([sys.executable, "-m", "mim", "install", "mmdet==3.2.0"], check=True) except Exception: raise ValueError( "Failed to install mmdet, please refer to the documentation at " "https://mmdetection.readthedocs.io/en/latest/get_started.html#installation for installation instructions." )
[docs] def parse_and_filter(self, data_sample) -> Dict: """Extract elements necessary to represent a prediction into a dictionary. It's better to contain only basic data elements such as strings and numbers in order to guarantee it's json-serializable. Args: data_sample (:obj:`PoseDataSample`): Predictions of the model. Returns: dict: Prediction results. """ from mmpose.structures import PoseDataSample assert isinstance(data_sample, PoseDataSample) result = { "keypoints": [], "keypoint_scores": [], "bboxes": [], "bbox_scores": [], } if "pred_instances" in data_sample: if "keypoints" in data_sample.pred_instances: result["keypoints"] = data_sample.pred_instances.keypoints if "keypoint_scores" in data_sample.pred_instances: result["keypoint_scores"] = data_sample.pred_instances.keypoint_scores if "bboxes" in data_sample.pred_instances: result["bboxes"] = data_sample.pred_instances.bboxes if "bbox_scores" in data_sample.pred_instances: result["bbox_scores"] = data_sample.pred_instances.bbox_scores return result
[docs] def visualize_results(self, image, model, result, output_file): model.task_processor.visualize( image=image, model=model, result=result[0], window_name="visualize", output_file=output_file )
[docs] def process_single(self, sample, rank=None): # check if it's generated already if self.pose_key in sample[Fields.meta]: return sample model = get_model(self.model_key, rank, self.use_cuda()) images = sample[self.image_key] from mmpose.apis.inference import dataset_meta_from_config dataset_meta = dataset_meta_from_config(model.task_processor.model_cfg, dataset_mode="test") keypoint_names = [dataset_meta["keypoint_id2name"][i] for i in range(dataset_meta["num_keypoints"])] results = [model(img) for img in images] pose_info = [self.parse_and_filter(res[0]) for res in results] for pinfo in pose_info: pinfo["keypoint_names"] = keypoint_names sample[Fields.meta][self.pose_key] = pose_info if self.visualization_dir: os.makedirs(self.visualization_dir, exist_ok=True) for i, img in enumerate(images): img_name = os.path.splitext(os.path.basename(img))[0] output_file = f"{self.visualization_dir}/{img_name}.png" self.visualize_results(img, model, results[i], output_file) return sample