# Partition reduce on GPU

By using ReducePartitionBy, per-partition (group) can be computed on GPU in a streaming (single-pass) fashion.

using CUDA
using CUDA: @allowscalar

n = 2^24
if has_cuda_gpu()
xs = CUDA.randn(n)
else
xs = randn(n)
end

ReducePartitionBy expects partition to be contiguous; e.g., sorted by the key. We will use floor as the key. So, plain sort! works in this example.

sort!(xs)

In GPU, it is convenient to know the output location before computation. So, let us build unique index for each partition using cumsum!:

function buildindices(f, xs)
isedge(x, y) = !isequal(f(x), f(y))
bounds = similar(xs, Bool)
@views map!(isedge, bounds[2:end], xs[1:end-1], xs[2:end])
@allowscalar bounds[1] = true
partitionindices = similar(xs, Int32)
return cumsum!(partitionindices, bounds)
end

partitionindices_xs = buildindices(floor, xs)

### Counting the size of each partition

import FoldsCUDA  # register the executor
using FLoops
using Transducers

function countparts(partitionindices; ex = nothing)
nparts = @allowscalar partitionindices[end]
ys = similar(partitionindices, nparts)

# The intra-partition reducing function that reduces each partition to
# a 2-tuple of index and count:
rf_partition = Map(p -> (p, 1))'(ProductRF(right, +))

index_and_count =
partitionindices |>
ReducePartitionBy(
identity,  # partition by partitionindices
rf_partition,
(-1, 0),
)

@floop ex for (p, c) in index_and_count
@inbounds ys[p] = c
end

return ys
end

c_xs = countparts(partitionindices_xs)
12-element Vector{Int32}:
2
513
21913
358689
2279978
5722851
5732407
2279275
358969
22074
543
2

### Computing the average of each partition

function meanparts(xs, partitionindices; ex = nothing)
nparts = @allowscalar partitionindices[end]
ys = similar(xs, float(eltype(xs)), nparts)

# The intra-partition reducing function that reduces each partition to
# a 3-tuple of index, count and sum:
rf_partition = Map(((i, p),) -> (p, 1, (@inbounds xs[i])))'(ProductRF(right, +, +))

index_count_and_sum =
pairs(partitionindices) |>
ReducePartitionBy(
((_, p),) -> p,  # partition by partitionindices
rf_partition,
(-1, 0, zero(eltype(ys))),
)

@floop ex for (p, c, s) in index_count_and_sum
@inbounds ys[p] = s / c
end

return ys
end

m_xs = meanparts(xs, partitionindices_xs)
12-element Vector{Float64}:
-5.100679611366887
-4.230969259920362
-3.2588360090767297
-2.316408552266089
-1.383281382470574
-0.46001513502063335
0.45979683048961356
1.3833539445522247
2.316060866464498
3.261666575428892
4.205593222946841
5.118959352723174