Partition reduce on GPU

By using ReducePartitionBy, per-partition (group) can be computed on GPU in a streaming (single-pass) fashion.

using CUDA
using CUDA: @allowscalar

n = 2^24
if has_cuda_gpu()
    xs = CUDA.randn(n)
else
    xs = randn(n)
end

ReducePartitionBy expects partition to be contiguous; e.g., sorted by the key. We will use floor as the key. So, plain sort! works in this example.

sort!(xs)

In GPU, it is convenient to know the output location before computation. So, let us build unique index for each partition using cumsum!:

function buildindices(f, xs)
    isedge(x, y) = !isequal(f(x), f(y))
    bounds = similar(xs, Bool)
    @views map!(isedge, bounds[2:end], xs[1:end-1], xs[2:end])
    @allowscalar bounds[1] = true
    partitionindices = similar(xs, Int32)
    return cumsum!(partitionindices, bounds)
end

partitionindices_xs = buildindices(floor, xs)

Counting the size of each partition

import FoldsCUDA  # register the executor
using FLoops
using Transducers

function countparts(partitionindices; ex = nothing)
    nparts = @allowscalar partitionindices[end]
    ys = similar(partitionindices, nparts)

    # The intra-partition reducing function that reduces each partition to
    # a 2-tuple of index and count:
    rf_partition = Map(p -> (p, 1))'(ProductRF(right, +))

    index_and_count =
        partitionindices |>
        ReducePartitionBy(
            identity,  # partition by partitionindices
            rf_partition,
            (-1, 0),
        )

    @floop ex for (p, c) in index_and_count
        @inbounds ys[p] = c
    end

    return ys
end

c_xs = countparts(partitionindices_xs)
12-element Vector{Int32}:
       2
     513
   21913
  358689
 2279978
 5722851
 5732407
 2279275
  358969
   22074
     543
       2

Computing the average of each partition

function meanparts(xs, partitionindices; ex = nothing)
    nparts = @allowscalar partitionindices[end]
    ys = similar(xs, float(eltype(xs)), nparts)

    # The intra-partition reducing function that reduces each partition to
    # a 3-tuple of index, count and sum:
    rf_partition = Map(((i, p),) -> (p, 1, (@inbounds xs[i])))'(ProductRF(right, +, +))

    index_count_and_sum =
        pairs(partitionindices) |>
        ReducePartitionBy(
            ((_, p),) -> p,  # partition by partitionindices
            rf_partition,
            (-1, 0, zero(eltype(ys))),
        )

    @floop ex for (p, c, s) in index_count_and_sum
        @inbounds ys[p] = s / c
    end

    return ys
end

m_xs = meanparts(xs, partitionindices_xs)
12-element Vector{Float64}:
 -5.100679611366887
 -4.230969259920362
 -3.2588360090767297
 -2.316408552266089
 -1.383281382470574
 -0.46001513502063335
  0.45979683048961356
  1.3833539445522247
  2.316060866464498
  3.261666575428892
  4.205593222946841
  5.118959352723174

This page was generated using Literate.jl.