s6.utils.tensor_registryΒΆ

Registry and decorator for custom torch.Tensor conversion.

Provides a decorator to register classes that implement a to_torch() method, and a utility to check registration status.

s6.utils.tensor_registry.register_torch_converter(cls)

Class decorator to register a custom type for automatic torch conversion.

The class must implement a to_torch(self) -> torch.Tensor method.

Usage:

@register_torch_converter class Vector2D:

def __init__(self, x, y):

self.x = x self.y = y

def to_torch(self):

import torch return torch.tensor([self.x, self.y], dtype=torch.float)

s6.utils.tensor_registry.is_torch_convertible_type(obj_or_cls)

Return True if the given object or class has been registered for torch conversion.