s6.utils.tensor_registryΒΆ
Registry and decorator for custom torch.Tensor conversion.
Provides a decorator to register classes that implement a to_torch() method, and a utility to check registration status.
- s6.utils.tensor_registry.register_torch_converter(cls)
Class decorator to register a custom type for automatic torch conversion.
The class must implement a to_torch(self) -> torch.Tensor method.
Usage:
@register_torch_converter class Vector2D:
- def __init__(self, x, y):
self.x = x self.y = y
- def to_torch(self):
import torch return torch.tensor([self.x, self.y], dtype=torch.float)
- s6.utils.tensor_registry.is_torch_convertible_type(obj_or_cls)
Return True if the given object or class has been registered for torch conversion.