s6.nn package¶

Subpackages¶

Submodules¶

s6.nn.keypoint_detector_v2 module¶

class s6.nn.keypoint_detector_v2.KeypointDetectorV2(ckpt_or_onnx_path: str = None)¶

Bases: object

Inference wrapper that supports ONNX or PyTorch checkpoint models.

s6.nn.keypoints module¶

class s6.nn.keypoints.KeypointDetector(onnx_model_path)¶

Bases: object

postprocess_keypoints(keypoints, original_size)¶

Rescales keypoints to the original image size.

Parameters:
  • keypoints (np.ndarray) – Keypoints in normalized [0,1) coordinates. Shape: (num_keypoints, 2)

  • crop_info (tuple) – (crop_coords, scale_factors).

  • original_size (tuple) – Original image size (width, height).

Returns:

Keypoints in original image coordinates. Shape: (num_keypoints, 2)

Return type:

rescaled_keypoints (np.ndarray)

preprocess_image(image: ndarray)¶

Preprocesses the input image for model inference.

Parameters:

image_input (str or np.ndarray) – Path to the image or image array in HxWxC format.

Returns:

Preprocessed image tensor in NCHW format. original_image (np.ndarray): Original image before preprocessing. crop_info (tuple): Information about cropping and scaling.

Return type:

input_tensor (np.ndarray)

run_inference(input_tensor)¶

Runs inference on the input tensor using the ONNX model.

Parameters:

input_tensor (np.ndarray) – Preprocessed input tensor.

Returns:

Detected keypoints. Shape: (batch_size, num_keypoints, 2) heatmaps (np.ndarray): Heatmaps output by the model. Shape: (batch_size, num_keypoints, H, W)

Return type:

keypoints (np.ndarray)

s6.nn.runtime module¶

s6.nn.runtime.onnx_load_model(onnx_model_path: str) InferenceSession¶

s6.nn.test_keypoint_detector_v2 module¶

s6.nn.test_keypoint_detector_v2.test_checkpoint_missing_file()¶
s6.nn.test_keypoint_detector_v2.test_init_missing_file()¶
s6.nn.test_keypoint_detector_v2.test_onnx_inference(tmp_path)¶

s6.nn.utils module¶

s6.nn.utils.torch_device() device¶

Module contents¶

Neural-network utilities for S6.