sbrunk / storch

GPU accelerated deep learning and numeric computing for Scala 3.

Home Page:https://storch.dev

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Raw tensor type

sbrunk opened this issue · comments

With compile-time tracked tensor shapes as discussed #63 we'll probably need three type parameters like Tensor[DType, Shape, Device]. While this is great for type-safety, it makes the Tensor type a bit more convoluted.

For use-cases like prototyping, it can be useful to have an escape-hatch with some kind of raw tensor type, similar to upstream PyTorch or NumPy etc. where these attributes are only tracked at runtime. This would of course be less safe, and we need to think about how both variants could coexists and how we can convert between them etc.

Design considerations

If all our tensor type parameters were covariant, i.e. Tensor[+D <: DType, +S <: Shape, +DE <: Device], a raw tensor could perhaps be Tensor[DType, Shape, Device] (DType, Shape and Device being the upper bound of the type parameters), but currently they aren't, and I'm not sure if it's feasible as tensors are currently mutable and even if we had immutable tensors, covariance isn't without it's own issues (I'm by no means a type a type system expert, this is just my current understanding).

Without covariance, we could have a RawTensor type as a super-type of Tensor, with "unsafe" operations defined only as extension methods that we need to import explicitly. If these operations return a tensor, it would always be a RawTensor. That would probably make it quite easy (too easy?) to run unsafe operations on any tensor. Going from a RawTensor to a typed tensor would always need an explicit unsafe cast. Need to figure out if this can still cause name clashes though.

Yet another way could be to have a more strict separation without inheritance hierarchy but RawTensor in a different package and explicit conversions. Unsafe ops could always work on typed tensors too for convenience.

Perhaps there are other options as well?

An open question is if we'd also need to add support for raw tensor types in torch.nn (modules etc.), and if we could generate/derive that to avoid extra overhead but that's out of scope for this issue.