Skip to content

Add support for F16 in linalg::transpose #2672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: branch-25.06
Choose a base branch
from

Conversation

enp1s0
Copy link
Member

@enp1s0 enp1s0 commented May 15, 2025

Required by rapidsai/cuvs#716

Currently, F16 support requires users to manually insert the following code, which may not be considered user-friendly.

 namespace std { 
 template <> 
 struct is_floating_point<half> : std::true_type {}; 
 }  // namespace std 

In this PR, we define raft::is_floating_point_v<T> = std::is_floating_point_v<T> || std::is_same_v<T, half> and use it instead of std::is_floating_point_v.

Alternative 1: Define raft::is_floating_point_v<T> = std::is_floating_point_v<T> || std::is_same_v<T, half> and use it instead of std::is_floating_point_v.

Alternative 2: Add the is_floating_point<half> specialization above to some header file.

  • Updade .cuh
  • Update test

@enp1s0 enp1s0 requested a review from a team as a code owner May 15, 2025 16:54
@github-actions github-actions bot added the cpp label May 15, 2025
@enp1s0 enp1s0 self-assigned this May 15, 2025
@enp1s0 enp1s0 added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels May 15, 2025
@enp1s0
Copy link
Member Author

enp1s0 commented May 15, 2025

The current transpose passes the test for F16 due to this definition in the test code.

namespace std {
template <>
struct is_floating_point<half> : std::true_type {};
} // namespace std

@@ -81,7 +81,7 @@ template <typename T, typename IndexType, typename LayoutPolicy, typename Access
auto transpose(raft::resources const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
-> std::enable_if_t<std::is_floating_point_v<T>, void>
-> std::enable_if_t<std::is_floating_point_v<T> || std::is_same_v<T, half>, void>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the primitives in raft only support 32-bit types, and I think it's really great we're starting to add support for half precision. Though I think it'll be a good idea to let the user know (in the doxygen docs) which primitives support half AND full-precision. Can you add a quick note to the doc for this transpose() function ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @cjnolet for reviewing the code. I added a quick note.

@lowener
Copy link
Contributor

lowener commented May 19, 2025

I'm in favor of alternative 1, as it will lead to less confusion if we have our own definition

@enp1s0
Copy link
Member Author

enp1s0 commented May 20, 2025

Thank you for your comment, @lowener. I added raft::is_floating_point. I'm not sure where to put type_traits.hpp, which defines it. Is it okay in raft/core/?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cpp improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
Development

Successfully merging this pull request may close these issues.

4 participants