From f75450e9bb7bf20b43aa261e0735ccf35d0e079e Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Mon, 22 Dec 2025 21:10:25 +0800 Subject: [PATCH 1/2] feat: implement count-min sketch --- src/countmin/mod.rs | 27 +++ src/countmin/serialization.rs | 32 +++ src/countmin/sketch.rs | 369 ++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + tests/countmin_test.rs | 129 ++++++++++++ 5 files changed, 558 insertions(+) create mode 100644 src/countmin/mod.rs create mode 100644 src/countmin/serialization.rs create mode 100644 src/countmin/sketch.rs create mode 100644 tests/countmin_test.rs diff --git a/src/countmin/mod.rs b/src/countmin/mod.rs new file mode 100644 index 0000000..b7de4b8 --- /dev/null +++ b/src/countmin/mod.rs @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Count-Min sketch implementation for frequency estimation. +//! +//! The Count-Min sketch provides approximate frequency counts for streaming data +//! with configurable relative error and confidence bounds. + +mod serialization; + +mod sketch; +pub use self::sketch::CountMinSketch; +pub use self::sketch::DEFAULT_SEED; diff --git a/src/countmin/serialization.rs b/src/countmin/serialization.rs new file mode 100644 index 0000000..e941790 --- /dev/null +++ b/src/countmin/serialization.rs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::hash::MurmurHash3X64128; +use std::hash::Hasher; + +pub(super) const PREAMBLE_LONGS_SHORT: u8 = 2; +pub(super) const SERIAL_VERSION: u8 = 1; +pub(super) const COUNTMIN_FAMILY_ID: u8 = 18; +pub(super) const FLAGS_IS_EMPTY: u8 = 1 << 0; +pub(super) const LONG_SIZE_BYTES: usize = 8; + +pub(super) fn compute_seed_hash(seed: u64) -> u16 { + let mut hasher = MurmurHash3X64128::with_seed(0); + hasher.write(&seed.to_le_bytes()); + let (h1, _) = hasher.finish128(); + (h1 & 0xffff) as u16 +} diff --git a/src/countmin/sketch.rs b/src/countmin/sketch.rs new file mode 100644 index 0000000..eee1ec1 --- /dev/null +++ b/src/countmin/sketch.rs @@ -0,0 +1,369 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::countmin::serialization::{ + COUNTMIN_FAMILY_ID, FLAGS_IS_EMPTY, LONG_SIZE_BYTES, PREAMBLE_LONGS_SHORT, SERIAL_VERSION, + compute_seed_hash, +}; +use crate::error::SerdeError; +use crate::hash::MurmurHash3X64128; +use byteorder::{LE, ReadBytesExt}; +use std::hash::{Hash, Hasher}; +use std::io::Cursor; +use std::mem::size_of; + +const MAX_TABLE_ENTRIES: usize = 1 << 30; + +/// Default seed used by the sketch. +pub const DEFAULT_SEED: u64 = 9001; + +/// Count-Min sketch for estimating item frequencies. +/// +/// The sketch provides upper and lower bounds on estimated item frequencies +/// with configurable relative error and confidence. +#[derive(Debug, Clone, PartialEq)] +pub struct CountMinSketch { + num_hashes: u8, + num_buckets: u32, + seed: u64, + total_weight: i64, + counts: Vec, + hash_seeds: Vec, +} + +impl CountMinSketch { + /// Creates a new Count-Min sketch with the default seed. + /// + /// # Panics + /// + /// Panics if `num_hashes` is 0, `num_buckets` is less than 3, or the + /// total table size exceeds the supported limit. + pub fn new(num_hashes: u8, num_buckets: u32) -> Self { + Self::with_seed(num_hashes, num_buckets, DEFAULT_SEED) + } + + /// Creates a new Count-Min sketch with the provided seed. + /// + /// # Panics + /// + /// Panics if `num_hashes` is 0, `num_buckets` is less than 3, or the + /// total table size exceeds the supported limit. + pub fn with_seed(num_hashes: u8, num_buckets: u32, seed: u64) -> Self { + let entries = entries_for_config(num_hashes, num_buckets); + Self::make(num_hashes, num_buckets, seed, entries) + } + + /// Returns the number of hash functions used by the sketch. + pub fn num_hashes(&self) -> u8 { + self.num_hashes + } + + /// Returns the number of buckets per hash function. + pub fn num_buckets(&self) -> u32 { + self.num_buckets + } + + /// Returns the seed used by the sketch. + pub fn seed(&self) -> u64 { + self.seed + } + + /// Returns the total weight inserted into the sketch. + pub fn total_weight(&self) -> i64 { + self.total_weight + } + + /// Returns the relative error (epsilon) implied by the number of buckets. + pub fn relative_error(&self) -> f64 { + std::f64::consts::E / self.num_buckets as f64 + } + + /// Returns true if the sketch has not seen any updates. + pub fn is_empty(&self) -> bool { + self.total_weight == 0 + } + + /// Suggests the number of buckets to achieve the given relative error. + /// + /// # Panics + /// + /// Panics if `relative_error` is negative. + pub fn suggest_num_buckets(relative_error: f64) -> u32 { + assert!(relative_error >= 0.0, "relative_error must be at least 0"); + (std::f64::consts::E / relative_error).ceil() as u32 + } + + /// Suggests the number of hashes to achieve the given confidence. + /// + /// # Panics + /// + /// Panics if `confidence` is not in (0, 1]. + pub fn suggest_num_hashes(confidence: f64) -> u8 { + assert!( + (0.0..=1.0).contains(&confidence), + "confidence must be between 0 and 1.0 (inclusive)" + ); + if confidence == 1.0 { + return 127; + } + let hashes = (1.0 / (1.0 - confidence)).ln().ceil(); + hashes.min(127.0) as u8 + } + + /// Updates the sketch with a single occurrence of the item. + pub fn update(&mut self, item: T) { + self.update_with_weight(item, 1); + } + + /// Updates the sketch with the given item and weight. + pub fn update_with_weight(&mut self, item: T, weight: i64) { + if weight == 0 { + return; + } + let abs_weight = abs_i64(weight); + self.total_weight = self.total_weight.wrapping_add(abs_weight); + let num_buckets = self.num_buckets as usize; + for (row, seed) in self.hash_seeds.iter().enumerate() { + let bucket = self.bucket_index(&item, *seed); + let index = row * num_buckets + bucket; + self.counts[index] = self.counts[index].wrapping_add(weight); + } + } + + /// Returns the estimated frequency of the given item. + pub fn estimate(&self, item: T) -> i64 { + let num_buckets = self.num_buckets as usize; + let mut min = i64::MAX; + for (row, seed) in self.hash_seeds.iter().enumerate() { + let bucket = self.bucket_index(&item, *seed); + let index = row * num_buckets + bucket; + let value = self.counts[index]; + if value < min { + min = value; + } + } + min + } + + /// Returns the lower bound on the true frequency of the given item. + pub fn lower_bound(&self, item: T) -> i64 { + self.estimate(item) + } + + /// Returns the upper bound on the true frequency of the given item. + pub fn upper_bound(&self, item: T) -> i64 { + let estimate = self.estimate(item); + let error = (self.relative_error() * self.total_weight as f64) as i64; + estimate.wrapping_add(error) + } + + /// Merges another sketch into this one. + /// + /// # Panics + /// + /// Panics if the sketches have incompatible configurations. + pub fn merge(&mut self, other: &CountMinSketch) { + if std::ptr::eq(self, other) { + panic!("Cannot merge a sketch with itself."); + } + if self.num_hashes != other.num_hashes + || self.num_buckets != other.num_buckets + || self.seed != other.seed + { + panic!("Incompatible sketch configuration."); + } + for (dst, src) in self.counts.iter_mut().zip(other.counts.iter()) { + *dst = dst.wrapping_add(*src); + } + self.total_weight = self.total_weight.wrapping_add(other.total_weight); + } + + /// Serializes this sketch into the DataSketches Count-Min format. + pub fn serialize(&self) -> Vec { + let header_size = PREAMBLE_LONGS_SHORT as usize * LONG_SIZE_BYTES; + let payload_size = if self.is_empty() { + 0 + } else { + LONG_SIZE_BYTES + (self.counts.len() * size_of::()) + }; + let mut bytes = Vec::with_capacity(header_size + payload_size); + + bytes.push(PREAMBLE_LONGS_SHORT); + bytes.push(SERIAL_VERSION); + bytes.push(COUNTMIN_FAMILY_ID); + bytes.push(if self.is_empty() { FLAGS_IS_EMPTY } else { 0 }); + bytes.extend_from_slice(&0u32.to_le_bytes()); + + bytes.extend_from_slice(&self.num_buckets.to_le_bytes()); + bytes.push(self.num_hashes); + bytes.extend_from_slice(&compute_seed_hash(self.seed).to_le_bytes()); + bytes.push(0u8); + + if self.is_empty() { + return bytes; + } + + bytes.extend_from_slice(&self.total_weight.to_le_bytes()); + for count in &self.counts { + bytes.extend_from_slice(&count.to_le_bytes()); + } + bytes + } + + /// Deserializes a sketch from bytes using the default seed. + pub fn deserialize(bytes: &[u8]) -> Result { + Self::deserialize_with_seed(bytes, DEFAULT_SEED) + } + + /// Deserializes a sketch from bytes using the provided seed. + pub fn deserialize_with_seed(bytes: &[u8], seed: u64) -> Result { + fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> SerdeError { + move |_| SerdeError::InsufficientData(tag.to_string()) + } + + let mut cursor = Cursor::new(bytes); + let preamble_longs = cursor.read_u8().map_err(make_error("preamble_longs"))?; + let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; + let family_id = cursor.read_u8().map_err(make_error("family_id"))?; + let flags = cursor.read_u8().map_err(make_error("flags"))?; + cursor.read_u32::().map_err(make_error("unused32"))?; + + if family_id != COUNTMIN_FAMILY_ID { + return Err(SerdeError::InvalidFamily(format!( + "expected {} (CountMinSketch), got {}", + COUNTMIN_FAMILY_ID, family_id + ))); + } + if serial_version != SERIAL_VERSION { + return Err(SerdeError::UnsupportedVersion(format!( + "expected {}, got {}", + SERIAL_VERSION, serial_version + ))); + } + if preamble_longs != PREAMBLE_LONGS_SHORT { + return Err(SerdeError::MalformedData(format!( + "unsupported preamble_longs {preamble_longs}" + ))); + } + + let num_buckets = cursor.read_u32::().map_err(make_error("num_buckets"))?; + let num_hashes = cursor.read_u8().map_err(make_error("num_hashes"))?; + let seed_hash = cursor.read_u16::().map_err(make_error("seed_hash"))?; + cursor.read_u8().map_err(make_error("unused8"))?; + + let expected_seed_hash = compute_seed_hash(seed); + if seed_hash != expected_seed_hash { + return Err(SerdeError::InvalidParameter(format!( + "incompatible seed hash: expected {}, got {}", + expected_seed_hash, seed_hash + ))); + } + + let entries = entries_for_config_checked(num_hashes, num_buckets)?; + let mut sketch = Self::make(num_hashes, num_buckets, seed, entries); + if (flags & FLAGS_IS_EMPTY) != 0 { + return Ok(sketch); + } + + sketch.total_weight = cursor + .read_i64::() + .map_err(make_error("total_weight"))?; + for count in sketch.counts.iter_mut() { + *count = cursor.read_i64::().map_err(make_error("counts"))?; + } + Ok(sketch) + } + + fn make(num_hashes: u8, num_buckets: u32, seed: u64, entries: usize) -> Self { + let counts = vec![0i64; entries]; + let hash_seeds = make_hash_seeds(seed, num_hashes); + CountMinSketch { + num_hashes, + num_buckets, + seed, + total_weight: 0, + counts, + hash_seeds, + } + } + + fn bucket_index(&self, item: &T, seed: u64) -> usize { + let mut hasher = MurmurHash3X64128::with_seed(seed); + item.hash(&mut hasher); + let (h1, _) = hasher.finish128(); + (h1 % self.num_buckets as u64) as usize + } +} + +fn entries_for_config(num_hashes: u8, num_buckets: u32) -> usize { + assert!(num_hashes > 0, "num_hashes must be at least 1"); + assert!(num_buckets >= 3, "num_buckets must be at least 3"); + let entries = (num_hashes as usize) + .checked_mul(num_buckets as usize) + .expect("num_hashes * num_buckets overflows usize"); + assert!( + entries < MAX_TABLE_ENTRIES, + "num_hashes * num_buckets must be < {}", + MAX_TABLE_ENTRIES + ); + entries +} + +fn entries_for_config_checked(num_hashes: u8, num_buckets: u32) -> Result { + if num_hashes == 0 { + return Err(SerdeError::InvalidParameter( + "num_hashes must be at least 1".to_string(), + )); + } + if num_buckets < 3 { + return Err(SerdeError::InvalidParameter( + "num_buckets must be at least 3".to_string(), + )); + } + let entries = (num_hashes as usize) + .checked_mul(num_buckets as usize) + .ok_or_else(|| { + SerdeError::InvalidParameter("num_hashes * num_buckets overflows usize".to_string()) + })?; + if entries >= MAX_TABLE_ENTRIES { + return Err(SerdeError::InvalidParameter(format!( + "num_hashes * num_buckets must be < {}", + MAX_TABLE_ENTRIES + ))); + } + Ok(entries) +} + +fn make_hash_seeds(seed: u64, num_hashes: u8) -> Vec { + let mut seeds = Vec::with_capacity(num_hashes as usize); + for i in 0..num_hashes { + // Derive per-row hash seeds deterministically from the sketch seed. + let mut hasher = MurmurHash3X64128::with_seed(seed); + hasher.write(&u64::from(i).to_le_bytes()); + let (h1, _) = hasher.finish128(); + seeds.push(h1); + } + seeds +} + +fn abs_i64(value: i64) -> i64 { + if value >= 0 { + value + } else { + value.wrapping_neg() + } +} diff --git a/src/lib.rs b/src/lib.rs index c566727..c7c2a14 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,6 +30,7 @@ #[cfg(target_endian = "big")] compile_error!("datasketches does not support big-endian targets"); +pub mod countmin; pub mod error; pub mod hll; pub mod tdigest; diff --git a/tests/countmin_test.rs b/tests/countmin_test.rs new file mode 100644 index 0000000..24d9153 --- /dev/null +++ b/tests/countmin_test.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datasketches::countmin::{CountMinSketch, DEFAULT_SEED}; + +#[test] +fn test_init_defaults() { + let sketch = CountMinSketch::new(3, 5); + assert_eq!(sketch.num_hashes(), 3); + assert_eq!(sketch.num_buckets(), 5); + assert_eq!(sketch.seed(), DEFAULT_SEED); + assert!(sketch.is_empty()); + assert_eq!(sketch.total_weight(), 0); + assert_eq!(sketch.estimate("missing"), 0); +} + +#[test] +fn test_parameter_suggestions() { + assert_eq!(CountMinSketch::suggest_num_buckets(0.2), 14); + assert_eq!(CountMinSketch::suggest_num_buckets(0.1), 28); + assert_eq!(CountMinSketch::suggest_num_buckets(0.05), 55); + assert_eq!(CountMinSketch::suggest_num_buckets(0.01), 272); + + assert_eq!(CountMinSketch::suggest_num_hashes(0.682689492), 2); + assert_eq!(CountMinSketch::suggest_num_hashes(0.954499736), 4); + assert_eq!(CountMinSketch::suggest_num_hashes(0.997300204), 6); + + let buckets = CountMinSketch::suggest_num_buckets(0.1); + let sketch = CountMinSketch::with_seed(3, buckets, DEFAULT_SEED); + assert!(sketch.relative_error() <= 0.1); +} + +#[test] +fn test_update_and_bounds() { + let mut sketch = CountMinSketch::with_seed(3, 128, 123); + sketch.update("x"); + sketch.update_with_weight("x", 9); + assert_eq!(sketch.estimate("x"), 10); + assert_eq!(sketch.total_weight(), 10); + let estimate = sketch.estimate("x"); + let upper = sketch.upper_bound("x"); + let lower = sketch.lower_bound("x"); + assert!(lower <= estimate); + assert!(estimate <= upper); +} + +#[test] +fn test_negative_weights() { + let mut sketch = CountMinSketch::with_seed(2, 32, 123); + sketch.update_with_weight("y", -1); + assert_eq!(sketch.total_weight(), 1); + assert_eq!(sketch.estimate("y"), -1); + sketch.update_with_weight("x", 2); + assert_eq!(sketch.total_weight(), 3); +} + +#[test] +fn test_merge() { + let mut left = CountMinSketch::with_seed(3, 64, DEFAULT_SEED); + let mut right = CountMinSketch::with_seed(3, 64, DEFAULT_SEED); + for _ in 0..10 { + left.update("a"); + } + for _ in 0..4 { + right.update("a"); + right.update("b"); + } + left.merge(&right); + assert_eq!(left.total_weight(), 18); + assert!(left.estimate("a") >= 14); + assert!(left.estimate("b") >= 4); +} + +#[test] +fn test_serialize_deserialize_empty() { + let sketch = CountMinSketch::with_seed(2, 5, 123); + let bytes = sketch.serialize(); + let decoded = CountMinSketch::deserialize_with_seed(&bytes, 123).unwrap(); + assert!(decoded.is_empty()); + assert_eq!(decoded.num_hashes(), 2); + assert_eq!(decoded.num_buckets(), 5); + assert_eq!(decoded.seed(), 123); +} + +#[test] +fn test_serialize_deserialize_non_empty() { + let mut sketch = CountMinSketch::with_seed(3, 32, 123); + for i in 0..100i64 { + sketch.update(i); + } + let bytes = sketch.serialize(); + let decoded = CountMinSketch::deserialize_with_seed(&bytes, 123).unwrap(); + assert_eq!(decoded.total_weight(), sketch.total_weight()); + assert_eq!(decoded.estimate(42i64), sketch.estimate(42i64)); +} + +#[test] +#[should_panic(expected = "num_hashes must be at least 1")] +fn test_invalid_hashes() { + CountMinSketch::new(0, 5); +} + +#[test] +#[should_panic(expected = "num_buckets must be at least 3")] +fn test_invalid_buckets() { + CountMinSketch::new(1, 2); +} + +#[test] +#[should_panic(expected = "Incompatible sketch configuration.")] +fn test_merge_incompatible() { + let mut left = CountMinSketch::with_seed(3, 64, DEFAULT_SEED); + let right = CountMinSketch::with_seed(2, 64, DEFAULT_SEED); + left.merge(&right); +} From bb8e6674f8e9d6ccf8668279e1acccae510af882 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Mon, 22 Dec 2025 21:14:33 +0800 Subject: [PATCH 2/2] test: add count-min cases from rust-count-min-sketch --- tests/countmin_test.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/countmin_test.rs b/tests/countmin_test.rs index 24d9153..1c9719a 100644 --- a/tests/countmin_test.rs +++ b/tests/countmin_test.rs @@ -127,3 +127,23 @@ fn test_merge_incompatible() { let right = CountMinSketch::with_seed(2, 64, DEFAULT_SEED); left.merge(&right); } + +#[test] +fn test_increment_single_key_like_rust_count_min_sketch() { + let mut sketch = CountMinSketch::with_seed(4, 32, DEFAULT_SEED); + for _ in 0..300 { + sketch.update("key"); + } + assert_eq!(sketch.estimate("key"), 300); +} + +#[test] +fn test_increment_multi_like_rust_count_min_sketch() { + let mut sketch = CountMinSketch::with_seed(6, 128, DEFAULT_SEED); + for i in 0..1_000_000u64 { + sketch.update(i % 100); + } + for key in 0..100u64 { + assert!(sketch.estimate(key) >= 9_000); + } +}