diff --git a/src/hyperrectangles.jl b/src/hyperrectangles.jl index 1254ff4..7ce8c5a 100644 --- a/src/hyperrectangles.jl +++ b/src/hyperrectangles.jl @@ -38,5 +38,86 @@ end get_max_distance_no_end(m, rec, point) = get_min_max_distance_no_end(distance_function_max, m, rec, point) +# Compute per-dimension contributions for max distance +function get_max_distance_contributions(m::Metric, rec::HyperRectangle{V}, point::AbstractVector{T}) where {V,T} + p = Distances.parameters(m) + return V( + @inbounds begin + v = distance_function_max(point[dim], rec.maxes[dim], rec.mins[dim]) + p === nothing ? eval_op(m, v, zero(T)) : eval_op(m, v, zero(T), p[dim]) + end for dim in eachindex(point) + ) +end + +# Compute single dimension contribution for max distance +function get_max_distance_contribution_single(m::Metric, point_dim, min_bound::T, max_bound::T, dim::Integer) where {T} + v = distance_function_max(point_dim, max_bound, min_bound) + p = Distances.parameters(m) + return p === nothing ? eval_op(m, v, zero(T)) : eval_op(m, v, zero(T), p[dim]) +end + get_min_distance_no_end(m, rec, point) = get_min_max_distance_no_end(distance_function_min, m, rec, point) + +# Combine all per-dimension contributions into final distance +function eval_reduce_all(m::Metric, contributions::SVector{N, T}) where {N, T} + if m isa Chebyshev + return maximum(contributions) + else + # For Lp norms, sum all contributions + s = zero(T) + for contrib in contributions + s = eval_reduce(m, s, contrib) + end + return s + end +end + +# O(1) incremental update: remove old contribution, add new contribution +function update_max_distance_incremental(m::Metric, current_max_dist, old_contrib, new_contrib) + # For Lp norms: current_max_dist - old_contrib + new_contrib + temp = eval_reduce_inv(m, current_max_dist, old_contrib) # Remove old + return eval_reduce(m, temp, new_contrib) # Add new +end + +# Inverse of eval_reduce for Lp norms (subtract contribution) +function eval_reduce_inv(m::Metric, current_sum, contrib_to_remove) + # For Lp norms, this is just subtraction since eval_reduce is addition + return current_sum - contrib_to_remove +end + +function update_min_distance_no_end(m::Metric, current_dist, point::AbstractVector{T}, + old_min, old_max, new_min, new_max, dim + ) where {T} + p_dim = point[dim] + + if new_min != old_min + # Min boundary changed - split_val is the new min + split_val = new_min + split_diff = p_dim - split_val + if split_diff > 0 + # Point is to the right of split_val + ddiff = max(zero(T), p_dim - new_max) + else + # Point is to the left of split_val + ddiff = max(zero(T), old_min - p_dim) + end + else + # Max boundary changed - split_val is the new max + split_val = new_max + split_diff = p_dim - split_val + if split_diff > 0 + # Point is to the right of split_val + ddiff = max(zero(T), p_dim - old_max) + else + # Point is to the left of split_val + ddiff = max(zero(T), new_min - p_dim) + end + end + + split_diff_pow = eval_pow(m, split_diff) + ddiff_pow = eval_pow(m, ddiff) + diff_tot = eval_diff(m, split_diff_pow, ddiff_pow, dim) + + return eval_reduce(m, current_dist, diff_tot) +end diff --git a/src/kd_tree.jl b/src/kd_tree.jl index 5518d7d..1b14b80 100644 --- a/src/kd_tree.jl +++ b/src/kd_tree.jl @@ -63,18 +63,10 @@ function KDTree(data::AbstractVector{V}, indices = indices_reordered end - if metric isa Distances.UnionMetrics - p = parameters(metric) - if p !== nothing && length(p) != length(V) - throw(ArgumentError( - "dimension of input points:$(length(V)) and metric parameter:$(length(p)) must agree")) - end - end - KDTree(storedata ? data : similar(data, 0), hyper_rec, indices, metric, split_vals, split_dims, tree_data, reorder) end - function KDTree(data::AbstractVecOrMat{T}, +function KDTree(data::AbstractVecOrMat{T}, metric::M = Euclidean(); leafsize::Int = 25, storedata::Bool = true, @@ -112,16 +104,7 @@ function build_KDTree(index::Int, mid_idx = find_split(first(range), tree_data.leafsize, n_p) - split_dim = 1 - max_spread = zero(T) - # Find dimension and spread where the spread is maximal - for d in 1:length(V) - spread = hyper_rec.maxes[d] - hyper_rec.mins[d] - if spread > max_spread - max_spread = spread - split_dim = d - end - end + split_dim = argmax(d -> hyper_rec.maxes[d] - hyper_rec.mins[d], 1:length(V)) select_spec!(indices, mid_idx, first(range), last(range), data, split_dim) @@ -183,61 +166,84 @@ function knn_kernel!(tree::KDTree{V}, far = getleft(index) hyper_rec_far = HyperRectangle(hyper_rec.mins, @inbounds setindex(hyper_rec.maxes, split_val, split_dim)) hyper_rec_close = HyperRectangle(@inbounds(setindex(hyper_rec.mins, split_val, split_dim)), hyper_rec.maxes) - ddiff = max(zero(eltype(V)), p_dim - hi) else close = getleft(index) far = getright(index) hyper_rec_far = HyperRectangle(@inbounds(setindex(hyper_rec.mins, split_val, split_dim)), hyper_rec.maxes) hyper_rec_close = HyperRectangle(hyper_rec.mins, @inbounds setindex(hyper_rec.maxes, split_val, split_dim)) - ddiff = max(zero(eltype(V)), lo - p_dim) end # Always call closer sub tree knn_kernel!(tree, close, point, best_idxs, best_dists, min_dist, hyper_rec_close, skip) - split_diff_pow = eval_pow(M, split_diff) - ddiff_pow = eval_pow(M, ddiff) - diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) - new_min = eval_reduce(M, min_dist, diff_tot) + # Compute new min distance for far subtree using incremental update + if M isa Chebyshev + new_min = get_min_distance_no_end(M, hyper_rec_far, point) + else + # Try to update min distance incrementally for far subtree + if split_diff > 0 + # Point is to the right, far subtree has split_val as new max + new_min = update_min_distance_no_end(M, min_dist, point, hyper_rec.mins[split_dim], hyper_rec.maxes[split_dim], hyper_rec.mins[split_dim], split_val, split_dim) + else + # Point is to the left, far subtree has split_val as new min + new_min = update_min_distance_no_end(M, min_dist, point, hyper_rec.mins[split_dim], hyper_rec.maxes[split_dim], split_val, hyper_rec.maxes[split_dim], split_dim) + end + end + if new_min < best_dists[1] knn_kernel!(tree, far, point, best_idxs, best_dists, new_min, hyper_rec_far, skip) end return end -function _inrange(tree::KDTree, - point::AbstractVector, - radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}} = Int[]) +function _inrange( + tree::KDTree, + point::AbstractVector, + radius::Number, + idx_in_ball::Union{Nothing, Vector{<:Integer}} = Int[] + ) init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point) - return inrange_kernel!(tree, 1, point, eval_pow(tree.metric, radius), idx_in_ball, - tree.hyper_rec, init_min) + init_max_contribs = get_max_distance_contributions(tree.metric, tree.hyper_rec, point) + init_max = eval_reduce_all(tree.metric, init_max_contribs) + return inrange_kernel!( + tree, 1, point, eval_pow(tree.metric, radius), idx_in_ball, + tree.hyper_rec, init_min, init_max_contribs, init_max + ) end # Explicitly check the distance between leaf node and point while traversing -function inrange_kernel!(tree::KDTree, - index::Int, - point::AbstractVector, - r::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}, - hyper_rec::HyperRectangle, - min_dist) +function inrange_kernel!( + tree::KDTree, + index::Int, + point::AbstractVector, + r::Number, + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + hyper_rec::HyperRectangle, + min_dist, + max_dist_contribs::SVector, + max_dist + ) # Point is outside hyper rectangle, skip the whole sub tree if min_dist > r return 0 end + M = tree.metric + # At a leaf node. Go through all points in node and add those in range if isleaf(tree.tree_data.n_internal_nodes, index) return add_points_inrange!(idx_in_ball, tree, index, point, r) end + if max_dist < r + return addall(tree, index, idx_in_ball) + end + split_val = tree.split_vals[index] split_dim = tree.split_dims[index] lo = hyper_rec.mins[split_dim] hi = hyper_rec.maxes[split_dim] p_dim = point[split_dim] split_diff = p_dim - split_val - M = tree.metric count = 0 @@ -254,19 +260,60 @@ function inrange_kernel!(tree::KDTree, hyper_rec_close = HyperRectangle(hyper_rec.mins, @inbounds setindex(hyper_rec.maxes, split_val, split_dim)) ddiff = max(zero(lo - p_dim), lo - p_dim) end + # Update per-dimension contributions for close subtree + old_contrib = max_dist_contribs[split_dim] + if split_diff > 0 + # Point is to the right, close subtree has split_val as new min + new_contrib_close = get_max_distance_contribution_single(M, point[split_dim], split_val, hyper_rec.maxes[split_dim], split_dim) + else + # Point is to the left, close subtree has split_val as new max + new_contrib_close = get_max_distance_contribution_single(M, point[split_dim], hyper_rec.mins[split_dim], split_val, split_dim) + end + + new_max_contribs_close = setindex(max_dist_contribs, new_contrib_close, split_dim) + + # For Chebyshev, recompute from scratch; for others, use O(1) incremental update + if M isa Chebyshev + new_max_dist_close = eval_reduce_all(M, new_max_contribs_close) + else + new_max_dist_close = update_max_distance_incremental(M, max_dist, old_contrib, new_contrib_close) + end + # Call closer sub tree - count += inrange_kernel!(tree, close, point, r, idx_in_ball, hyper_rec_close, min_dist) - - # TODO: We could potentially also keep track of the max distance - # between the point and the hyper rectangle and add the whole sub tree - # in case of the max distance being <= r similarly to the BallTree inrange method. - # It would be interesting to benchmark this on some different data sets. - - # Call further sub tree with the new min distance - split_diff_pow = eval_pow(M, split_diff) - ddiff_pow = eval_pow(M, ddiff) - diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) - new_min = eval_reduce(M, min_dist, diff_tot) - count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min) + count += inrange_kernel!(tree, close, point, r, idx_in_ball, hyper_rec_close, min_dist, new_max_contribs_close, new_max_dist_close) + + # Compute new min distance for far subtree using incremental update + if M isa Chebyshev + new_min = get_min_distance_no_end(M, hyper_rec_far, point) + else + # Try to update min distance incrementally for far subtree + if split_diff > 0 + # Point is to the right, far subtree has split_val as new max + new_min = update_min_distance_no_end(M, min_dist, point, hyper_rec.mins[split_dim], hyper_rec.maxes[split_dim], hyper_rec.mins[split_dim], split_val, split_dim) + else + # Point is to the left, far subtree has split_val as new min + new_min = update_min_distance_no_end(M, min_dist, point, hyper_rec.mins[split_dim], hyper_rec.maxes[split_dim], split_val, hyper_rec.maxes[split_dim], split_dim) + end + end + + # Update per-dimension contributions for far subtree + if split_diff > 0 + # Point is to the right, far subtree has split_val as new max + new_contrib_far = get_max_distance_contribution_single(M, point[split_dim], hyper_rec.mins[split_dim], split_val, split_dim) + else + # Point is to the left, far subtree has split_val as new min + new_contrib_far = get_max_distance_contribution_single(M, point[split_dim], split_val, hyper_rec.maxes[split_dim], split_dim) + end + + new_max_contribs_far = setindex(max_dist_contribs, new_contrib_far, split_dim) + + # For Chebyshev, recompute from scratch; for others, use O(1) incremental update + if M isa Chebyshev + new_max_dist_far = eval_reduce_all(M, new_max_contribs_far) + else + new_max_dist_far = update_max_distance_incremental(M, max_dist, old_contrib, new_contrib_far) + end + + count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min, new_max_contribs_far, new_max_dist_far) return count end