Partition reduce on GPU
By using ReducePartitionBy
, per-partition (group) can be computed on GPU in a streaming (single-pass) fashion.
using CUDA
using CUDA: @allowscalar
n = 2^24
if has_cuda_gpu()
xs = CUDA.randn(n)
else
xs = randn(n)
end
ReducePartitionBy
expects partition to be contiguous; e.g., sorted by the key. We will use floor
as the key. So, plain sort!
works in this example.
sort!(xs)
In GPU, it is convenient to know the output location before computation. So, let us build unique index for each partition using cumsum!
:
function buildindices(f, xs)
isedge(x, y) = !isequal(f(x), f(y))
bounds = similar(xs, Bool)
@views map!(isedge, bounds[2:end], xs[1:end-1], xs[2:end])
@allowscalar bounds[1] = true
partitionindices = similar(xs, Int32)
return cumsum!(partitionindices, bounds)
end
partitionindices_xs = buildindices(floor, xs)
Counting the size of each partition
import FoldsCUDA # register the executor
using FLoops
using Transducers
function countparts(partitionindices; ex = nothing)
nparts = @allowscalar partitionindices[end]
ys = similar(partitionindices, nparts)
# The intra-partition reducing function that reduces each partition to
# a 2-tuple of index and count:
rf_partition = Map(p -> (p, 1))'(ProductRF(right, +))
index_and_count =
partitionindices |>
ReducePartitionBy(
identity, # partition by partitionindices
rf_partition,
(-1, 0),
)
@floop ex for (p, c) in index_and_count
@inbounds ys[p] = c
end
return ys
end
c_xs = countparts(partitionindices_xs)
12-element Vector{Int32}:
2
513
21913
358689
2279978
5722851
5732407
2279275
358969
22074
543
2
Computing the average of each partition
function meanparts(xs, partitionindices; ex = nothing)
nparts = @allowscalar partitionindices[end]
ys = similar(xs, float(eltype(xs)), nparts)
# The intra-partition reducing function that reduces each partition to
# a 3-tuple of index, count and sum:
rf_partition = Map(((i, p),) -> (p, 1, (@inbounds xs[i])))'(ProductRF(right, +, +))
index_count_and_sum =
pairs(partitionindices) |>
ReducePartitionBy(
((_, p),) -> p, # partition by partitionindices
rf_partition,
(-1, 0, zero(eltype(ys))),
)
@floop ex for (p, c, s) in index_count_and_sum
@inbounds ys[p] = s / c
end
return ys
end
m_xs = meanparts(xs, partitionindices_xs)
12-element Vector{Float64}:
-5.100679611366887
-4.230969259920362
-3.2588360090767297
-2.316408552266089
-1.383281382470574
-0.46001513502063335
0.45979683048961356
1.3833539445522247
2.316060866464498
3.261666575428892
4.205593222946841
5.118959352723174
This page was generated using Literate.jl.