From 59d7be736e9462d138d80fd5640fa686c5ac84a2 Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Mon, 22 Jun 2026 15:28:52 +0100 Subject: [PATCH 1/2] group ids as array ref, multi encoding kernel lookup Signed-off-by: Onur Satici --- vortex-array/benches/aggregate_grouped.rs | 27 ++- .../src/aggregate_fn/accumulator_grouped.rs | 195 ++++++++++++------ .../src/aggregate_fn/fns/count/grouped.rs | 45 +++- .../src/aggregate_fn/fns/count/mod.rs | 38 ++-- .../src/aggregate_fn/fns/sum/grouped.rs | 65 ++++++ vortex-array/src/aggregate_fn/fns/sum/mod.rs | 36 +--- vortex-array/src/aggregate_fn/kernels.rs | 25 ++- vortex-array/src/aggregate_fn/session.rs | 151 ++++++++++---- vortex-array/src/aggregate_fn/vtable.rs | 31 --- 9 files changed, 405 insertions(+), 208 deletions(-) diff --git a/vortex-array/benches/aggregate_grouped.rs b/vortex-array/benches/aggregate_grouped.rs index 2d46a5cce8a..e99e4619143 100644 --- a/vortex-array/benches/aggregate_grouped.rs +++ b/vortex-array/benches/aggregate_grouped.rs @@ -15,6 +15,7 @@ use vortex_array::VortexSessionExecute; use vortex_array::aggregate_fn::AggregateFnVTable; use vortex_array::aggregate_fn::DynGroupedAccumulator; use vortex_array::aggregate_fn::EmptyOptions; +use vortex_array::aggregate_fn::GroupIds; use vortex_array::aggregate_fn::GroupedAccumulator; use vortex_array::aggregate_fn::fns::count::Count; use vortex_array::aggregate_fn::fns::sum::Sum; @@ -45,24 +46,22 @@ fn total_element_count(group_sizes: &[usize]) -> usize { struct DenseGroupedInput { values: ArrayRef, - group_ids: Vec, - num_groups: usize, + group_ids: GroupIds, } fn dense_grouped_input(values: ArrayRef, group_sizes: &[usize]) -> DenseGroupedInput { assert_eq!(values.len(), total_element_count(group_sizes)); - let group_ids = group_sizes - .iter() - .enumerate() - .flat_map(|(group_id, &size)| std::iter::repeat_n(group_id as u32, size)) - .collect(); + let group_ids = GroupIds::from_iter( + group_sizes + .iter() + .enumerate() + .flat_map(|(group_id, &size)| std::iter::repeat_n(group_id as u32, size)), + group_sizes.len(), + ) + .unwrap(); - DenseGroupedInput { - values, - group_ids, - num_groups: group_sizes.len(), - } + DenseGroupedInput { values, group_ids } } fn i32_nullable_all_valid_input() -> DenseGroupedInput { @@ -142,14 +141,14 @@ where { let mut acc = GroupedAccumulator::try_new(vtable, EmptyOptions, input.values.dtype().clone()).unwrap(); + let num_groups = input.group_ids.num_groups(); acc.accumulate( &input.values, &input.group_ids, - input.num_groups, &mut LEGACY_SESSION.create_execution_ctx(), ) .unwrap(); - divan::black_box(acc.finish(input.num_groups).unwrap()) + divan::black_box(acc.finish(num_groups).unwrap()) } #[divan::bench] diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 7a614ceed63..2e51fb17782 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -7,7 +7,6 @@ use vortex_error::vortex_ensure; use vortex_error::vortex_err; use crate::ArrayRef; -use crate::Columnar; use crate::ExecutionCtx; use crate::IntoArray; use crate::aggregate_fn::Accumulator; @@ -17,16 +16,107 @@ use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::kernels::GroupedAggregateKernelResult; use crate::aggregate_fn::session::AggregateFnSessionExt; +use crate::array::ArrayId; +use crate::arrays::PrimitiveArray; use crate::builders::builder_with_capacity; use crate::columnar::AnyColumnar; use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; use crate::executor::max_iterations; use crate::scalar::Scalar; +use crate::validity::Validity; /// Reference-counted type-erased grouped accumulator. pub type GroupedAccumulatorRef = Box; -/// An accumulator used for computing aggregates over dense group ids. +/// Encoded group ids parallel to a grouped aggregate input batch. +/// +/// The array must contain non-null `u32` ordinals. The ordinals are dense state slots in +/// `0..num_groups`, not raw group keys. Range validation may require executing the encoded array, +/// so kernels that can prove the invariant from encoded metadata should avoid materializing and +/// otherwise call [`Self::validated_ids`] before indexing group state. +#[derive(Clone, Debug)] +pub struct GroupIds { + ids: ArrayRef, + num_groups: usize, +} + +impl GroupIds { + /// Create group ids from an encoded non-null `u32` array. + pub fn new(ids: ArrayRef, num_groups: usize) -> VortexResult { + validate_num_groups(num_groups)?; + vortex_ensure!( + ids.dtype() == &DType::Primitive(PType::U32, Nullability::NonNullable), + "Group ids must be non-nullable u32, got {}", + ids.dtype() + ); + Ok(Self { ids, num_groups }) + } + + /// Create group ids from a materialized buffer. + pub fn from_buffer(ids: Buffer, num_groups: usize) -> VortexResult { + Self::new( + PrimitiveArray::new(ids, Validity::NonNullable).into_array(), + num_groups, + ) + } + + /// Create group ids from materialized values. + pub fn from_iter(ids: impl IntoIterator, num_groups: usize) -> VortexResult { + Self::from_buffer(Buffer::from_iter(ids), num_groups) + } + + /// Create group ids containing `0..num_groups`. + pub fn range(num_groups: usize) -> VortexResult { + validate_num_groups(num_groups)?; + if num_groups == 0 { + return Self::from_buffer(Buffer::::empty(), num_groups); + } + + let last = u32::try_from(num_groups - 1).map_err(|_| { + vortex_err!( + "num_groups {} exceeds dense u32 group id capacity", + num_groups + ) + })?; + Self::from_buffer((0..=last).collect(), num_groups) + } + + /// Return the encoded ids array. + pub fn ids(&self) -> &ArrayRef { + &self.ids + } + + /// Return the number of dense group state slots. + pub fn num_groups(&self) -> usize { + self.num_groups + } + + /// Return the number of ids. + pub fn len(&self) -> usize { + self.ids.len() + } + + /// Return whether there are no ids. + pub fn is_empty(&self) -> bool { + self.ids.is_empty() + } + + /// Return the encoding id for kernel dispatch. + pub fn encoding_id(&self) -> ArrayId { + self.ids.encoding_id() + } + + /// Execute the ids to a native buffer and validate every id is in range. + pub fn validated_ids(&self, ctx: &mut ExecutionCtx) -> VortexResult> { + let ids = self.ids.clone().execute::>(ctx)?; + validate_group_ids(ids.as_ref(), self.num_groups)?; + Ok(ids) + } +} + +/// An accumulator used for computing aggregates over group ids. /// /// Group ids are caller-assigned `u32` ordinals in the dense range `0..num_groups`. Input batches /// may repeat, omit, and reorder those ids, but every id must identify a state slot rather than a @@ -88,54 +178,30 @@ impl GroupedAccumulator { Ok(()) } - fn validate_group_ids(&self, group_ids: &[u32], num_groups: usize) -> VortexResult<()> { - validate_num_groups(num_groups)?; - for &group_id in group_ids { - vortex_ensure!( - (group_id as usize) < num_groups, - "Group id {} out of range for {} groups", - group_id, - num_groups - ); - } - Ok(()) - } - fn accumulate_kernel_result( &mut self, result: GroupedAggregateKernelResult, - num_groups: usize, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - self.accumulate_partials(result.partials(), result.group_ids(), num_groups, ctx) + self.accumulate_partials(result.partials(), result.group_ids(), ctx) } fn try_accumulate_kernel( &mut self, batch: &ArrayRef, - group_ids: &[u32], - num_groups: usize, + group_ids: &GroupIds, ctx: &mut ExecutionCtx, ) -> VortexResult { let session = ctx.session().clone(); - if let Some(kernel) = session - .aggregate_fns() - .find_grouped_encoding_kernel(batch.encoding_id(), self.aggregate_fn.id()) - && let Some(result) = - kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, num_groups, ctx)? - { - self.accumulate_kernel_result(result, num_groups, ctx)?; - return Ok(true); - } - - if let Some(kernel) = session - .aggregate_fns() - .find_grouped_kernel(self.aggregate_fn.id()) - && let Some(result) = - kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, num_groups, ctx)? + if let Some(kernel) = session.aggregate_fns().find_grouped_kernel( + self.aggregate_fn.id(), + batch.encoding_id(), + group_ids.encoding_id(), + ) && let Some(result) = + kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, ctx)? { - self.accumulate_kernel_result(result, num_groups, ctx)?; + self.accumulate_kernel_result(result, ctx)?; return Ok(true); } @@ -198,18 +264,31 @@ fn validate_num_groups(num_groups: usize) -> VortexResult<()> { Ok(()) } +fn validate_group_ids(group_ids: &[u32], num_groups: usize) -> VortexResult<()> { + validate_num_groups(num_groups)?; + for &group_id in group_ids { + vortex_ensure!( + (group_id as usize) < num_groups, + "Group id {} out of range for {} groups", + group_id, + num_groups + ); + } + Ok(()) +} + /// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the /// aggregate function is not known at compile time. pub trait DynGroupedAccumulator: 'static + Send { /// Accumulate a values batch into dense group state. /// /// `group_ids` is parallel to `batch`. Each id must be a caller-assigned group ordinal in - /// `0..num_groups`; ids may repeat, appear out of order, or be absent from a given batch. + /// `0..group_ids.num_groups()`; ids may repeat, appear out of order, or be absent from a + /// given batch. fn accumulate( &mut self, batch: &ArrayRef, - group_ids: &[u32], - num_groups: usize, + group_ids: &GroupIds, ctx: &mut ExecutionCtx, ) -> VortexResult<()>; @@ -220,8 +299,7 @@ pub trait DynGroupedAccumulator: 'static + Send { fn accumulate_partials( &mut self, partials: &ArrayRef, - group_ids: &[u32], - num_groups: usize, + group_ids: &GroupIds, ctx: &mut ExecutionCtx, ) -> VortexResult<()>; @@ -254,10 +332,10 @@ impl DynGroupedAccumulator for GroupedAccumulator { fn accumulate( &mut self, batch: &ArrayRef, - group_ids: &[u32], - num_groups: usize, + group_ids: &GroupIds, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { + let num_groups = group_ids.num_groups(); vortex_ensure!( batch.dtype() == &self.dtype, "Input DType mismatch: expected {}, got {}", @@ -271,56 +349,43 @@ impl DynGroupedAccumulator for GroupedAccumulator { group_ids.len() ); - self.validate_group_ids(group_ids, num_groups)?; self.ensure_groups(num_groups)?; - if self.try_accumulate_kernel(batch, group_ids, num_groups, ctx)? { - return Ok(()); - } - - if self.vtable.try_accumulate_grouped( - &mut self.partials[..num_groups], - batch, - group_ids, - ctx, - )? { + if self.try_accumulate_kernel(batch, group_ids, ctx)? { return Ok(()); } let input = batch.clone(); let mut batch = batch.clone(); + let mut tried_current = true; for _ in 0..max_iterations() { if batch.is::() { break; } - if self.try_accumulate_kernel(&batch, group_ids, num_groups, ctx)? { + if !tried_current && self.try_accumulate_kernel(&batch, group_ids, ctx)? { return Ok(()); } batch = batch.execute(ctx)?; + tried_current = false; } - let columnar = batch.clone().execute::(ctx)?; - if self.vtable.accumulate_grouped( - &mut self.partials[..num_groups], - &columnar, - group_ids, - ctx, - )? { + if !tried_current && self.try_accumulate_kernel(&batch, group_ids, ctx)? { return Ok(()); } - self.accumulate_fallback(&input, group_ids, ctx) + let group_ids = group_ids.validated_ids(ctx)?; + self.accumulate_fallback(&input, group_ids.as_ref(), ctx) } fn accumulate_partials( &mut self, partials: &ArrayRef, - group_ids: &[u32], - num_groups: usize, + group_ids: &GroupIds, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { + let num_groups = group_ids.num_groups(); vortex_ensure!( partials.dtype() == &self.partial_dtype, "Partial DType mismatch: expected {}, got {}", @@ -334,7 +399,7 @@ impl DynGroupedAccumulator for GroupedAccumulator { group_ids.len() ); - self.validate_group_ids(group_ids, num_groups)?; + let group_ids = group_ids.validated_ids(ctx)?; self.ensure_groups(num_groups)?; for (row_idx, &group_id) in group_ids.iter().enumerate() { diff --git a/vortex-array/src/aggregate_fn/fns/count/grouped.rs b/vortex-array/src/aggregate_fn/fns/count/grouped.rs index 03e2b1b49ae..5902910f166 100644 --- a/vortex-array/src/aggregate_fn/fns/count/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/count/grouped.rs @@ -3,20 +3,51 @@ use vortex_error::VortexResult; +use super::Count; use crate::ArrayRef; use crate::ExecutionCtx; +use crate::IntoArray; +use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::GroupIds; +use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; +use crate::aggregate_fn::kernels::GroupedAggregateKernelResult; +use crate::arrays::PrimitiveArray; -pub(super) fn try_accumulate_grouped( - states: &mut [u64], +#[derive(Debug)] +pub(crate) struct CountGroupedKernel; + +impl DynGroupedAggregateKernel for CountGroupedKernel { + fn grouped_aggregate( + &self, + aggregate_fn: &AggregateFnRef, + batch: &ArrayRef, + group_ids: &GroupIds, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + if aggregate_fn.as_opt::().is_none() { + return Ok(None); + } + + let partials = accumulate_grouped(batch, group_ids, ctx)?; + Ok(Some(GroupedAggregateKernelResult::dense( + PrimitiveArray::from_iter(partials).into_array(), + group_ids.num_groups(), + )?)) + } +} + +fn accumulate_grouped( batch: &ArrayRef, - group_ids: &[u32], + group_ids: &GroupIds, ctx: &mut ExecutionCtx, -) -> VortexResult { +) -> VortexResult> { + let ids = group_ids.validated_ids(ctx)?; + let mut partials = vec![0u64; group_ids.num_groups()]; let validity = batch.validity()?.execute_mask(batch.len(), ctx)?; - for (&group_id, valid) in group_ids.iter().zip(validity.iter()) { + for (&group_id, valid) in ids.iter().zip(validity.iter()) { if valid { - states[group_id as usize] += 1; + partials[group_id as usize] += 1; } } - Ok(true) + Ok(partials) } diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index e53a378b5a9..73ad3ac5fbc 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors mod grouped; +pub(crate) use grouped::CountGroupedKernel; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -95,16 +96,6 @@ impl AggregateFnVTable for Count { Ok(true) } - fn try_accumulate_grouped( - &self, - states: &mut [Self::Partial], - batch: &ArrayRef, - group_ids: &[u32], - ctx: &mut ExecutionCtx, - ) -> VortexResult { - grouped::try_accumulate_grouped(states, batch, group_ids, ctx) - } - fn accumulate( &self, _partial: &mut Self::Partial, @@ -139,6 +130,7 @@ mod tests { use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::DynGroupedAccumulator; use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::GroupIds; use crate::aggregate_fn::GroupedAccumulator; use crate::aggregate_fn::fns::count::Count; use crate::arrays::ChunkedArray; @@ -258,10 +250,10 @@ mod tests { num_groups: usize, ) -> VortexResult { let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; + let group_ids = GroupIds::from_iter(group_ids.iter().copied(), num_groups)?; acc.accumulate( values, - group_ids, - num_groups, + &group_ids, &mut LEGACY_SESSION.create_execution_ctx(), )?; acc.finish(num_groups) @@ -307,13 +299,30 @@ mod tests { Ok(()) } + #[test] + fn grouped_count_constant_group_ids() -> VortexResult<()> { + let values = + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4)]).into_array(); + let group_ids = GroupIds::new(ConstantArray::new(1u32, values.len()).into_array(), 3)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; + + acc.accumulate(&values, &group_ids, &mut ctx)?; + let actual = acc.finish(3)?; + + let expected = PrimitiveArray::from_iter([0u64, 3, 0]).into_array(); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + #[test] fn grouped_count_rejects_out_of_range_group_id() -> VortexResult<()> { let values = PrimitiveArray::new(buffer![1i32, 2], Validity::NonNullable).into_array(); let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let group_ids = GroupIds::from_iter([0u32, 2], 2)?; - assert!(acc.accumulate(&values, &[0, 2], 2, &mut ctx).is_err()); + assert!(acc.accumulate(&values, &group_ids, &mut ctx).is_err()); Ok(()) } @@ -324,7 +333,8 @@ mod tests { let mut ctx = LEGACY_SESSION.create_execution_ctx(); let mut left = GroupedAccumulator::try_new(Count, EmptyOptions, dtype.clone())?; - left.accumulate_partials(&partials, &[0, 1, 1], 2, &mut ctx)?; + let group_ids = GroupIds::from_iter([0u32, 1, 1], 2)?; + left.accumulate_partials(&partials, &group_ids, &mut ctx)?; let mut right = GroupedAccumulator::try_new(Count, EmptyOptions, dtype)?; right.merge_group(0, &left, 1)?; diff --git a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs index 81304f1eb9f..2c9728b62aa 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs @@ -5,10 +5,12 @@ use num_traits::AsPrimitive; use num_traits::ToPrimitive; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_error::vortex_panic; use vortex_mask::AllOr; use vortex_mask::Mask; +use super::Sum; use super::SumPartial; use super::SumState; use super::checked_add_i64; @@ -16,7 +18,15 @@ use super::checked_add_u64; use super::primitive::sum_float_all; use super::primitive::sum_signed_all; use super::primitive::sum_unsigned_all; +use crate::ArrayRef; +use crate::Canonical; +use crate::Columnar; use crate::ExecutionCtx; +use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::GroupIds; +use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; +use crate::aggregate_fn::kernels::GroupedAggregateKernelResult; use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; use crate::arrays::bool::BoolArrayExt; @@ -25,6 +35,61 @@ use crate::match_each_native_ptype; const MIN_AVG_RUN_LENGTH_FOR_GROUPED_SUM_RUNS: usize = 4; +#[derive(Debug)] +pub(crate) struct SumGroupedKernel; + +impl DynGroupedAggregateKernel for SumGroupedKernel { + fn grouped_aggregate( + &self, + aggregate_fn: &AggregateFnRef, + batch: &ArrayRef, + group_ids: &GroupIds, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let Some(options) = aggregate_fn.as_opt::() else { + return Ok(None); + }; + + let columnar = batch.clone().execute::(ctx)?; + match &columnar { + Columnar::Canonical(Canonical::Primitive(_)) + | Columnar::Canonical(Canonical::Bool(_)) => {} + // Decimal and constants still use the universal grouped fallback. + Columnar::Canonical(Canonical::Decimal(_)) | Columnar::Constant(_) => return Ok(None), + Columnar::Canonical(_) => { + vortex_bail!("Unsupported canonical type for sum: {}", columnar.dtype()) + } + } + + let partial_dtype = Sum + .partial_dtype(options, batch.dtype()) + .ok_or_else(|| vortex_error::vortex_err!("Unsupported sum dtype: {}", batch.dtype()))?; + let ids = group_ids.validated_ids(ctx)?; + let mut partials = (0..group_ids.num_groups()) + .map(|_| Sum.empty_partial(options, batch.dtype())) + .collect::>>()?; + + match &columnar { + Columnar::Canonical(Canonical::Primitive(p)) => { + accumulate_grouped_primitive(&mut partials, p, ids.as_ref(), ctx)?; + } + Columnar::Canonical(Canonical::Bool(b)) => { + accumulate_grouped_bool(&mut partials, b, ids.as_ref(), ctx)?; + } + Columnar::Canonical(Canonical::Decimal(_)) | Columnar::Constant(_) => unreachable!(), + Columnar::Canonical(_) => unreachable!(), + } + + let Some(partials) = Sum.partials_to_array(&partials, &partial_dtype)? else { + return Ok(None); + }; + Ok(Some(GroupedAggregateKernelResult::dense( + partials, + group_ids.num_groups(), + )?)) + } +} + fn for_each_valid_idx(validity: &Mask, len: usize, mut f: impl FnMut(usize)) { match validity.indices() { AllOr::All => { diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index eff487d55e8..5219933bc20 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -7,6 +7,7 @@ mod decimal; mod grouped; mod primitive; +pub(crate) use grouped::SumGroupedKernel; use vortex_buffer::Buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -281,30 +282,6 @@ impl AggregateFnVTable for Sum { Ok(()) } - fn accumulate_grouped( - &self, - partials: &mut [Self::Partial], - batch: &Columnar, - group_ids: &[u32], - ctx: &mut ExecutionCtx, - ) -> VortexResult { - match batch { - Columnar::Canonical(Canonical::Primitive(p)) => { - grouped::accumulate_grouped_primitive(partials, p, group_ids, ctx)?; - Ok(true) - } - Columnar::Canonical(Canonical::Bool(b)) => { - grouped::accumulate_grouped_bool(partials, b, group_ids, ctx)?; - Ok(true) - } - // Decimal and constants still use the universal grouped fallback. - Columnar::Canonical(Canonical::Decimal(_)) | Columnar::Constant(_) => Ok(false), - Columnar::Canonical(_) => { - vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()) - } - } - } - fn finalize(&self, partials: ArrayRef) -> VortexResult { Ok(partials) } @@ -439,6 +416,7 @@ mod tests { use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::DynGroupedAccumulator; use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::GroupIds; use crate::aggregate_fn::GroupedAccumulator; use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::fns::sum::sum; @@ -616,10 +594,10 @@ mod tests { num_groups: usize, ) -> VortexResult { let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, values.dtype().clone())?; + let group_ids = GroupIds::from_iter(group_ids.iter().copied(), num_groups)?; acc.accumulate( values, - group_ids, - num_groups, + &group_ids, &mut LEGACY_SESSION.create_execution_ctx(), )?; acc.finish(num_groups) @@ -689,14 +667,16 @@ mod tests { let values1 = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); - acc.accumulate(&values1, &[0, 0, 1, 1], 2, &mut ctx)?; + let group_ids1 = GroupIds::from_iter([0u32, 0, 1, 1], 2)?; + acc.accumulate(&values1, &group_ids1, &mut ctx)?; let result1 = acc.finish(2)?; let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array(); assert_arrays_eq!(&result1, &expected1); let values2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); - acc.accumulate(&values2, &[0, 0], 1, &mut ctx)?; + let group_ids2 = GroupIds::from_iter([0u32, 0], 1)?; + acc.accumulate(&values2, &group_ids2, &mut ctx)?; let result2 = acc.finish(1)?; let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array(); diff --git a/vortex-array/src/aggregate_fn/kernels.rs b/vortex-array/src/aggregate_fn/kernels.rs index e0b1d42e41e..ff1d56dc615 100644 --- a/vortex-array/src/aggregate_fn/kernels.rs +++ b/vortex-array/src/aggregate_fn/kernels.rs @@ -6,12 +6,12 @@ use std::fmt::Debug; -use vortex_buffer::Buffer; use vortex_error::VortexResult; use crate::ArrayRef; use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::GroupIds; use crate::scalar::Scalar; /// A pluggable kernel for an aggregate function. @@ -35,20 +35,27 @@ pub trait DynAggregateKernel: 'static + Send + Sync + Debug { /// batch through `accumulate_partials`. #[derive(Clone, Debug)] pub struct GroupedAggregateKernelResult { - group_ids: Buffer, + group_ids: GroupIds, partials: ArrayRef, } impl GroupedAggregateKernelResult { - pub fn new(group_ids: Buffer, partials: ArrayRef) -> Self { + pub fn new(group_ids: GroupIds, partials: ArrayRef) -> Self { Self { group_ids, partials, } } - pub fn group_ids(&self) -> &[u32] { - self.group_ids.as_ref() + pub fn dense(partials: ArrayRef, num_groups: usize) -> VortexResult { + Ok(Self { + group_ids: GroupIds::range(num_groups)?, + partials, + }) + } + + pub fn group_ids(&self) -> &GroupIds { + &self.group_ids } pub fn partials(&self) -> &ArrayRef { @@ -58,9 +65,8 @@ impl GroupedAggregateKernelResult { /// A pluggable kernel for batch aggregation of many groups. /// -/// A grouped kernel can be registered for an aggregate function regardless of input encoding, or -/// for a specific aggregate function and array encoding. Encoding-specific kernels are matched on -/// the values array, not on a pre-grouped list wrapper. +/// A grouped kernel can be registered for an aggregate function regardless of input encodings, or +/// for a specific aggregate function plus values and/or group-id encoding. /// /// Kernels receive the same dense group ordinals that the caller passed to the grouped accumulator /// and may aggregate directly in the encoded domain. @@ -72,8 +78,7 @@ pub trait DynGroupedAggregateKernel: 'static + Send + Sync + Debug { &self, aggregate_fn: &AggregateFnRef, batch: &ArrayRef, - group_ids: &[u32], - num_groups: usize, + group_ids: &GroupIds, ctx: &mut ExecutionCtx, ) -> VortexResult>; } diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index 78b139bf36f..359b7e3c910 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -18,6 +18,8 @@ use crate::aggregate_fn::fns::all_non_null::AllNonNull; use crate::aggregate_fn::fns::all_null::AllNull; use crate::aggregate_fn::fns::bounded_max::BoundedMax; use crate::aggregate_fn::fns::bounded_min::BoundedMin; +use crate::aggregate_fn::fns::count::Count; +use crate::aggregate_fn::fns::count::CountGroupedKernel; use crate::aggregate_fn::fns::first::First; use crate::aggregate_fn::fns::is_constant::IsConstant; use crate::aggregate_fn::fns::is_sorted::IsSorted; @@ -28,6 +30,7 @@ use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::nan_count::NanCount; use crate::aggregate_fn::fns::null_count::NullCount; use crate::aggregate_fn::fns::sum::Sum; +use crate::aggregate_fn::fns::sum::SumGroupedKernel; use crate::aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes; use crate::aggregate_fn::kernels::DynAggregateKernel; use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; @@ -51,9 +54,7 @@ pub struct AggregateFnSession { registry: ArcSwapMap, kernels: ArcSwapMap, - grouped_kernels: ArcSwapMap, - grouped_encoding_kernels: - ArcSwapMap, + grouped_kernels: ArcSwapMap, } impl SessionVar for AggregateFnSession { @@ -67,7 +68,7 @@ impl SessionVar for AggregateFnSession { } type AggregateKernelKey = (ArrayId, Option); -type GroupedEncodingKernelKey = (ArrayId, AggregateFnId); +type GroupedAggregateKernelKey = (AggregateFnId, Option, Option); impl Default for AggregateFnSession { fn default() -> Self { @@ -75,7 +76,6 @@ impl Default for AggregateFnSession { registry: ArcSwapMap::default(), kernels: ArcSwapMap::default(), grouped_kernels: ArcSwapMap::default(), - grouped_encoding_kernels: ArcSwapMap::default(), }; // Register the built-in aggregate functions @@ -103,6 +103,8 @@ impl Default for AggregateFnSession { this.register_aggregate_kernel(Dict.id(), Some(MinMax.id()), &DictMinMaxKernel); this.register_aggregate_kernel(Dict.id(), Some(IsConstant.id()), &DictIsConstantKernel); this.register_aggregate_kernel(Dict.id(), Some(IsSorted.id()), &DictIsSortedKernel); + this.register_grouped_kernel(Count.id(), None, None, &CountGroupedKernel); + this.register_grouped_kernel(Sum.id(), None, None, &SumGroupedKernel); this } @@ -156,54 +158,44 @@ impl AggregateFnSession { self.kernels.insert(id, kernel); } - /// Returns the grouped aggregate kernel registered for `agg_fn_id`, if any. + /// Returns the grouped aggregate kernel registered for this aggregate and pair of encodings. /// - /// These kernels are independent of the element encoding and are checked for each element - /// representation, after any kernel registered for the current element encoding. + /// Lookup first checks the exact `(aggregate, values encoding, group ids encoding)` key, then + /// falls back through `(aggregate, values encoding, any group ids)`, `(aggregate, any values, + /// group ids encoding)`, and finally `(aggregate, any values, any group ids)`. pub fn find_grouped_kernel( &self, agg_fn_id: impl Into, + values_id: impl Into, + group_ids_id: impl Into, ) -> Option<&'static dyn DynGroupedAggregateKernel> { let fn_id = agg_fn_id.into(); - self.grouped_kernels - .read(|kernels| kernels.get(&fn_id).copied()) - } - - /// Registers a grouped aggregate kernel for an aggregate function. - pub fn register_grouped_kernel( - &self, - agg_fn_id: impl Into, - kernel: &'static dyn DynGroupedAggregateKernel, - ) { - let fn_id = agg_fn_id.into(); - self.grouped_kernels.insert(fn_id, kernel) + let values_id = values_id.into(); + let group_ids_id = group_ids_id.into(); + self.grouped_kernels.read(|kernels| { + kernels + .get(&(fn_id, Some(values_id), Some(group_ids_id))) + .or_else(|| kernels.get(&(fn_id, Some(values_id), None))) + .or_else(|| kernels.get(&(fn_id, None, Some(group_ids_id)))) + .or_else(|| kernels.get(&(fn_id, None, None))) + .copied() + }) } - /// Returns the grouped aggregate kernel registered for `array_id` and `agg_fn_id`, if any. + /// Registers a grouped aggregate kernel. /// - /// These kernels are matched against each intermediate element encoding while the grouped - /// accumulator executes the element array. - pub fn find_grouped_encoding_kernel( - &self, - array_id: impl Into, - agg_fn_id: impl Into, - ) -> Option<&'static dyn DynGroupedAggregateKernel> { - let id = array_id.into(); - let fn_id = agg_fn_id.into(); - self.grouped_encoding_kernels - .read(|kernels| kernels.get(&(id, fn_id)).copied()) - } - - /// Registers a grouped aggregate kernel for a specific aggregate function and array encoding. - pub fn register_grouped_encoding_kernel( + /// `values_id` and `group_ids_id` are optional wildcards. Passing `None` for either dimension + /// makes the kernel a fallback for that encoding dimension. + pub fn register_grouped_kernel( &self, - array_id: impl Into, agg_fn_id: impl Into, + values_id: Option, + group_ids_id: Option, kernel: &'static dyn DynGroupedAggregateKernel, ) { - let id = array_id.into(); let fn_id = agg_fn_id.into(); - self.grouped_encoding_kernels.insert((id, fn_id), kernel) + self.grouped_kernels + .insert((fn_id, values_id, group_ids_id), kernel) } } @@ -215,3 +207,84 @@ pub trait AggregateFnSessionExt: SessionExt { } } impl AggregateFnSessionExt for S {} + +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + + use super::*; + use crate::ArrayRef; + use crate::ExecutionCtx; + use crate::aggregate_fn::AggregateFnRef; + use crate::aggregate_fn::GroupIds; + use crate::aggregate_fn::kernels::GroupedAggregateKernelResult; + use crate::arrays::Constant; + use crate::arrays::Primitive; + + #[derive(Debug)] + struct TestGroupedKernel; + + impl DynGroupedAggregateKernel for TestGroupedKernel { + fn grouped_aggregate( + &self, + _aggregate_fn: &AggregateFnRef, + _batch: &ArrayRef, + _group_ids: &GroupIds, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + Ok(None) + } + } + + static GENERIC_KERNEL: TestGroupedKernel = TestGroupedKernel; + static GROUP_IDS_KERNEL: TestGroupedKernel = TestGroupedKernel; + static VALUES_KERNEL: TestGroupedKernel = TestGroupedKernel; + static EXACT_KERNEL: TestGroupedKernel = TestGroupedKernel; + + fn assert_same_kernel( + actual: Option<&'static dyn DynGroupedAggregateKernel>, + expected: &'static dyn DynGroupedAggregateKernel, + ) { + assert!(std::ptr::eq( + actual.expect("expected registered grouped kernel"), + expected + )); + } + + #[test] + fn grouped_kernel_lookup_prefers_exact_then_value_then_group_ids() { + let session = AggregateFnSession::default(); + let aggregate_id = AggregateFnId::new("test.grouped_lookup"); + let values_id = Primitive.id(); + let group_ids_id = Constant.id(); + + session.register_grouped_kernel(aggregate_id, None, None, &GENERIC_KERNEL); + assert_same_kernel( + session.find_grouped_kernel(aggregate_id, values_id, group_ids_id), + &GENERIC_KERNEL, + ); + + session.register_grouped_kernel(aggregate_id, None, Some(group_ids_id), &GROUP_IDS_KERNEL); + assert_same_kernel( + session.find_grouped_kernel(aggregate_id, values_id, group_ids_id), + &GROUP_IDS_KERNEL, + ); + + session.register_grouped_kernel(aggregate_id, Some(values_id), None, &VALUES_KERNEL); + assert_same_kernel( + session.find_grouped_kernel(aggregate_id, values_id, group_ids_id), + &VALUES_KERNEL, + ); + + session.register_grouped_kernel( + aggregate_id, + Some(values_id), + Some(group_ids_id), + &EXACT_KERNEL, + ); + assert_same_kernel( + session.find_grouped_kernel(aggregate_id, values_id, group_ids_id), + &EXACT_KERNEL, + ); + } +} diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 09eab6c5a9c..ab9edae5862 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -157,37 +157,6 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { ctx: &mut ExecutionCtx, ) -> VortexResult<()>; - /// Try to accumulate a raw values batch into dense per-group states before decompression. - /// - /// `group_ids` is parallel to `batch` and contains caller-assigned dense ordinals in - /// `0..states.len()`. Ids may repeat, appear out of order, or be absent from the batch. - /// Returns `true` when the batch was fully handled. - fn try_accumulate_grouped( - &self, - _states: &mut [Self::Partial], - _batch: &ArrayRef, - _group_ids: &[u32], - _ctx: &mut ExecutionCtx, - ) -> VortexResult { - Ok(false) - } - - /// Accumulate a canonical values batch into dense per-group states. - /// - /// `group_ids` is parallel to `batch` and contains caller-assigned dense ordinals in - /// `0..states.len()`. Ids may repeat, appear out of order, or be absent from the batch. - /// Returns `true` when the batch was fully handled. The provided default preserves universal - /// correctness through [`crate::aggregate_fn::GroupedAccumulator`]'s fallback. - fn accumulate_grouped( - &self, - _states: &mut [Self::Partial], - _batch: &Columnar, - _group_ids: &[u32], - _ctx: &mut ExecutionCtx, - ) -> VortexResult { - Ok(false) - } - /// Finalize an array of accumulator states into an array of aggregate results. /// /// The provides `states` array has dtype as specified by `state_dtype`, the result array From 7c748a8c5d6f237be80e60b393f9f5d7e8f6cca8 Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Thu, 25 Jun 2026 11:54:56 +0100 Subject: [PATCH 2/2] comments Signed-off-by: Onur Satici --- .../src/aggregate_fn/accumulator_grouped.rs | 36 ++---- .../src/aggregate_fn/fns/count/grouped.rs | 51 +++----- .../src/aggregate_fn/fns/count/mod.rs | 2 +- .../src/aggregate_fn/fns/sum/grouped.rs | 72 ++++------- vortex-array/src/aggregate_fn/fns/sum/mod.rs | 2 +- vortex-array/src/aggregate_fn/kernels.rs | 117 +++++++++++++----- vortex-array/src/aggregate_fn/session.rs | 70 ++++++++--- 7 files changed, 193 insertions(+), 157 deletions(-) diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 2e51fb17782..46064e3b000 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -14,7 +14,6 @@ use crate::aggregate_fn::AggregateFn; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; -use crate::aggregate_fn::kernels::GroupedAggregateKernelResult; use crate::aggregate_fn::session::AggregateFnSessionExt; use crate::array::ArrayId; use crate::arrays::PrimitiveArray; @@ -67,22 +66,6 @@ impl GroupIds { Self::from_buffer(Buffer::from_iter(ids), num_groups) } - /// Create group ids containing `0..num_groups`. - pub fn range(num_groups: usize) -> VortexResult { - validate_num_groups(num_groups)?; - if num_groups == 0 { - return Self::from_buffer(Buffer::::empty(), num_groups); - } - - let last = u32::try_from(num_groups - 1).map_err(|_| { - vortex_err!( - "num_groups {} exceeds dense u32 group id capacity", - num_groups - ) - })?; - Self::from_buffer((0..=last).collect(), num_groups) - } - /// Return the encoded ids array. pub fn ids(&self) -> &ArrayRef { &self.ids @@ -178,14 +161,6 @@ impl GroupedAccumulator { Ok(()) } - fn accumulate_kernel_result( - &mut self, - result: GroupedAggregateKernelResult, - ctx: &mut ExecutionCtx, - ) -> VortexResult<()> { - self.accumulate_partials(result.partials(), result.group_ids(), ctx) - } - fn try_accumulate_kernel( &mut self, batch: &ArrayRef, @@ -198,10 +173,13 @@ impl GroupedAccumulator { self.aggregate_fn.id(), batch.encoding_id(), group_ids.encoding_id(), - ) && let Some(result) = - kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, ctx)? - { - self.accumulate_kernel_result(result, ctx)?; + ) && kernel.grouped_accumulate( + &self.aggregate_fn, + batch, + group_ids, + &mut self.partials, + ctx, + )? { return Ok(true); } diff --git a/vortex-array/src/aggregate_fn/fns/count/grouped.rs b/vortex-array/src/aggregate_fn/fns/count/grouped.rs index 5902910f166..68ea4e05d26 100644 --- a/vortex-array/src/aggregate_fn/fns/count/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/count/grouped.rs @@ -6,48 +6,33 @@ use vortex_error::VortexResult; use super::Count; use crate::ArrayRef; use crate::ExecutionCtx; -use crate::IntoArray; -use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::EmptyOptions; use crate::aggregate_fn::GroupIds; -use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; -use crate::aggregate_fn::kernels::GroupedAggregateKernelResult; -use crate::arrays::PrimitiveArray; +use crate::aggregate_fn::kernels::GroupedAggregateKernel; +use crate::aggregate_fn::kernels::GroupedAggregateKernelAdapter; + +pub(crate) static COUNT_GROUPED_KERNEL: GroupedAggregateKernelAdapter = + GroupedAggregateKernelAdapter::new(CountGroupedKernel); #[derive(Debug)] pub(crate) struct CountGroupedKernel; -impl DynGroupedAggregateKernel for CountGroupedKernel { - fn grouped_aggregate( +impl GroupedAggregateKernel for CountGroupedKernel { + fn grouped_accumulate( &self, - aggregate_fn: &AggregateFnRef, + _options: &EmptyOptions, + states: &mut [u64], batch: &ArrayRef, group_ids: &GroupIds, ctx: &mut ExecutionCtx, - ) -> VortexResult> { - if aggregate_fn.as_opt::().is_none() { - return Ok(None); - } - - let partials = accumulate_grouped(batch, group_ids, ctx)?; - Ok(Some(GroupedAggregateKernelResult::dense( - PrimitiveArray::from_iter(partials).into_array(), - group_ids.num_groups(), - )?)) - } -} - -fn accumulate_grouped( - batch: &ArrayRef, - group_ids: &GroupIds, - ctx: &mut ExecutionCtx, -) -> VortexResult> { - let ids = group_ids.validated_ids(ctx)?; - let mut partials = vec![0u64; group_ids.num_groups()]; - let validity = batch.validity()?.execute_mask(batch.len(), ctx)?; - for (&group_id, valid) in ids.iter().zip(validity.iter()) { - if valid { - partials[group_id as usize] += 1; + ) -> VortexResult { + let group_ids = group_ids.validated_ids(ctx)?; + let validity = batch.validity()?.execute_mask(batch.len(), ctx)?; + for (&group_id, valid) in group_ids.iter().zip(validity.iter()) { + if valid { + states[group_id as usize] += 1; + } } + Ok(true) } - Ok(partials) } diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index 73ad3ac5fbc..c6a9c27d52f 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -2,7 +2,7 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors mod grouped; -pub(crate) use grouped::CountGroupedKernel; +pub(crate) use grouped::COUNT_GROUPED_KERNEL; use vortex_error::VortexExpect; use vortex_error::VortexResult; diff --git a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs index 2c9728b62aa..e7a73059fc3 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs @@ -5,7 +5,6 @@ use num_traits::AsPrimitive; use num_traits::ToPrimitive; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_panic; use vortex_mask::AllOr; use vortex_mask::Mask; @@ -19,15 +18,14 @@ use super::primitive::sum_float_all; use super::primitive::sum_signed_all; use super::primitive::sum_unsigned_all; use crate::ArrayRef; -use crate::Canonical; -use crate::Columnar; use crate::ExecutionCtx; -use crate::aggregate_fn::AggregateFnRef; -use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::EmptyOptions; use crate::aggregate_fn::GroupIds; -use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; -use crate::aggregate_fn::kernels::GroupedAggregateKernelResult; +use crate::aggregate_fn::kernels::GroupedAggregateKernel; +use crate::aggregate_fn::kernels::GroupedAggregateKernelAdapter; +use crate::arrays::Bool; use crate::arrays::BoolArray; +use crate::arrays::Primitive; use crate::arrays::PrimitiveArray; use crate::arrays::bool::BoolArrayExt; use crate::dtype::NativePType; @@ -35,58 +33,36 @@ use crate::match_each_native_ptype; const MIN_AVG_RUN_LENGTH_FOR_GROUPED_SUM_RUNS: usize = 4; +pub(crate) static SUM_GROUPED_KERNEL: GroupedAggregateKernelAdapter = + GroupedAggregateKernelAdapter::new(SumGroupedKernel); + #[derive(Debug)] pub(crate) struct SumGroupedKernel; -impl DynGroupedAggregateKernel for SumGroupedKernel { - fn grouped_aggregate( +impl GroupedAggregateKernel for SumGroupedKernel { + fn grouped_accumulate( &self, - aggregate_fn: &AggregateFnRef, + _options: &EmptyOptions, + partials: &mut [SumPartial], batch: &ArrayRef, group_ids: &GroupIds, ctx: &mut ExecutionCtx, - ) -> VortexResult> { - let Some(options) = aggregate_fn.as_opt::() else { - return Ok(None); - }; - - let columnar = batch.clone().execute::(ctx)?; - match &columnar { - Columnar::Canonical(Canonical::Primitive(_)) - | Columnar::Canonical(Canonical::Bool(_)) => {} - // Decimal and constants still use the universal grouped fallback. - Columnar::Canonical(Canonical::Decimal(_)) | Columnar::Constant(_) => return Ok(None), - Columnar::Canonical(_) => { - vortex_bail!("Unsupported canonical type for sum: {}", columnar.dtype()) - } + ) -> VortexResult { + if let Some(primitive) = batch.as_opt::() { + let group_ids = group_ids.validated_ids(ctx)?; + let primitive = primitive.into_owned(); + accumulate_grouped_primitive(partials, &primitive, group_ids.as_ref(), ctx)?; + return Ok(true); } - let partial_dtype = Sum - .partial_dtype(options, batch.dtype()) - .ok_or_else(|| vortex_error::vortex_err!("Unsupported sum dtype: {}", batch.dtype()))?; - let ids = group_ids.validated_ids(ctx)?; - let mut partials = (0..group_ids.num_groups()) - .map(|_| Sum.empty_partial(options, batch.dtype())) - .collect::>>()?; - - match &columnar { - Columnar::Canonical(Canonical::Primitive(p)) => { - accumulate_grouped_primitive(&mut partials, p, ids.as_ref(), ctx)?; - } - Columnar::Canonical(Canonical::Bool(b)) => { - accumulate_grouped_bool(&mut partials, b, ids.as_ref(), ctx)?; - } - Columnar::Canonical(Canonical::Decimal(_)) | Columnar::Constant(_) => unreachable!(), - Columnar::Canonical(_) => unreachable!(), + if let Some(bools) = batch.as_opt::() { + let group_ids = group_ids.validated_ids(ctx)?; + let bools = bools.into_owned(); + accumulate_grouped_bool(partials, &bools, group_ids.as_ref(), ctx)?; + return Ok(true); } - let Some(partials) = Sum.partials_to_array(&partials, &partial_dtype)? else { - return Ok(None); - }; - Ok(Some(GroupedAggregateKernelResult::dense( - partials, - group_ids.num_groups(), - )?)) + Ok(false) } } diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 5219933bc20..207b8140922 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -7,7 +7,7 @@ mod decimal; mod grouped; mod primitive; -pub(crate) use grouped::SumGroupedKernel; +pub(crate) use grouped::SUM_GROUPED_KERNEL; use vortex_buffer::Buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; diff --git a/vortex-array/src/aggregate_fn/kernels.rs b/vortex-array/src/aggregate_fn/kernels.rs index ff1d56dc615..51d47c33a2e 100644 --- a/vortex-array/src/aggregate_fn/kernels.rs +++ b/vortex-array/src/aggregate_fn/kernels.rs @@ -4,13 +4,18 @@ //! Pluggable aggregate function kernels used to provide encoding-specific implementations of //! aggregate functions. +use std::any::Any; use std::fmt::Debug; +use std::marker::PhantomData; use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; use crate::ArrayRef; use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::GroupIds; use crate::scalar::Scalar; @@ -27,39 +32,47 @@ pub trait DynAggregateKernel: 'static + Send + Sync + Debug { ) -> VortexResult>; } -/// Partial grouped aggregate output produced by an encoding-specific grouped kernel. +/// A typed grouped aggregate kernel. /// -/// `group_ids` is parallel to `partials`: each row in `partials` is a partial state for the -/// corresponding dense group ordinal. The ids may repeat, omit, and reorder groups, but must be -/// valid slots in the accumulator's `0..num_groups` range. The grouped accumulator merges this -/// batch through `accumulate_partials`. -#[derive(Clone, Debug)] -pub struct GroupedAggregateKernelResult { - group_ids: GroupIds, - partials: ArrayRef, +/// Implementations receive the concrete aggregate options and typed partial state. Return +/// `Ok(false)` when the kernel cannot handle the current values or group-id encodings. +pub trait GroupedAggregateKernel: 'static + Send + Sync + Debug { + /// Accumulate `batch` into `states` according to `group_ids`. + fn grouped_accumulate( + &self, + options: &V::Options, + states: &mut [V::Partial], + batch: &ArrayRef, + group_ids: &GroupIds, + ctx: &mut ExecutionCtx, + ) -> VortexResult; +} + +/// Bridges a typed [`GroupedAggregateKernel`] to type-erased grouped kernel dispatch. +pub struct GroupedAggregateKernelAdapter { + kernel: K, + _phantom: PhantomData V>, } -impl GroupedAggregateKernelResult { - pub fn new(group_ids: GroupIds, partials: ArrayRef) -> Self { +impl GroupedAggregateKernelAdapter { + /// Create a new adapter around `kernel`. + pub const fn new(kernel: K) -> Self { Self { - group_ids, - partials, + kernel, + _phantom: PhantomData, } } +} - pub fn dense(partials: ArrayRef, num_groups: usize) -> VortexResult { - Ok(Self { - group_ids: GroupIds::range(num_groups)?, - partials, - }) - } - - pub fn group_ids(&self) -> &GroupIds { - &self.group_ids - } - - pub fn partials(&self) -> &ArrayRef { - &self.partials +impl Debug for GroupedAggregateKernelAdapter +where + V: AggregateFnVTable, + K: GroupedAggregateKernel, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("GroupedAggregateKernelAdapter") + .field("kernel", &self.kernel) + .finish() } } @@ -71,14 +84,58 @@ impl GroupedAggregateKernelResult { /// Kernels receive the same dense group ordinals that the caller passed to the grouped accumulator /// and may aggregate directly in the encoded domain. /// -/// Return `Ok(None)` if the kernel cannot be applied to the given aggregate function. +/// Return `Ok(false)` if the kernel cannot be applied to the given aggregate function or input +/// encodings. pub trait DynGroupedAggregateKernel: 'static + Send + Sync + Debug { - /// Aggregate values into a partial-state batch keyed by dense group ordinal. - fn grouped_aggregate( + /// Accumulate values into type-erased partial state. + fn grouped_accumulate( &self, aggregate_fn: &AggregateFnRef, batch: &ArrayRef, group_ids: &GroupIds, + states: &mut dyn Any, ctx: &mut ExecutionCtx, - ) -> VortexResult>; + ) -> VortexResult; +} + +impl DynGroupedAggregateKernel for GroupedAggregateKernelAdapter +where + V: AggregateFnVTable, + K: GroupedAggregateKernel, +{ + fn grouped_accumulate( + &self, + aggregate_fn: &AggregateFnRef, + batch: &ArrayRef, + group_ids: &GroupIds, + states: &mut dyn Any, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let Some(options) = aggregate_fn.as_opt::() else { + return Ok(false); + }; + + let Some(states) = states.downcast_mut::>() else { + vortex_bail!( + "Grouped aggregate kernel for {} received incompatible partial state", + aggregate_fn.id() + ); + }; + + vortex_ensure!( + states.len() >= group_ids.num_groups(), + "Grouped aggregate kernel for {} received {} partial states for {} groups", + aggregate_fn.id(), + states.len(), + group_ids.num_groups() + ); + + self.kernel.grouped_accumulate( + options, + &mut states[..group_ids.num_groups()], + batch, + group_ids, + ctx, + ) + } } diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index 359b7e3c910..14a5ccb261d 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -18,8 +18,8 @@ use crate::aggregate_fn::fns::all_non_null::AllNonNull; use crate::aggregate_fn::fns::all_null::AllNull; use crate::aggregate_fn::fns::bounded_max::BoundedMax; use crate::aggregate_fn::fns::bounded_min::BoundedMin; +use crate::aggregate_fn::fns::count::COUNT_GROUPED_KERNEL; use crate::aggregate_fn::fns::count::Count; -use crate::aggregate_fn::fns::count::CountGroupedKernel; use crate::aggregate_fn::fns::first::First; use crate::aggregate_fn::fns::is_constant::IsConstant; use crate::aggregate_fn::fns::is_sorted::IsSorted; @@ -29,8 +29,8 @@ use crate::aggregate_fn::fns::min::Min; use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::nan_count::NanCount; use crate::aggregate_fn::fns::null_count::NullCount; +use crate::aggregate_fn::fns::sum::SUM_GROUPED_KERNEL; use crate::aggregate_fn::fns::sum::Sum; -use crate::aggregate_fn::fns::sum::SumGroupedKernel; use crate::aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes; use crate::aggregate_fn::kernels::DynAggregateKernel; use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; @@ -68,7 +68,27 @@ impl SessionVar for AggregateFnSession { } type AggregateKernelKey = (ArrayId, Option); -type GroupedAggregateKernelKey = (AggregateFnId, Option, Option); + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +struct GroupedAggregateKernelKey { + aggregate_id: AggregateFnId, + values_id: Option, + group_ids_id: Option, +} + +impl GroupedAggregateKernelKey { + fn new( + aggregate_id: AggregateFnId, + values_id: Option, + group_ids_id: Option, + ) -> Self { + Self { + aggregate_id, + values_id, + group_ids_id, + } + } +} impl Default for AggregateFnSession { fn default() -> Self { @@ -103,8 +123,8 @@ impl Default for AggregateFnSession { this.register_aggregate_kernel(Dict.id(), Some(MinMax.id()), &DictMinMaxKernel); this.register_aggregate_kernel(Dict.id(), Some(IsConstant.id()), &DictIsConstantKernel); this.register_aggregate_kernel(Dict.id(), Some(IsSorted.id()), &DictIsSortedKernel); - this.register_grouped_kernel(Count.id(), None, None, &CountGroupedKernel); - this.register_grouped_kernel(Sum.id(), None, None, &SumGroupedKernel); + this.register_grouped_kernel(Count.id(), None, None, &COUNT_GROUPED_KERNEL); + this.register_grouped_kernel(Sum.id(), None, None, &SUM_GROUPED_KERNEL); this } @@ -174,10 +194,26 @@ impl AggregateFnSession { let group_ids_id = group_ids_id.into(); self.grouped_kernels.read(|kernels| { kernels - .get(&(fn_id, Some(values_id), Some(group_ids_id))) - .or_else(|| kernels.get(&(fn_id, Some(values_id), None))) - .or_else(|| kernels.get(&(fn_id, None, Some(group_ids_id)))) - .or_else(|| kernels.get(&(fn_id, None, None))) + .get(&GroupedAggregateKernelKey::new( + fn_id, + Some(values_id), + Some(group_ids_id), + )) + .or_else(|| { + kernels.get(&GroupedAggregateKernelKey::new( + fn_id, + Some(values_id), + None, + )) + }) + .or_else(|| { + kernels.get(&GroupedAggregateKernelKey::new( + fn_id, + None, + Some(group_ids_id), + )) + }) + .or_else(|| kernels.get(&GroupedAggregateKernelKey::new(fn_id, None, None))) .copied() }) } @@ -194,8 +230,10 @@ impl AggregateFnSession { kernel: &'static dyn DynGroupedAggregateKernel, ) { let fn_id = agg_fn_id.into(); - self.grouped_kernels - .insert((fn_id, values_id, group_ids_id), kernel) + self.grouped_kernels.insert( + GroupedAggregateKernelKey::new(fn_id, values_id, group_ids_id), + kernel, + ) } } @@ -210,6 +248,8 @@ impl AggregateFnSessionExt for S {} #[cfg(test)] mod tests { + use std::any::Any; + use vortex_error::VortexResult; use super::*; @@ -217,7 +257,6 @@ mod tests { use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::GroupIds; - use crate::aggregate_fn::kernels::GroupedAggregateKernelResult; use crate::arrays::Constant; use crate::arrays::Primitive; @@ -225,14 +264,15 @@ mod tests { struct TestGroupedKernel; impl DynGroupedAggregateKernel for TestGroupedKernel { - fn grouped_aggregate( + fn grouped_accumulate( &self, _aggregate_fn: &AggregateFnRef, _batch: &ArrayRef, _group_ids: &GroupIds, + _states: &mut dyn Any, _ctx: &mut ExecutionCtx, - ) -> VortexResult> { - Ok(None) + ) -> VortexResult { + Ok(false) } }