text_buffer/
metric.rs

1use get_size2::GetSize;
2use smallvec::{SmallVec, smallvec};
3use std::{
4    fmt,
5    iter::Sum,
6    mem,
7    ops::{Add, AddAssign, RangeBounds, Sub, SubAssign},
8};
9
10const MAX: usize = 6;
11const MIN: usize = MAX / 2;
12#[cfg(test)]
13pub(crate) const MAX_LEAF: usize = 18;
14#[cfg(not(test))]
15pub(crate) const MAX_LEAF: usize = 8000;
16
17type Metrics = SmallVec<[Metric; MAX]>;
18
19#[derive(Debug, Default, GetSize)]
20struct Internal {
21    #[get_size(size_fn = smallvec_size_helper)]
22    metrics: Metrics,
23    #[get_size(size_fn = smallvec_size_helper)]
24    children: SmallVec<[Box<Node>; MAX]>,
25}
26
27fn smallvec_size_helper<T>(slice: &[T]) -> usize
28where
29    T: GetSize,
30{
31    slice.iter().map(GetSize::get_heap_size).sum()
32}
33
34impl Internal {
35    fn len(&self) -> usize {
36        debug_assert_eq!(self.metrics.len(), self.children.len());
37        self.metrics.len()
38    }
39
40    fn push(&mut self, child: Box<Node>) {
41        let metric = child.metrics();
42        self.children.push(child);
43        self.metrics.push(metric);
44    }
45
46    fn insert(&mut self, idx: usize, child: Box<Node>) {
47        let metric = child.metrics();
48        self.children.insert(idx, child);
49        self.metrics.insert(idx, metric);
50    }
51
52    fn take(&mut self, other: &mut Self, range: impl RangeBounds<usize> + Clone) {
53        self.metrics.extend(other.metrics.drain(range.clone()));
54        self.children.extend(other.children.drain(range.clone()));
55    }
56
57    fn search_char_pos(&self, char_pos: usize) -> (usize, Metric) {
58        let mut acc = Metric::default();
59        let last = self.metrics.len() - 1;
60        for (i, metric) in self.metrics[..last].iter().enumerate() {
61            if char_pos < acc.chars + metric.chars {
62                return (i, acc);
63            }
64            acc += *metric;
65        }
66        (last, acc)
67    }
68
69    fn insert_node(&mut self, idx: usize, new_child: Box<Node>) -> Option<Box<Node>> {
70        // update the metrics for the current child
71        self.metrics[idx] = self.children[idx].metrics();
72        // shift idx to the right
73        let idx = idx + 1;
74        if self.len() < MAX {
75            // If there is room in this node then insert the
76            // node before the current one
77            self.insert(idx, new_child);
78            None
79        } else {
80            assert_eq!(self.len(), MAX);
81            // split this node into two and return the left one
82            let middle = MAX / 2;
83
84            let mut right = Internal {
85                metrics: self.metrics.drain(middle..).collect(),
86                children: self.children.drain(middle..).collect(),
87            };
88            if idx < middle {
89                self.insert(idx, new_child);
90            } else {
91                right.insert(idx - middle, new_child);
92            }
93            Some(Box::new(Node::Internal(right)))
94        }
95    }
96
97    /// Balance the node by either stealing from it's siblings or merging with
98    ///
99    /// Return true if the node still underfull
100    fn balance_node(&mut self, idx: usize) -> bool {
101        let missing = MIN.saturating_sub(self.children[idx].len());
102        if missing == 0 {
103            return false;
104        }
105
106        #[expect(clippy::borrowed_box)]
107        let free_nodes = |x: &Box<Node>| x.len().saturating_sub(MIN);
108
109        let left_free = if idx == 0 { 0 } else { free_nodes(&self.children[idx - 1]) };
110        let right_free = self.children.get(idx + 1).map_or(0, free_nodes);
111        if left_free + right_free >= missing {
112            // short circuit
113            let failed = self.try_steal_left(idx) && self.try_steal_right(idx);
114            debug_assert!(!failed);
115            false
116        } else {
117            self.merge_children(idx)
118        }
119    }
120
121    /// Merge this node with it's siblings.
122    ///
123    /// Returns true if the merged node is still is underfull
124    fn merge_children(&mut self, idx: usize) -> bool {
125        if self.len() <= 1 {
126            // no siblings to merge
127            return true;
128        }
129        let right_idx = if idx == 0 { idx + 1 } else { idx };
130        let left_idx = right_idx - 1;
131        let (left, right) = self.children.split_at_mut(right_idx);
132        let underfull = left[left_idx].merge_sibling(&mut right[0]);
133        self.children.remove(right_idx);
134        let right_metric = self.metrics.remove(right_idx);
135        self.metrics[left_idx] += right_metric;
136        underfull
137    }
138
139    fn try_steal_left(&mut self, idx: usize) -> bool {
140        assert!(idx < self.children.len());
141        assert!(idx < self.metrics.len());
142        let Some(left_idx) = idx.checked_sub(1) else { return true };
143
144        while self.children[idx].len() < MIN {
145            let left_node = self.children[left_idx].steal(false);
146            if let Some((node, node_metric)) = left_node {
147                self.children[idx].merge_node(node, node_metric, 0);
148                self.metrics[idx] += node_metric;
149                self.metrics[left_idx] -= node_metric;
150            } else {
151                return true;
152            }
153        }
154        false
155    }
156
157    fn try_steal_right(&mut self, idx: usize) -> bool {
158        assert_eq!(self.children.len(), self.metrics.len());
159        let right_idx = idx + 1;
160        if right_idx >= self.children.len() {
161            return true;
162        }
163
164        while self.children[idx].len() < MIN {
165            let right_node = self.children[right_idx].steal(true);
166            if let Some((node, node_metric)) = right_node {
167                let underfull_child = &mut self.children[idx];
168                let len = underfull_child.len();
169                underfull_child.merge_node(node, node_metric, len);
170                self.metrics[idx] += node_metric;
171                self.metrics[right_idx] -= node_metric;
172            } else {
173                return true;
174            }
175        }
176        false
177    }
178}
179
180#[derive(Debug, Default, GetSize)]
181struct Leaf {
182    #[get_size(size_fn = smallvec_size_helper)]
183    metrics: Metrics,
184}
185
186impl Leaf {
187    fn len(&self) -> usize {
188        self.metrics.len()
189    }
190
191    fn insert_at(&mut self, idx: usize, pos: Metric, data: Metric) -> Option<Box<Node>> {
192        if (self.metrics[idx].bytes + data.bytes) < MAX_LEAF {
193            self.metrics[idx] += data;
194            return None;
195        }
196        let left_metric = pos;
197        let right_metric = self.metrics[idx] - left_metric;
198
199        let new = if left_metric.bytes <= right_metric.bytes {
200            self.metrics[idx] = left_metric + data;
201            right_metric
202        } else {
203            self.metrics[idx] = left_metric;
204            right_metric + data
205        };
206        // shift idx to the right
207        let idx = idx + 1;
208        if self.len() < MAX {
209            // If there is room in this node then insert the
210            // leaf before the current one, splitting the
211            // size
212            self.metrics.insert(idx, new);
213            None
214        } else {
215            assert_eq!(self.len(), MAX);
216            // split this node into two and return the left one
217            let middle = MAX / 2;
218            let mut right_metrics: Metrics = self.metrics.drain(middle..).collect();
219            if idx < middle {
220                self.metrics.insert(idx, new);
221            } else {
222                right_metrics.insert(idx - middle, new);
223            }
224            let right = Node::Leaf(Leaf { metrics: right_metrics });
225            Some(Box::new(right))
226        }
227    }
228
229    fn push(&mut self, metric: Metric) -> Option<Box<Node>> {
230        if self.len() < MAX {
231            // If there is room in this node then insert the
232            // leaf before the current one, splitting the
233            // size
234            self.metrics.push(metric);
235            None
236        } else {
237            assert_eq!(self.len(), MAX);
238            // split this node into two and return the left one
239            let right = Node::Leaf(Leaf { metrics: smallvec![metric] });
240            Some(Box::new(right))
241        }
242    }
243}
244
245#[derive(Debug, Default, GetSize)]
246pub(crate) struct BufferMetrics {
247    root: Node,
248}
249
250impl BufferMetrics {
251    pub(crate) fn search_char(&self, chars: usize) -> (Metric, usize) {
252        self.root.search_char(chars)
253    }
254
255    pub(crate) fn search_byte(&self, bytes: usize) -> (Metric, usize) {
256        self.root.search_byte(bytes)
257    }
258
259    pub(crate) fn len(&self) -> Metric {
260        self.root.metrics()
261    }
262
263    pub(crate) fn build(metrics: impl Iterator<Item = Metric>) -> Self {
264        // build the base layer of leaf nodes
265        let len = metrics.size_hint().0;
266        let remainder = len % MAX;
267        let split_idx = if remainder != 0 && remainder != len && remainder < MIN {
268            // If that last node is too small then merge it with the
269            // previous one by splitting it early
270            len - MIN - 1
271        } else {
272            // index will never equal len
273            len
274        };
275        let mut leaf = Leaf::default();
276        let cap = (len / MAX) + 1;
277        let mut nodes = Vec::with_capacity(cap);
278        for (idx, metric) in metrics.enumerate() {
279            leaf.push(metric);
280            if leaf.len() == MAX || idx == split_idx {
281                nodes.push(Box::new(Node::Leaf(leaf)));
282                leaf = Leaf::default();
283            }
284        }
285        if leaf.len() > 0 {
286            nodes.push(Box::new(Node::Leaf(leaf)));
287        }
288        // build each layer of internal nodes from the bottom up
289        let mut next_level = Vec::with_capacity((nodes.len() / MAX) + 1);
290        while nodes.len() > 1 {
291            let len = nodes.len();
292            let parent_count = len / MAX;
293            let remainder = len % MAX;
294            let split_idx = if remainder != 0 && remainder != len && remainder < MIN {
295                // If that last node is too small then merge it with the
296                // previous one by splitting it early
297                len - MIN - 1
298            } else {
299                // index will never equal len
300                len
301            };
302            let mut int = Internal::default();
303            for (idx, node) in nodes.drain(..).enumerate() {
304                int.metrics.push(node.metrics());
305                int.children.push(node);
306                if int.len() == MAX || idx == split_idx {
307                    next_level.push(Box::new(Node::Internal(int)));
308                    int = Internal::default();
309                }
310            }
311            debug_assert_eq!(next_level.len(), parent_count);
312            mem::swap(&mut nodes, &mut next_level);
313
314            if int.len() > 0 {
315                nodes.push(Box::new(Node::Internal(int)));
316            }
317        }
318        let root = *nodes.pop().unwrap_or_default();
319        let built = Self { root };
320        built.assert_invariants();
321        built
322    }
323
324    pub(crate) fn insert(&mut self, pos: Metric, data: impl Iterator<Item = Metric>) {
325        let size = data.size_hint().0;
326        let len = self.root.metrics();
327        assert!(pos.bytes <= len.bytes);
328
329        if size == 0 {
330            return;
331        }
332        if len.bytes == 0 {
333            // empty tree
334            let _ = mem::replace(self, Self::build(data));
335            return;
336        }
337
338        if size < 6 {
339            let mut pos = pos;
340            for metric in data {
341                let offset = metric;
342                self.root.insert_at(pos, metric);
343                pos += offset;
344            }
345        } else {
346            // build a new tree and splice it in
347            let new = Self::build(data);
348            if pos.bytes == 0 {
349                // append at the start by swapping self and new
350                let new_pos = new.root.metrics().chars;
351                let right = mem::replace(self, new);
352                self.root.append(right.root);
353                self.root.fix_seam(new_pos);
354            } else if len.bytes == pos.bytes {
355                // append at the end
356                self.root.append(new.root);
357                self.root.fix_seam(pos.chars);
358            } else {
359                // splice in the middle
360                let right_metric = self.root.metrics() - pos;
361                let new_metric = new.root.metrics();
362                let right = self.root.split(pos);
363                debug_assert_eq!(self.root.metrics(), pos);
364                debug_assert_eq!(right.metrics(), right_metric);
365                self.root.append(new.root);
366                self.root.append(right);
367                debug_assert_eq!(self.root.metrics(), pos + right_metric + new_metric);
368                self.root.fix_seam(pos.chars);
369                self.root.fix_seam(pos.chars + new_metric.chars);
370            }
371            self.root.collapse();
372        }
373        self.assert_invariants();
374    }
375
376    pub(crate) fn delete(&mut self, start: Metric, end: Metric) {
377        debug_assert!(start.bytes <= end.bytes);
378        debug_assert!(start.chars <= end.chars);
379        if start.bytes == end.bytes {
380            return;
381        }
382        if start.bytes == 0 && end.bytes == self.root.metrics().bytes {
383            // delete the whole tree
384            self.root = Node::default();
385            return;
386        }
387
388        let fix_seam = self.root.delete_impl(start, end);
389        if fix_seam {
390            self.root.fix_seam(start.chars);
391        }
392        self.root.collapse();
393        self.assert_invariants();
394    }
395
396    fn assert_invariants(&self) {
397        if cfg!(debug_assertions) {
398            self.root.assert_integrity();
399            self.root.assert_node_size(true);
400            self.root.assert_balance();
401        }
402    }
403}
404
405#[derive(Debug, GetSize)]
406enum Node {
407    Leaf(Leaf),
408    Internal(Internal),
409}
410
411impl Node {
412    fn metric_slice(&self) -> &[Metric] {
413        match self {
414            Self::Internal(x) => &x.metrics,
415            Self::Leaf(x) => &x.metrics,
416        }
417    }
418
419    fn is_underfull(&self) -> bool {
420        match self {
421            Node::Leaf(leaf) => leaf.len() < MIN,
422            Node::Internal(int) => int.len() < MIN,
423        }
424    }
425
426    fn metrics(&self) -> Metric {
427        let metrics = match self {
428            Self::Leaf(x) => &x.metrics,
429            Self::Internal(x) => &x.metrics,
430        };
431        metrics.iter().copied().sum()
432    }
433
434    fn len(&self) -> usize {
435        match self {
436            Self::Leaf(x) => x.len(),
437            Self::Internal(x) => x.len(),
438        }
439    }
440
441    fn depth(&self) -> usize {
442        match self {
443            Self::Leaf(_) => 0,
444            Self::Internal(x) => 1 + x.children[0].depth(),
445        }
446    }
447
448    fn search_char_pos(&self, char_pos: usize) -> (usize, Metric) {
449        let metrics = self.metric_slice();
450        let mut acc = Metric::default();
451        let last = metrics.len() - 1;
452        for (i, metric) in metrics[..last].iter().enumerate() {
453            if char_pos < acc.chars + metric.chars {
454                return (i, acc);
455            }
456            acc += *metric;
457        }
458        (last, acc)
459    }
460
461    /// If only a single child remains, collapse the node into that child.
462    fn collapse(&mut self) {
463        while self.len() == 1 {
464            match self {
465                Node::Internal(int) => {
466                    let child = int.children.pop().unwrap();
467                    let _ = mem::replace(self, *child);
468                }
469                Node::Leaf(_) => break,
470            }
471        }
472    }
473
474    fn insert_at(&mut self, pos: Metric, data: Metric) {
475        let len = self.metrics();
476        assert!(pos.bytes <= len.bytes);
477        if self.len() == 0 {
478            assert!(pos.bytes == 0);
479            let Node::Leaf(leaf) = self else { unreachable!() };
480            leaf.metrics.push(data);
481            return;
482        }
483        let new = self.insert_impl(pos, data);
484        if let Some(right) = new {
485            // split the root, making the old root the left child
486            let left = mem::replace(self, Node::Internal(Internal::default()));
487            let Node::Internal(int) = self else { unreachable!() };
488            int.metrics = smallvec![left.metrics(), right.metrics()];
489            int.children = smallvec![Box::new(left), right];
490        }
491    }
492
493    fn insert_impl(&mut self, pos: Metric, data: Metric) -> Option<Box<Node>> {
494        self.assert_node_integrity();
495        let (idx, metric) = self.search_char_pos(pos.chars);
496        let offset = pos - metric;
497        match self {
498            Node::Leaf(leaf) => leaf.insert_at(idx, offset, data),
499            Node::Internal(int) => {
500                if let Some(new) = int.children[idx].insert_impl(offset, data) {
501                    int.insert_node(idx, new)
502                } else {
503                    int.metrics[idx] += data;
504                    None
505                }
506            }
507        }
508    }
509
510    fn delete_impl(&mut self, start: Metric, end: Metric) -> bool {
511        self.assert_node_integrity();
512        assert!(start.chars <= end.chars);
513        let ((start_idx, start), (end_idx, end)) = self.get_delete_indices(start, end);
514
515        match self {
516            Node::Internal(int) => {
517                if start_idx == end_idx {
518                    // delete range is in a single child
519                    let idx = start_idx;
520                    let metrics = &mut int.metrics;
521                    let fix_seam = int.children[idx].delete_impl(start, end);
522                    metrics[idx] -= end - start;
523                    if int.children[idx].is_underfull() {
524                        let fix = int.balance_node(idx);
525                        debug_assert!(!fix);
526                    }
527                    fix_seam
528                } else {
529                    // if the byte index covers the entire node, delete the
530                    // whole thing
531                    let start_delete = if start.bytes == 0 { start_idx } else { start_idx + 1 };
532                    let end_size = int.metrics[end_idx].bytes;
533                    let end_delete = if end.bytes == end_size { end_idx + 1 } else { end_idx };
534                    // Delete nodes in the middle
535                    if start_delete < end_delete {
536                        int.children.drain(start_delete..end_delete);
537                        int.metrics.drain(start_delete..end_delete);
538                    }
539                    // since we might have deleted nodes in the middle, the
540                    // index is now 1 more then start.
541                    let mut fix_seam = false;
542                    let mut merge_left = false;
543                    // has a left child
544                    if start_delete > start_idx {
545                        fix_seam |=
546                            int.children[start_idx].delete_impl(start, int.metrics[start_idx]);
547                        int.metrics[start_idx] = start;
548                        if int.children[start_idx].is_underfull() {
549                            merge_left = true;
550                        }
551                    }
552                    // has a right child
553                    if end_delete <= end_idx {
554                        let end_idx = if start_idx == start_delete {
555                            start_idx
556                        } else {
557                            debug_assert_eq!(
558                                end_idx,
559                                start_idx + 1 + end_delete.saturating_sub(start_delete)
560                            );
561                            start_idx + 1
562                        };
563                        fix_seam |= int.children[end_idx].delete_impl(Metric::default(), end);
564                        int.metrics[end_idx] -= end;
565                        // merge right child first so that the index of left is not changed
566                        if int.children[end_idx].is_underfull() {
567                            fix_seam |= int.balance_node(end_idx);
568                        }
569                    }
570                    if merge_left {
571                        fix_seam |= int.balance_node(start_idx);
572                    }
573                    fix_seam
574                }
575            }
576            Node::Leaf(leaf) => {
577                if start_idx == end_idx {
578                    let chunk = end - start;
579                    if chunk == leaf.metrics[start_idx] {
580                        leaf.metrics.remove(start_idx);
581                    } else {
582                        leaf.metrics[start_idx] -= chunk;
583                    }
584                } else {
585                    let start_delete = if start.bytes == 0 { start_idx } else { start_idx + 1 };
586                    let end_size = leaf.metrics[end_idx].bytes;
587                    let end_delete = if end_size == end.bytes { end_idx + 1 } else { end_idx };
588
589                    leaf.metrics[end_idx] -= end;
590                    leaf.metrics[start_idx] = start;
591                    if start_delete < end_delete {
592                        leaf.metrics.drain(start_delete..end_delete);
593                    }
594                }
595                false
596            }
597        }
598    }
599
600    fn get_delete_indices(&self, start: Metric, end: Metric) -> ((usize, Metric), (usize, Metric)) {
601        let (mut start, mut end) = (start, end);
602        let mut start_idx = None;
603        let mut end_idx = None;
604        for idx in 0..self.len() {
605            let metric = self.metric_slice()[idx];
606            if start_idx.is_none() && (start.chars < metric.chars || start.chars == 0) {
607                start_idx = Some(idx);
608            }
609            if end.chars <= metric.chars {
610                end_idx = Some(idx);
611                break;
612            }
613            if start_idx.is_none() {
614                start -= metric;
615            }
616            end -= metric;
617        }
618        ((start_idx.unwrap(), start), (end_idx.unwrap(), end))
619    }
620
621    fn merge_node(&mut self, node: Option<Box<Node>>, metric: Metric, idx: usize) {
622        match (self, node) {
623            // TODO don't recalculate the metric
624            (Node::Internal(int), Some(node)) => int.insert(idx, node),
625            (Node::Leaf(leaf), None) => leaf.metrics.insert(idx, metric),
626            _ => unreachable!("cannot merge internal and leaf nodes"),
627        }
628    }
629
630    fn merge_sibling(&mut self, right: &mut Self) -> bool {
631        assert!(self.len() + right.len() <= MAX);
632        match (self, right) {
633            (Node::Internal(left), Node::Internal(right)) => {
634                left.metrics.append(&mut right.metrics);
635                left.children.append(&mut right.children);
636                left.len() < MIN
637            }
638            (Node::Leaf(left), Node::Leaf(right)) => {
639                left.metrics.append(&mut right.metrics);
640                left.len() < MIN
641            }
642            _ => unreachable!("cannot merge internal and leaf nodes"),
643        }
644    }
645
646    fn steal(&mut self, first: bool) -> Option<(Option<Box<Node>>, Metric)> {
647        let idx = if first { 0 } else { self.len() - 1 };
648        match self {
649            Node::Internal(int) if int.len() > MIN => {
650                let metric = int.metrics.remove(idx);
651                let child = int.children.remove(idx);
652                Some((Some(child), metric))
653            }
654            Node::Leaf(leaf) if leaf.len() > MIN => {
655                let metric = leaf.metrics.remove(idx);
656                Some((None, metric))
657            }
658            _ => None,
659        }
660    }
661
662    fn fix_seam(&mut self, char_pos: usize) -> bool {
663        if let Node::Internal(int) = self {
664            let prev_len = int.len();
665            loop {
666                let (idx, metric) = int.search_char_pos(char_pos);
667
668                if int.children[idx].is_underfull() {
669                    int.balance_node(idx);
670                }
671                let on_seam = metric.chars == char_pos && idx > 0;
672                if on_seam && int.children[idx - 1].is_underfull() {
673                    // we are on a seam and there is a left sibling
674                    int.balance_node(idx - 1);
675                }
676
677                // recalculate because position might have changed
678                let (idx, metric) = int.search_char_pos(char_pos);
679
680                let mut retry = false;
681                // recurse into children
682                let on_seam = metric.chars == char_pos && idx > 0;
683                if on_seam {
684                    let new_pos = int.metrics[idx - 1].chars;
685                    retry |= int.children[idx - 1].fix_seam(new_pos);
686                }
687
688                let new_pos = char_pos - metric.chars;
689                retry |= int.children[idx].fix_seam(new_pos);
690                // If one of the children was underfull we need to retry the
691                // loop to merge it again
692                if !retry {
693                    break;
694                }
695            }
696            let len = int.len();
697            len < prev_len && len < MIN
698        } else {
699            false
700        }
701    }
702
703    /// Split the tree at the given point. Returns the right side of the split.
704    /// Note that the resulting trees may have underfull nodes and will need to
705    /// be fixed later.
706    fn split(&mut self, pos: Metric) -> Node {
707        let (idx, metric) = self.search_char_pos(pos.chars);
708        match self {
709            Node::Leaf(leaf) => {
710                let offset = pos - metric;
711                let mut right;
712                if offset.bytes == 0 {
713                    right = leaf.metrics.drain(idx..).collect();
714                } else {
715                    let right_node = leaf.metrics[idx] - offset;
716                    leaf.metrics[idx] = offset;
717                    right = smallvec![right_node];
718                    right.extend(leaf.metrics.drain(idx + 1..));
719                }
720                Node::Leaf(Leaf { metrics: right })
721            }
722            Node::Internal(int) => {
723                let offset = pos - metric;
724                let mut right;
725                if offset.bytes == 0 {
726                    right = Internal {
727                        metrics: int.metrics.drain(idx..).collect(),
728                        children: int.children.drain(idx..).collect(),
729                    };
730                } else {
731                    let right_node = int.children[idx].split(offset);
732                    let right_metric = int.metrics[idx] - offset;
733                    int.metrics[idx] = offset;
734                    right = Internal {
735                        metrics: smallvec![right_metric],
736                        children: smallvec![Box::new(right_node)],
737                    };
738                    right.take(int, idx + 1..);
739                }
740                Node::Internal(right)
741            }
742        }
743    }
744
745    fn append(&mut self, other: Self) {
746        let self_depth = self.depth();
747        let other_depth = other.depth();
748        if other_depth <= self_depth {
749            let new = self.append_at_depth(other, self_depth - other_depth);
750
751            if let Some(right) = new {
752                // split the root, making the old root the left child
753                let left = mem::replace(self, Node::Internal(Internal::default()));
754                let Node::Internal(int) = self else { unreachable!() };
755                int.metrics = smallvec![left.metrics(), right.metrics()];
756                int.children = smallvec![Box::new(left), right];
757            }
758        } else {
759            let left = mem::replace(self, other);
760            let new = self.prepend_at_depth(left, other_depth - self_depth);
761
762            if let Some(left) = new {
763                // split the root, making the old root the right child
764                let right = mem::replace(self, Node::Internal(Internal::default()));
765                let Node::Internal(int) = self else { unreachable!() };
766                int.metrics = smallvec![left.metrics(), right.metrics()];
767                int.children = smallvec![left, Box::new(right)];
768            }
769        }
770    }
771
772    fn append_at_depth(&mut self, other: Self, depth: usize) -> Option<Box<Node>> {
773        if depth == 0 {
774            match (self, other) {
775                (Node::Leaf(left), Node::Leaf(mut right)) => {
776                    if left.len() + right.len() <= MAX {
777                        left.metrics.extend(right.metrics.drain(..));
778                        None
779                    } else {
780                        Some(Box::new(Node::Leaf(right)))
781                    }
782                }
783                (Node::Internal(left), Node::Internal(mut right)) => {
784                    if left.len() + right.len() <= MAX {
785                        left.take(&mut right, ..);
786                        None
787                    } else {
788                        Some(Box::new(Node::Internal(right)))
789                    }
790                }
791                _ => unreachable!("siblings have different types"),
792            }
793        } else if let Node::Internal(int) = self {
794            match int.children.last_mut().unwrap().append_at_depth(other, depth - 1) {
795                Some(new) if int.len() < MAX => {
796                    int.push(new);
797                    None
798                }
799                Some(new) => Some(Box::new(Node::Internal(Internal {
800                    metrics: smallvec![new.metrics()],
801                    children: smallvec![new],
802                }))),
803                None => {
804                    let update = int.children.last().unwrap().metrics();
805                    *int.metrics.last_mut().unwrap() = update;
806                    None
807                }
808            }
809        } else {
810            unreachable!("reached leaf node while depth was non-zero");
811        }
812    }
813
814    fn prepend_at_depth(&mut self, other: Self, depth: usize) -> Option<Box<Node>> {
815        if depth == 0 {
816            match (other, self) {
817                (Node::Leaf(mut left), Node::Leaf(right)) => {
818                    if left.len() + right.len() <= MAX {
819                        left.metrics.extend(right.metrics.drain(..));
820                        *right = left;
821                        None
822                    } else {
823                        Some(Box::new(Node::Leaf(left)))
824                    }
825                }
826                (Node::Internal(mut left), Node::Internal(right)) => {
827                    if left.len() + right.len() <= MAX {
828                        left.take(right, ..);
829                        *right = left;
830                        None
831                    } else {
832                        Some(Box::new(Node::Internal(left)))
833                    }
834                }
835                _ => unreachable!("siblings have different types"),
836            }
837        } else if let Node::Internal(int) = self {
838            match int.children.first_mut().unwrap().prepend_at_depth(other, depth - 1) {
839                Some(new) if int.len() < MAX => {
840                    int.insert(0, new);
841                    None
842                }
843                Some(new) => Some(Box::new(Node::Internal(Internal {
844                    metrics: smallvec![new.metrics()],
845                    children: smallvec![new],
846                }))),
847                None => {
848                    let update = int.children[0].metrics();
849                    int.metrics[0] = update;
850                    None
851                }
852            }
853        } else {
854            unreachable!("reached leaf node while depth was non-zero");
855        }
856    }
857
858    fn search_char(&self, chars: usize) -> (Metric, usize) {
859        self.search_impl(chars, |x| x.chars)
860    }
861
862    fn search_byte(&self, bytes: usize) -> (Metric, usize) {
863        self.search_impl(bytes, |x| x.bytes)
864    }
865
866    fn search_impl(&self, needle: usize, getter: impl Fn(&Metric) -> usize) -> (Metric, usize) {
867        self.assert_node_integrity();
868        let mut needle = needle;
869        let mut sum = Metric::default();
870        for (idx, metric) in self.metric_slice().iter().enumerate() {
871            // fast path if we happen get the exact position in the node
872            if needle == 0 {
873                break;
874            }
875            let pos = getter(metric);
876            if needle < pos {
877                // if it is ascii then we can just calculate the offset
878                if metric.is_ascii() {
879                    let offset = Metric { bytes: needle, chars: needle };
880                    return (sum + offset, 0);
881                }
882                let child_sum = match &self {
883                    Node::Internal(int) => {
884                        let (metric, offset) = int.children[idx].search_impl(needle, getter);
885                        (sum + metric, offset)
886                    }
887                    Node::Leaf(_) => (sum, needle),
888                };
889                return child_sum;
890            }
891            sum += *metric;
892            needle -= pos;
893        }
894        // we are beyond total size of the tree
895        (sum, needle)
896    }
897
898    fn assert_node_integrity(&self) {
899        if cfg!(debug_assertions) {
900            match self {
901                Node::Internal(int) => {
902                    assert!(!int.metrics.is_empty());
903                    assert!(int.metrics.len() <= MAX);
904                    assert_eq!(int.metrics.len(), int.children.len());
905                    for i in 0..int.children.len() {
906                        assert_eq!(int.children[i].metrics(), int.metrics[i]);
907                    }
908                }
909                Node::Leaf(leaf) => {
910                    assert!(leaf.metrics.len() <= MAX);
911                }
912            }
913        }
914    }
915
916    fn assert_integrity(&self) {
917        match self {
918            Node::Leaf(_) => {}
919            Node::Internal(int) => {
920                assert_eq!(int.metrics.len(), int.children.len());
921                for i in 0..int.children.len() {
922                    assert_eq!(int.children[i].metrics(), int.metrics[i]);
923                    int.children[i].assert_integrity();
924                }
925            }
926        }
927    }
928
929    fn assert_balance(&self) -> usize {
930        match self {
931            Node::Leaf(_) => 1,
932            Node::Internal(int) => {
933                let first_depth = int.children[0].assert_balance();
934                for node in &int.children[1..] {
935                    assert_eq!(node.assert_balance(), first_depth);
936                }
937                first_depth + 1
938            }
939        }
940    }
941
942    fn assert_node_size(&self, is_root: bool) {
943        match self {
944            Node::Leaf(leaf) => {
945                assert!(leaf.len() <= MAX);
946                if !is_root {
947                    assert!(leaf.len() >= MIN);
948                }
949            }
950            Node::Internal(int) => {
951                assert!(int.len() <= MAX);
952                assert!(int.len() >= 2);
953                if !is_root {
954                    assert!(int.len() >= MIN);
955                }
956                for node in &int.children {
957                    node.assert_node_size(false);
958                }
959            }
960        }
961    }
962}
963
964impl Default for Node {
965    fn default() -> Self {
966        Self::Leaf(Leaf::default())
967    }
968}
969
970impl fmt::Display for Node {
971    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
972        // print the children level by level by adding them to a pair of
973        // alternating arrays for each level
974        let mut current = Vec::new();
975        let mut next: Vec<&Self> = Vec::new();
976        current.push(self);
977        let mut level = 0;
978        while !current.is_empty() {
979            next.clear();
980            write!(f, "level {level}:")?;
981            for node in &current {
982                write!(f, " [")?;
983                match node {
984                    Node::Internal(int) => {
985                        for metric in &int.metrics {
986                            write!(f, "({metric}) ")?;
987                        }
988                        for child in &int.children {
989                            next.push(child);
990                        }
991                    }
992                    Node::Leaf(leaf) => {
993                        for metric in &leaf.metrics {
994                            write!(f, "({metric}) ")?;
995                        }
996                    }
997                }
998                write!(f, "]")?;
999            }
1000            writeln!(f)?;
1001            level += 1;
1002            mem::swap(&mut current, &mut next);
1003        }
1004        Ok(())
1005    }
1006}
1007
1008#[derive(Debug, Default, Copy, Clone, Eq, GetSize)]
1009pub(crate) struct Metric {
1010    pub(crate) bytes: usize,
1011    pub(crate) chars: usize,
1012}
1013
1014impl PartialEq for Metric {
1015    fn eq(&self, other: &Self) -> bool {
1016        let eq = self.bytes == other.bytes;
1017        if eq {
1018            debug_assert_eq!(self.chars, other.chars);
1019        } else {
1020            debug_assert_ne!(self.chars, other.chars);
1021        }
1022        eq
1023    }
1024}
1025
1026impl Metric {
1027    fn is_ascii(&self) -> bool {
1028        self.bytes == self.chars
1029    }
1030}
1031
1032impl fmt::Display for Metric {
1033    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1034        write!(f, "b:{}, c:{}", self.bytes, self.chars)
1035    }
1036}
1037
1038impl Sum for Metric {
1039    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
1040        iter.fold(Self::default(), |a, b| Self {
1041            bytes: a.bytes + b.bytes,
1042            chars: a.chars + b.chars,
1043        })
1044    }
1045}
1046
1047impl Add for Metric {
1048    type Output = Self;
1049
1050    fn add(self, rhs: Self) -> Self::Output {
1051        Self { bytes: self.bytes + rhs.bytes, chars: self.chars + rhs.chars }
1052    }
1053}
1054
1055impl Sub for Metric {
1056    type Output = Self;
1057
1058    fn sub(self, rhs: Self) -> Self::Output {
1059        Self { bytes: self.bytes - rhs.bytes, chars: self.chars - rhs.chars }
1060    }
1061}
1062
1063impl AddAssign for Metric {
1064    fn add_assign(&mut self, rhs: Self) {
1065        self.bytes += rhs.bytes;
1066        self.chars += rhs.chars;
1067    }
1068}
1069
1070impl SubAssign for Metric {
1071    fn sub_assign(&mut self, rhs: Self) {
1072        self.bytes -= rhs.bytes;
1073        self.chars -= rhs.chars;
1074    }
1075}
1076
1077#[cfg(test)]
1078mod test {
1079    use super::*;
1080
1081    fn metric(x: usize) -> Metric {
1082        Metric { bytes: x * 2, chars: x }
1083    }
1084
1085    fn mock_search_char(root: &Node, needle: usize) -> Metric {
1086        let (metric, offset) = root.search_char(needle);
1087        Metric { bytes: metric.bytes + offset * 2, chars: metric.chars + offset }
1088    }
1089
1090    struct TreeBuilderBasic {
1091        count: usize,
1092        step: usize,
1093    }
1094
1095    impl Iterator for TreeBuilderBasic {
1096        type Item = Metric;
1097
1098        fn next(&mut self) -> Option<Self::Item> {
1099            if self.count == 0 {
1100                None
1101            } else {
1102                self.count -= 1;
1103                Some(metric(self.step))
1104            }
1105        }
1106
1107        fn size_hint(&self) -> (usize, Option<usize>) {
1108            (self.count, Some(self.count))
1109        }
1110    }
1111
1112    #[test]
1113    fn test_insert_empty() {
1114        let mut buffer = BufferMetrics::build(&mut TreeBuilderBasic { count: 1, step: 5 });
1115        let builder = &mut TreeBuilderBasic { count: 4, step: 1 };
1116        buffer.insert(metric(0), builder);
1117        for i in 0..10 {
1118            println!("searching for {i}");
1119            let cmp = mock_search_char(&buffer.root, i);
1120            assert_eq!(cmp, metric(i));
1121        }
1122    }
1123
1124    #[test]
1125    fn test_insert() {
1126        let mut buffer = BufferMetrics::default();
1127        let builder = &mut TreeBuilderBasic { count: 10, step: 1 };
1128        buffer.insert(metric(0), builder);
1129        buffer.root.insert_at(metric(5), metric(5));
1130        println!("{}", buffer.root);
1131        for i in 0..15 {
1132            println!("searching for {i}");
1133            let cmp = mock_search_char(&buffer.root, i);
1134            assert_eq!(cmp, metric(i));
1135        }
1136    }
1137
1138    #[test]
1139    fn test_search() {
1140        let builder = &mut TreeBuilderBasic { count: 20, step: 1 };
1141        let root = BufferMetrics::build(builder);
1142        for i in 0..20 {
1143            println!("searching for {i}");
1144            let cmp = mock_search_char(&root.root, i);
1145            assert_eq!(cmp, metric(i));
1146        }
1147    }
1148
1149    #[test]
1150    fn test_search_chars() {
1151        let builder = &mut TreeBuilderBasic { count: 20, step: 1 };
1152        let root = BufferMetrics::build(builder);
1153        for i in 0..20 {
1154            println!("searching for {i}");
1155            let cmp = mock_search_char(&root.root, i);
1156            assert_eq!(cmp, metric(i));
1157        }
1158    }
1159
1160    #[test]
1161    fn test_delete_range_leaf() {
1162        // shouldn't need more then a single leaf node
1163        let builder = &mut TreeBuilderBasic { count: 3, step: 4 };
1164        let mut buffer = BufferMetrics::build(builder);
1165        assert_eq!(buffer.root.metrics(), metric(12));
1166        println!("init: {}", buffer.root);
1167        buffer.delete(metric(1), metric(3));
1168        assert_eq!(buffer.root.metrics(), metric(10));
1169        println!("after: {}", buffer.root);
1170        buffer.delete(metric(2), metric(6));
1171        assert_eq!(buffer.root.metrics(), metric(6));
1172        println!("after: {}", buffer.root);
1173        buffer.delete(metric(1), metric(4));
1174        assert_eq!(buffer.root.metrics(), metric(3));
1175        println!("after: {}", buffer.root);
1176        buffer.delete(metric(0), metric(1));
1177        assert_eq!(buffer.root.metrics(), metric(2));
1178        println!("after: {}", buffer.root);
1179    }
1180
1181    #[test]
1182    fn test_delete_range_internal() {
1183        let builder = &mut TreeBuilderBasic { count: 6, step: 4 };
1184        let mut buffer = BufferMetrics::build(builder);
1185        println!("init: {}", buffer.root);
1186        buffer.delete(metric(0), metric(12));
1187        assert_eq!(buffer.root.metrics(), metric(12));
1188        println!("after: {}", buffer.root);
1189
1190        let builder = &mut TreeBuilderBasic { count: 6, step: 4 };
1191        let mut buffer = BufferMetrics::build(builder);
1192        println!("init: {}", buffer.root);
1193        buffer.delete(metric(12), metric(24));
1194        assert_eq!(buffer.root.metrics(), metric(12));
1195        println!("after: {}", buffer.root);
1196    }
1197
1198    #[test]
1199    fn test_split() {
1200        let builder = &mut TreeBuilderBasic { count: 20, step: 1 };
1201        let mut buffer = BufferMetrics::build(builder);
1202        println!("init: {}", buffer.root);
1203        let right = buffer.root.split(metric(10));
1204        println!("left: {}", buffer.root);
1205        println!("right: {}", buffer.root);
1206        assert_eq!(buffer.root.metrics(), right.metrics());
1207        for i in 0..10 {
1208            println!("searching for {i}");
1209            let cmp = mock_search_char(&buffer.root, i);
1210            assert_eq!(cmp, metric(i));
1211            let cmp = mock_search_char(&right, i);
1212            assert_eq!(cmp, metric(i));
1213        }
1214    }
1215
1216    #[test]
1217    fn test_append() {
1218        let builder = &mut TreeBuilderBasic { count: 10, step: 1 };
1219        let mut buffer = BufferMetrics::build(builder);
1220        println!("init: {}", buffer.root);
1221        let builder = &mut TreeBuilderBasic { count: 10, step: 1 };
1222        let right = BufferMetrics::build(builder);
1223        println!("right: {}", right.root);
1224        buffer.root.append(right.root);
1225        println!("after: {}", buffer.root);
1226        for i in 0..20 {
1227            println!("searching for {i}");
1228            let cmp = mock_search_char(&buffer.root, i);
1229            assert_eq!(cmp, metric(i));
1230        }
1231    }
1232
1233    #[test]
1234    fn test_build() {
1235        {
1236            let builder = &mut TreeBuilderBasic { count: 0, step: 0 };
1237            let buffer = BufferMetrics::build(builder);
1238            assert_eq!(buffer.root.len(), 0);
1239        }
1240        {
1241            let builder = &mut TreeBuilderBasic { count: 1, step: 1 };
1242            let buffer = BufferMetrics::build(builder);
1243            assert_eq!(buffer.root.len(), 1);
1244        }
1245        let builder = &mut TreeBuilderBasic { count: 20, step: 1 };
1246        let buffer = BufferMetrics::build(builder);
1247        println!("{}", buffer.root);
1248        for i in 0..20 {
1249            println!("searching for {i}");
1250            let cmp = mock_search_char(&buffer.root, i);
1251            assert_eq!(cmp, metric(i));
1252        }
1253    }
1254}