Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions vortex-array/benches/aggregate_grouped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use vortex_array::VortexSessionExecute;
use vortex_array::aggregate_fn::AggregateFnVTable;
use vortex_array::aggregate_fn::DynGroupedAccumulator;
use vortex_array::aggregate_fn::EmptyOptions;
use vortex_array::aggregate_fn::GroupIds;
use vortex_array::aggregate_fn::GroupedAccumulator;
use vortex_array::aggregate_fn::fns::count::Count;
use vortex_array::aggregate_fn::fns::sum::Sum;
Expand Down Expand Up @@ -45,24 +46,22 @@ fn total_element_count(group_sizes: &[usize]) -> usize {

struct DenseGroupedInput {
values: ArrayRef,
group_ids: Vec<u32>,
num_groups: usize,
group_ids: GroupIds,
}

fn dense_grouped_input(values: ArrayRef, group_sizes: &[usize]) -> DenseGroupedInput {
assert_eq!(values.len(), total_element_count(group_sizes));

let group_ids = group_sizes
.iter()
.enumerate()
.flat_map(|(group_id, &size)| std::iter::repeat_n(group_id as u32, size))
.collect();
let group_ids = GroupIds::from_iter(
group_sizes
.iter()
.enumerate()
.flat_map(|(group_id, &size)| std::iter::repeat_n(group_id as u32, size)),
group_sizes.len(),
)
.unwrap();

DenseGroupedInput {
values,
group_ids,
num_groups: group_sizes.len(),
}
DenseGroupedInput { values, group_ids }
}

fn i32_nullable_all_valid_input() -> DenseGroupedInput {
Expand Down Expand Up @@ -142,14 +141,14 @@ where
{
let mut acc =
GroupedAccumulator::try_new(vtable, EmptyOptions, input.values.dtype().clone()).unwrap();
let num_groups = input.group_ids.num_groups();
acc.accumulate(
&input.values,
&input.group_ids,
input.num_groups,
&mut LEGACY_SESSION.create_execution_ctx(),
)
.unwrap();
divan::black_box(acc.finish(input.num_groups).unwrap())
divan::black_box(acc.finish(num_groups).unwrap())
}

#[divan::bench]
Expand Down
195 changes: 130 additions & 65 deletions vortex-array/src/aggregate_fn/accumulator_grouped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use vortex_error::vortex_ensure;
use vortex_error::vortex_err;

use crate::ArrayRef;
use crate::Columnar;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::aggregate_fn::Accumulator;
Expand All @@ -17,16 +16,107 @@ use crate::aggregate_fn::AggregateFnVTable;
use crate::aggregate_fn::DynAccumulator;
use crate::aggregate_fn::kernels::GroupedAggregateKernelResult;
use crate::aggregate_fn::session::AggregateFnSessionExt;
use crate::array::ArrayId;
use crate::arrays::PrimitiveArray;
use crate::builders::builder_with_capacity;
use crate::columnar::AnyColumnar;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::executor::max_iterations;
use crate::scalar::Scalar;
use crate::validity::Validity;

/// Reference-counted type-erased grouped accumulator.
pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;

/// An accumulator used for computing aggregates over dense group ids.
/// Encoded group ids parallel to a grouped aggregate input batch.
///
/// The array must contain non-null `u32` ordinals. The ordinals are dense state slots in
/// `0..num_groups`, not raw group keys. Range validation may require executing the encoded array,
/// so kernels that can prove the invariant from encoded metadata should avoid materializing and
/// otherwise call [`Self::validated_ids`] before indexing group state.
#[derive(Clone, Debug)]
pub struct GroupIds {
ids: ArrayRef,
num_groups: usize,
}

impl GroupIds {
/// Create group ids from an encoded non-null `u32` array.
pub fn new(ids: ArrayRef, num_groups: usize) -> VortexResult<Self> {
validate_num_groups(num_groups)?;
vortex_ensure!(
ids.dtype() == &DType::Primitive(PType::U32, Nullability::NonNullable),
"Group ids must be non-nullable u32, got {}",
ids.dtype()
);
Ok(Self { ids, num_groups })
}

/// Create group ids from a materialized buffer.
pub fn from_buffer(ids: Buffer<u32>, num_groups: usize) -> VortexResult<Self> {
Self::new(
PrimitiveArray::new(ids, Validity::NonNullable).into_array(),
num_groups,
)
}

/// Create group ids from materialized values.
pub fn from_iter(ids: impl IntoIterator<Item = u32>, num_groups: usize) -> VortexResult<Self> {
Self::from_buffer(Buffer::from_iter(ids), num_groups)
}

/// Create group ids containing `0..num_groups`.
pub fn range(num_groups: usize) -> VortexResult<Self> {
validate_num_groups(num_groups)?;
if num_groups == 0 {
return Self::from_buffer(Buffer::<u32>::empty(), num_groups);
}

let last = u32::try_from(num_groups - 1).map_err(|_| {
vortex_err!(
"num_groups {} exceeds dense u32 group id capacity",
num_groups
)
})?;
Self::from_buffer((0..=last).collect(), num_groups)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we want a sequence array?

}

/// Return the encoded ids array.
pub fn ids(&self) -> &ArrayRef {
&self.ids
}

/// Return the number of dense group state slots.
pub fn num_groups(&self) -> usize {
self.num_groups
}

/// Return the number of ids.
pub fn len(&self) -> usize {
self.ids.len()
}

/// Return whether there are no ids.
pub fn is_empty(&self) -> bool {
self.ids.is_empty()
}

/// Return the encoding id for kernel dispatch.
pub fn encoding_id(&self) -> ArrayId {
self.ids.encoding_id()
}

/// Execute the ids to a native buffer and validate every id is in range.
pub fn validated_ids(&self, ctx: &mut ExecutionCtx) -> VortexResult<Buffer<u32>> {
let ids = self.ids.clone().execute::<Buffer<u32>>(ctx)?;
validate_group_ids(ids.as_ref(), self.num_groups)?;
Ok(ids)
}
}

/// An accumulator used for computing aggregates over group ids.
///
/// Group ids are caller-assigned `u32` ordinals in the dense range `0..num_groups`. Input batches
/// may repeat, omit, and reorder those ids, but every id must identify a state slot rather than a
Expand Down Expand Up @@ -88,54 +178,30 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
Ok(())
}

fn validate_group_ids(&self, group_ids: &[u32], num_groups: usize) -> VortexResult<()> {
validate_num_groups(num_groups)?;
for &group_id in group_ids {
vortex_ensure!(
(group_id as usize) < num_groups,
"Group id {} out of range for {} groups",
group_id,
num_groups
);
}
Ok(())
}

fn accumulate_kernel_result(
&mut self,
result: GroupedAggregateKernelResult,
num_groups: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<()> {
self.accumulate_partials(result.partials(), result.group_ids(), num_groups, ctx)
self.accumulate_partials(result.partials(), result.group_ids(), ctx)
}

fn try_accumulate_kernel(
&mut self,
batch: &ArrayRef,
group_ids: &[u32],
num_groups: usize,
group_ids: &GroupIds,
ctx: &mut ExecutionCtx,
) -> VortexResult<bool> {
let session = ctx.session().clone();

if let Some(kernel) = session
.aggregate_fns()
.find_grouped_encoding_kernel(batch.encoding_id(), self.aggregate_fn.id())
&& let Some(result) =
kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, num_groups, ctx)?
{
self.accumulate_kernel_result(result, num_groups, ctx)?;
return Ok(true);
}

if let Some(kernel) = session
.aggregate_fns()
.find_grouped_kernel(self.aggregate_fn.id())
&& let Some(result) =
kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, num_groups, ctx)?
if let Some(kernel) = session.aggregate_fns().find_grouped_kernel(
self.aggregate_fn.id(),
batch.encoding_id(),
group_ids.encoding_id(),
) && let Some(result) =
kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, ctx)?
{
self.accumulate_kernel_result(result, num_groups, ctx)?;
self.accumulate_kernel_result(result, ctx)?;
return Ok(true);
}

Expand Down Expand Up @@ -198,18 +264,31 @@ fn validate_num_groups(num_groups: usize) -> VortexResult<()> {
Ok(())
}

fn validate_group_ids(group_ids: &[u32], num_groups: usize) -> VortexResult<()> {
validate_num_groups(num_groups)?;
for &group_id in group_ids {
vortex_ensure!(
(group_id as usize) < num_groups,
"Group id {} out of range for {} groups",
group_id,
num_groups
);
}
Ok(())
}

/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the
/// aggregate function is not known at compile time.
pub trait DynGroupedAccumulator: 'static + Send {
/// Accumulate a values batch into dense group state.
///
/// `group_ids` is parallel to `batch`. Each id must be a caller-assigned group ordinal in
/// `0..num_groups`; ids may repeat, appear out of order, or be absent from a given batch.
/// `0..group_ids.num_groups()`; ids may repeat, appear out of order, or be absent from a
/// given batch.
fn accumulate(
&mut self,
batch: &ArrayRef,
group_ids: &[u32],
num_groups: usize,
group_ids: &GroupIds,
ctx: &mut ExecutionCtx,
) -> VortexResult<()>;

Expand All @@ -220,8 +299,7 @@ pub trait DynGroupedAccumulator: 'static + Send {
fn accumulate_partials(
&mut self,
partials: &ArrayRef,
group_ids: &[u32],
num_groups: usize,
group_ids: &GroupIds,
ctx: &mut ExecutionCtx,
) -> VortexResult<()>;

Expand Down Expand Up @@ -254,10 +332,10 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
fn accumulate(
&mut self,
batch: &ArrayRef,
group_ids: &[u32],
num_groups: usize,
group_ids: &GroupIds,
ctx: &mut ExecutionCtx,
) -> VortexResult<()> {
let num_groups = group_ids.num_groups();
vortex_ensure!(
batch.dtype() == &self.dtype,
"Input DType mismatch: expected {}, got {}",
Expand All @@ -271,56 +349,43 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
group_ids.len()
);

self.validate_group_ids(group_ids, num_groups)?;
self.ensure_groups(num_groups)?;

if self.try_accumulate_kernel(batch, group_ids, num_groups, ctx)? {
return Ok(());
}

if self.vtable.try_accumulate_grouped(
&mut self.partials[..num_groups],
batch,
group_ids,
ctx,
)? {
if self.try_accumulate_kernel(batch, group_ids, ctx)? {
return Ok(());
}

let input = batch.clone();
let mut batch = batch.clone();
let mut tried_current = true;
for _ in 0..max_iterations() {
if batch.is::<AnyColumnar>() {
break;
}

if self.try_accumulate_kernel(&batch, group_ids, num_groups, ctx)? {
if !tried_current && self.try_accumulate_kernel(&batch, group_ids, ctx)? {
return Ok(());
}

batch = batch.execute(ctx)?;
tried_current = false;
}

let columnar = batch.clone().execute::<Columnar>(ctx)?;
if self.vtable.accumulate_grouped(
&mut self.partials[..num_groups],
&columnar,
group_ids,
ctx,
)? {
if !tried_current && self.try_accumulate_kernel(&batch, group_ids, ctx)? {
return Ok(());
}

self.accumulate_fallback(&input, group_ids, ctx)
let group_ids = group_ids.validated_ids(ctx)?;
self.accumulate_fallback(&input, group_ids.as_ref(), ctx)
}

fn accumulate_partials(
&mut self,
partials: &ArrayRef,
group_ids: &[u32],
num_groups: usize,
group_ids: &GroupIds,
ctx: &mut ExecutionCtx,
) -> VortexResult<()> {
let num_groups = group_ids.num_groups();
vortex_ensure!(
partials.dtype() == &self.partial_dtype,
"Partial DType mismatch: expected {}, got {}",
Expand All @@ -334,7 +399,7 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
group_ids.len()
);

self.validate_group_ids(group_ids, num_groups)?;
let group_ids = group_ids.validated_ids(ctx)?;
self.ensure_groups(num_groups)?;

for (row_idx, &group_id) in group_ids.iter().enumerate() {
Expand Down
Loading
Loading