|
18 | 18 | #include <raft/core/resources.hpp>
|
19 | 19 | #include <raft/util/cudart_utils.hpp>
|
20 | 20 |
|
| 21 | +#include <rmm/cuda_stream_pool.hpp> |
21 | 22 | #include <rmm/device_uvector.hpp>
|
22 | 23 |
|
23 | 24 | #include <gtest/gtest.h>
|
@@ -99,4 +100,51 @@ TEST(Raft, GetDeviceForAddress)
|
99 | 100 | ASSERT_EQ(0, raft::get_device_for_address(d.data()));
|
100 | 101 | }
|
101 | 102 |
|
| 103 | +TEST(Raft, Copy2DAsync) |
| 104 | +{ |
| 105 | + using DType = float; |
| 106 | + |
| 107 | + constexpr size_t rows = 4; |
| 108 | + constexpr size_t cols = 5; |
| 109 | + constexpr size_t pitch = 8; |
| 110 | + constexpr size_t elem_size = sizeof(DType); |
| 111 | + constexpr size_t width = cols; |
| 112 | + constexpr size_t height = rows; |
| 113 | + |
| 114 | + rmm::cuda_stream_pool pool{1}; |
| 115 | + auto stream = pool.get_stream(); |
| 116 | + |
| 117 | + rmm::device_uvector<DType> d_src(pitch * elem_size * rows, stream); |
| 118 | + rmm::device_uvector<DType> d_dst(pitch * elem_size * rows, stream); |
| 119 | + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); |
| 120 | + |
| 121 | + std::vector<DType> h_src(rows * pitch, -1.0f); |
| 122 | + std::vector<DType> h_dst(rows * pitch, 0.0f); |
| 123 | + std::vector<DType> h_dst_baseline(rows * pitch, 0.0f); |
| 124 | + |
| 125 | + for (size_t r = 0; r < rows; ++r) { |
| 126 | + for (size_t c = 0; c < pitch; ++c) { |
| 127 | + h_src[r * pitch + c] = static_cast<DType>(r * pitch + c); |
| 128 | + if (r < height && c < cols) { |
| 129 | + h_dst_baseline[r * pitch + c] = static_cast<DType>(r * pitch + c); |
| 130 | + } |
| 131 | + } |
| 132 | + } |
| 133 | + RAFT_CUDA_TRY( |
| 134 | + cudaMemcpy(d_src.data(), h_src.data(), pitch * elem_size * rows, cudaMemcpyHostToDevice)); |
| 135 | + RAFT_CUDA_TRY( |
| 136 | + cudaMemcpy(d_dst.data(), h_dst.data(), pitch * elem_size * rows, cudaMemcpyHostToDevice)); |
| 137 | + |
| 138 | + raft::copy_matrix(d_dst.data(), pitch, d_src.data(), pitch, width, height, stream); |
| 139 | + RAFT_CUDA_TRY( |
| 140 | + cudaMemcpy(h_dst.data(), d_dst.data(), pitch * elem_size * rows, cudaMemcpyDeviceToHost)); |
| 141 | + |
| 142 | + for (size_t r = 0; r < rows; ++r) { |
| 143 | + for (size_t c = 0; c < pitch; ++c) { |
| 144 | + ASSERT_EQ(h_dst[r * pitch + c], h_dst_baseline[r * pitch + c]) |
| 145 | + << "Mismatch at row " << r << " col " << c; |
| 146 | + } |
| 147 | + } |
| 148 | +} |
| 149 | + |
102 | 150 | } // namespace raft
|
0 commit comments