Skip to content

Commit a3f2d98

Browse files
authored
[Feat] add cudaMemcpy2DAsync wrapper (#2674)
Authors: - rhdong (https://github.com/rhdong) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #2674
1 parent 21da2bd commit a3f2d98

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

cpp/include/raft/util/cudart_utils.hpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,42 @@ void copy(Type* dst, const Type* src, size_t len, rmm::cuda_stream_view stream)
148148
RAFT_CUDA_TRY(cudaMemcpyAsync(dst, src, len * sizeof(Type), cudaMemcpyDefault, stream));
149149
}
150150

151+
/**
152+
* @brief Generic matrix copy method with pitch support
153+
*
154+
* Performs an asynchronous 2D memory copy from source to destination, where each row
155+
* may include padding (i.e., the pitch is larger than the row width). This is useful
156+
* when working with pitched memory allocations or copying submatrices.
157+
*
158+
* @tparam Type Data type of the elements
159+
* @param dst Destination pointer
160+
* @param dst_pitch Pitch (number of elements) between consecutive rows in the destination
161+
* @param src Source pointer
162+
* @param src_pitch Pitch (number of elements) between consecutive rows in the source
163+
* @param width Number of elements to copy per row
164+
* @param height Number of rows to copy
165+
* @param stream CUDA stream used to perform the asynchronous copy
166+
*/
167+
template <typename Type>
168+
void copy_matrix(Type* dst,
169+
size_t dst_pitch,
170+
const Type* src,
171+
size_t src_pitch,
172+
size_t width,
173+
size_t height,
174+
rmm::cuda_stream_view stream)
175+
{
176+
constexpr size_t elem_size = sizeof(Type);
177+
RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst,
178+
dst_pitch * elem_size,
179+
src,
180+
src_pitch * elem_size,
181+
width * elem_size,
182+
height,
183+
cudaMemcpyDefault,
184+
stream));
185+
}
186+
151187
/**
152188
* @defgroup Copy Copy methods
153189
* These are here along with the generic 'copy' method in order to improve

cpp/tests/util/cudart_utils.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <raft/core/resources.hpp>
1919
#include <raft/util/cudart_utils.hpp>
2020

21+
#include <rmm/cuda_stream_pool.hpp>
2122
#include <rmm/device_uvector.hpp>
2223

2324
#include <gtest/gtest.h>
@@ -99,4 +100,51 @@ TEST(Raft, GetDeviceForAddress)
99100
ASSERT_EQ(0, raft::get_device_for_address(d.data()));
100101
}
101102

103+
TEST(Raft, Copy2DAsync)
104+
{
105+
using DType = float;
106+
107+
constexpr size_t rows = 4;
108+
constexpr size_t cols = 5;
109+
constexpr size_t pitch = 8;
110+
constexpr size_t elem_size = sizeof(DType);
111+
constexpr size_t width = cols;
112+
constexpr size_t height = rows;
113+
114+
rmm::cuda_stream_pool pool{1};
115+
auto stream = pool.get_stream();
116+
117+
rmm::device_uvector<DType> d_src(pitch * elem_size * rows, stream);
118+
rmm::device_uvector<DType> d_dst(pitch * elem_size * rows, stream);
119+
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
120+
121+
std::vector<DType> h_src(rows * pitch, -1.0f);
122+
std::vector<DType> h_dst(rows * pitch, 0.0f);
123+
std::vector<DType> h_dst_baseline(rows * pitch, 0.0f);
124+
125+
for (size_t r = 0; r < rows; ++r) {
126+
for (size_t c = 0; c < pitch; ++c) {
127+
h_src[r * pitch + c] = static_cast<DType>(r * pitch + c);
128+
if (r < height && c < cols) {
129+
h_dst_baseline[r * pitch + c] = static_cast<DType>(r * pitch + c);
130+
}
131+
}
132+
}
133+
RAFT_CUDA_TRY(
134+
cudaMemcpy(d_src.data(), h_src.data(), pitch * elem_size * rows, cudaMemcpyHostToDevice));
135+
RAFT_CUDA_TRY(
136+
cudaMemcpy(d_dst.data(), h_dst.data(), pitch * elem_size * rows, cudaMemcpyHostToDevice));
137+
138+
raft::copy_matrix(d_dst.data(), pitch, d_src.data(), pitch, width, height, stream);
139+
RAFT_CUDA_TRY(
140+
cudaMemcpy(h_dst.data(), d_dst.data(), pitch * elem_size * rows, cudaMemcpyDeviceToHost));
141+
142+
for (size_t r = 0; r < rows; ++r) {
143+
for (size_t c = 0; c < pitch; ++c) {
144+
ASSERT_EQ(h_dst[r * pitch + c], h_dst_baseline[r * pitch + c])
145+
<< "Mismatch at row " << r << " col " << c;
146+
}
147+
}
148+
}
149+
102150
} // namespace raft

0 commit comments

Comments
 (0)