-
Notifications
You must be signed in to change notification settings - Fork 206
Add support for F16 in linalg::transpose #2672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: branch-25.06
Are you sure you want to change the base?
Conversation
The current raft/cpp/tests/linalg/transpose.cu Lines 34 to 37 in 98b6fe0
|
@@ -81,7 +81,7 @@ template <typename T, typename IndexType, typename LayoutPolicy, typename Access | |||
auto transpose(raft::resources const& handle, | |||
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in, | |||
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out) | |||
-> std::enable_if_t<std::is_floating_point_v<T>, void> | |||
-> std::enable_if_t<std::is_floating_point_v<T> || std::is_same_v<T, half>, void> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of the primitives in raft only support 32-bit types, and I think it's really great we're starting to add support for half precision. Though I think it'll be a good idea to let the user know (in the doxygen docs) which primitives support half AND full-precision. Can you add a quick note to the doc for this transpose() function ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @cjnolet for reviewing the code. I added a quick note.
I'm in favor of alternative 1, as it will lead to less confusion if we have our own definition |
Thank you for your comment, @lowener. I added |
Required by rapidsai/cuvs#716
Currently, F16 support requires users to manually insert the following code, which may not be considered user-friendly.
In this PR, we define
raft::is_floating_point_v<T>
=std::is_floating_point_v<T> || std::is_same_v<T, half>
and use it instead ofstd::is_floating_point_v
.Alternative 1: Defineraft::is_floating_point_v<T>
=std::is_floating_point_v<T> || std::is_same_v<T, half>
and use it instead ofstd::is_floating_point_v
.Alternative 2: Add the
is_floating_point<half>
specialization above to some header file.