s6.nn package¶
Subpackages¶
Submodules¶
s6.nn.keypoint_detector_v2 module¶
- class s6.nn.keypoint_detector_v2.KeypointDetectorV2(ckpt_or_onnx_path: str = None)¶
Bases:
objectInference wrapper that supports ONNX or PyTorch checkpoint models.
s6.nn.keypoints module¶
- class s6.nn.keypoints.KeypointDetector(onnx_model_path)¶
Bases:
object- postprocess_keypoints(keypoints, original_size)¶
Rescales keypoints to the original image size.
- Parameters:
keypoints (np.ndarray) – Keypoints in normalized [0,1) coordinates. Shape: (num_keypoints, 2)
crop_info (tuple) – (crop_coords, scale_factors).
original_size (tuple) – Original image size (width, height).
- Returns:
Keypoints in original image coordinates. Shape: (num_keypoints, 2)
- Return type:
rescaled_keypoints (np.ndarray)
- preprocess_image(image: ndarray)¶
Preprocesses the input image for model inference.
- Parameters:
image_input (str or np.ndarray) – Path to the image or image array in HxWxC format.
- Returns:
Preprocessed image tensor in NCHW format. original_image (np.ndarray): Original image before preprocessing. crop_info (tuple): Information about cropping and scaling.
- Return type:
input_tensor (np.ndarray)
- run_inference(input_tensor)¶
Runs inference on the input tensor using the ONNX model.
- Parameters:
input_tensor (np.ndarray) – Preprocessed input tensor.
- Returns:
Detected keypoints. Shape: (batch_size, num_keypoints, 2) heatmaps (np.ndarray): Heatmaps output by the model. Shape: (batch_size, num_keypoints, H, W)
- Return type:
keypoints (np.ndarray)
s6.nn.runtime module¶
- s6.nn.runtime.onnx_load_model(onnx_model_path: str) InferenceSession¶
s6.nn.test_keypoint_detector_v2 module¶
- s6.nn.test_keypoint_detector_v2.test_checkpoint_missing_file()¶
- s6.nn.test_keypoint_detector_v2.test_init_missing_file()¶
- s6.nn.test_keypoint_detector_v2.test_onnx_inference(tmp_path)¶
s6.nn.utils module¶
- s6.nn.utils.torch_device() device¶
Module contents¶
Neural-network utilities for S6.