1package avl
2
3import (
4 "sort"
5 "strings"
6 "testing"
7
8 "gno.land/p/demo/ufmt"
9)
10
11func TestTraverseByOffset(t *testing.T) {
12 const testStrings = `Alfa
13Alfred
14Alpha
15Alphabet
16Beta
17Beth
18Book
19Browser`
20 tt := []struct {
21 name string
22 asc bool
23 }{
24 {"ascending", true},
25 {"descending", false},
26 }
27
28 for _, tt := range tt {
29 t.Run(tt.name, func(t *testing.T) {
30 // use sl to insert the values, and reversed to match the values
31 // we do this to ensure that the order of TraverseByOffset is independent
32 // from the insertion order
33 sl := strings.Split(testStrings, "\n")
34 sort.Strings(sl)
35 reversed := append([]string{}, sl...)
36 reverseSlice(reversed)
37
38 if !tt.asc {
39 sl, reversed = reversed, sl
40 }
41
42 r := NewNode(reversed[0], nil)
43 for _, v := range reversed[1:] {
44 r, _ = r.Set(v, nil)
45 }
46
47 var result []string
48 for i := 0; i < len(sl); i++ {
49 r.TraverseByOffset(i, 1, tt.asc, true, func(n *Node) bool {
50 result = append(result, n.Key())
51 return false
52 })
53 }
54
55 if !slicesEqual(sl, result) {
56 t.Errorf("want %v got %v", sl, result)
57 }
58
59 for l := 2; l <= len(sl); l++ {
60 // "slices"
61 for i := 0; i <= len(sl); i++ {
62 max := i + l
63 if max > len(sl) {
64 max = len(sl)
65 }
66 exp := sl[i:max]
67 actual := []string{}
68
69 r.TraverseByOffset(i, l, tt.asc, true, func(tr *Node) bool {
70 actual = append(actual, tr.Key())
71 return false
72 })
73 if !slicesEqual(exp, actual) {
74 t.Errorf("want %v got %v", exp, actual)
75 }
76 }
77 }
78 })
79 }
80}
81
82func TestHas(t *testing.T) {
83 tests := []struct {
84 name string
85 input []string
86 hasKey string
87 expected bool
88 }{
89 {
90 "has key in non-empty tree",
91 []string{"C", "A", "B", "E", "D"},
92 "B",
93 true,
94 },
95 {
96 "does not have key in non-empty tree",
97 []string{"C", "A", "B", "E", "D"},
98 "F",
99 false,
100 },
101 {
102 "has key in single-node tree",
103 []string{"A"},
104 "A",
105 true,
106 },
107 {
108 "does not have key in single-node tree",
109 []string{"A"},
110 "B",
111 false,
112 },
113 {
114 "does not have key in empty tree",
115 []string{},
116 "A",
117 false,
118 },
119 }
120
121 for _, tt := range tests {
122 t.Run(tt.name, func(t *testing.T) {
123 var tree *Node
124 for _, key := range tt.input {
125 tree, _ = tree.Set(key, nil)
126 }
127
128 result := tree.Has(tt.hasKey)
129
130 if result != tt.expected {
131 t.Errorf("Expected %v, got %v", tt.expected, result)
132 }
133 })
134 }
135}
136
137func TestGet(t *testing.T) {
138 tests := []struct {
139 name string
140 input []string
141 getKey string
142 expectIdx int
143 expectVal interface{}
144 expectExists bool
145 }{
146 {
147 "get existing key",
148 []string{"C", "A", "B", "E", "D"},
149 "B",
150 1,
151 nil,
152 true,
153 },
154 {
155 "get non-existent key (smaller)",
156 []string{"C", "A", "B", "E", "D"},
157 "@",
158 0,
159 nil,
160 false,
161 },
162 {
163 "get non-existent key (larger)",
164 []string{"C", "A", "B", "E", "D"},
165 "F",
166 5,
167 nil,
168 false,
169 },
170 {
171 "get from empty tree",
172 []string{},
173 "A",
174 0,
175 nil,
176 false,
177 },
178 }
179
180 for _, tt := range tests {
181 t.Run(tt.name, func(t *testing.T) {
182 var tree *Node
183 for _, key := range tt.input {
184 tree, _ = tree.Set(key, nil)
185 }
186
187 idx, val, exists := tree.Get(tt.getKey)
188
189 if idx != tt.expectIdx {
190 t.Errorf("Expected index %d, got %d", tt.expectIdx, idx)
191 }
192
193 if val != tt.expectVal {
194 t.Errorf("Expected value %v, got %v", tt.expectVal, val)
195 }
196
197 if exists != tt.expectExists {
198 t.Errorf("Expected exists %t, got %t", tt.expectExists, exists)
199 }
200 })
201 }
202}
203
204func TestGetByIndex(t *testing.T) {
205 tests := []struct {
206 name string
207 input []string
208 idx int
209 expectKey string
210 expectVal interface{}
211 expectPanic bool
212 }{
213 {
214 "get by valid index",
215 []string{"C", "A", "B", "E", "D"},
216 2,
217 "C",
218 nil,
219 false,
220 },
221 {
222 "get by valid index (smallest)",
223 []string{"C", "A", "B", "E", "D"},
224 0,
225 "A",
226 nil,
227 false,
228 },
229 {
230 "get by valid index (largest)",
231 []string{"C", "A", "B", "E", "D"},
232 4,
233 "E",
234 nil,
235 false,
236 },
237 {
238 "get by invalid index (negative)",
239 []string{"C", "A", "B", "E", "D"},
240 -1,
241 "",
242 nil,
243 true,
244 },
245 {
246 "get by invalid index (out of range)",
247 []string{"C", "A", "B", "E", "D"},
248 5,
249 "",
250 nil,
251 true,
252 },
253 }
254
255 for _, tt := range tests {
256 t.Run(tt.name, func(t *testing.T) {
257 var tree *Node
258 for _, key := range tt.input {
259 tree, _ = tree.Set(key, nil)
260 }
261
262 if tt.expectPanic {
263 defer func() {
264 if r := recover(); r == nil {
265 t.Errorf("Expected a panic but didn't get one")
266 }
267 }()
268 }
269
270 key, val := tree.GetByIndex(tt.idx)
271
272 if !tt.expectPanic {
273 if key != tt.expectKey {
274 t.Errorf("Expected key %s, got %s", tt.expectKey, key)
275 }
276
277 if val != tt.expectVal {
278 t.Errorf("Expected value %v, got %v", tt.expectVal, val)
279 }
280 }
281 })
282 }
283}
284
285func TestRemove(t *testing.T) {
286 tests := []struct {
287 name string
288 input []string
289 removeKey string
290 expected []string
291 }{
292 {
293 "remove leaf node",
294 []string{"C", "A", "B", "D"},
295 "B",
296 []string{"A", "C", "D"},
297 },
298 {
299 "remove node with one child",
300 []string{"C", "A", "B", "D"},
301 "A",
302 []string{"B", "C", "D"},
303 },
304 {
305 "remove node with two children",
306 []string{"C", "A", "B", "E", "D"},
307 "C",
308 []string{"A", "B", "D", "E"},
309 },
310 {
311 "remove root node",
312 []string{"C", "A", "B", "E", "D"},
313 "C",
314 []string{"A", "B", "D", "E"},
315 },
316 {
317 "remove non-existent key",
318 []string{"C", "A", "B", "E", "D"},
319 "F",
320 []string{"A", "B", "C", "D", "E"},
321 },
322 }
323
324 for _, tt := range tests {
325 t.Run(tt.name, func(t *testing.T) {
326 var tree *Node
327 for _, key := range tt.input {
328 tree, _ = tree.Set(key, nil)
329 }
330
331 tree, _, _, _ = tree.Remove(tt.removeKey)
332
333 result := make([]string, 0)
334 tree.Iterate("", "", func(n *Node) bool {
335 result = append(result, n.Key())
336 return false
337 })
338
339 if !slicesEqual(tt.expected, result) {
340 t.Errorf("want %v got %v", tt.expected, result)
341 }
342 })
343 }
344}
345
346func TestTraverse(t *testing.T) {
347 tests := []struct {
348 name string
349 input []string
350 expected []string
351 }{
352 {
353 "empty tree",
354 []string{},
355 []string{},
356 },
357 {
358 "single node tree",
359 []string{"A"},
360 []string{"A"},
361 },
362 {
363 "small tree",
364 []string{"C", "A", "B", "E", "D"},
365 []string{"A", "B", "C", "D", "E"},
366 },
367 {
368 "large tree",
369 []string{"H", "D", "L", "B", "F", "J", "N", "A", "C", "E", "G", "I", "K", "M", "O"},
370 []string{"A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O"},
371 },
372 }
373
374 for _, tt := range tests {
375 t.Run(tt.name, func(t *testing.T) {
376 var tree *Node
377 for _, key := range tt.input {
378 tree, _ = tree.Set(key, nil)
379 }
380
381 t.Run("iterate", func(t *testing.T) {
382 var result []string
383 tree.Iterate("", "", func(n *Node) bool {
384 result = append(result, n.Key())
385 return false
386 })
387 if !slicesEqual(tt.expected, result) {
388 t.Errorf("want %v got %v", tt.expected, result)
389 }
390 })
391
392 t.Run("ReverseIterate", func(t *testing.T) {
393 var result []string
394 tree.ReverseIterate("", "", func(n *Node) bool {
395 result = append(result, n.Key())
396 return false
397 })
398 expected := make([]string, len(tt.expected))
399 copy(expected, tt.expected)
400 for i, j := 0, len(expected)-1; i < j; i, j = i+1, j-1 {
401 expected[i], expected[j] = expected[j], expected[i]
402 }
403 if !slicesEqual(expected, result) {
404 t.Errorf("want %v got %v", expected, result)
405 }
406 })
407
408 t.Run("TraverseInRange", func(t *testing.T) {
409 var result []string
410 start, end := "C", "M"
411 tree.TraverseInRange(start, end, true, true, func(n *Node) bool {
412 result = append(result, n.Key())
413 return false
414 })
415 expected := make([]string, 0)
416 for _, key := range tt.expected {
417 if key >= start && key < end {
418 expected = append(expected, key)
419 }
420 }
421 if !slicesEqual(expected, result) {
422 t.Errorf("want %v got %v", expected, result)
423 }
424 })
425
426 t.Run("early termination", func(t *testing.T) {
427 if len(tt.input) == 0 {
428 return // Skip for empty tree
429 }
430
431 var result []string
432 var count int
433 tree.Iterate("", "", func(n *Node) bool {
434 count++
435 result = append(result, n.Key())
436 return true // Stop after first item
437 })
438
439 if count != 1 {
440 t.Errorf("Expected callback to be called exactly once, got %d calls", count)
441 }
442 if len(result) != 1 {
443 t.Errorf("Expected exactly one result, got %d items", len(result))
444 }
445 if len(result) > 0 && result[0] != tt.expected[0] {
446 t.Errorf("Expected first item to be %v, got %v", tt.expected[0], result[0])
447 }
448 })
449 })
450 }
451}
452
453func TestRotateWhenHeightDiffers(t *testing.T) {
454 tests := []struct {
455 name string
456 input []string
457 expected []string
458 }{
459 {
460 "right rotation when left subtree is higher",
461 []string{"E", "C", "A", "B", "D"},
462 []string{"A", "B", "C", "D", "E"},
463 },
464 {
465 "left rotation when right subtree is higher",
466 []string{"A", "C", "E", "D", "F"},
467 []string{"A", "C", "D", "E", "F"},
468 },
469 {
470 "left-right rotation",
471 []string{"E", "A", "C", "B", "D"},
472 []string{"A", "B", "C", "D", "E"},
473 },
474 {
475 "right-left rotation",
476 []string{"A", "E", "C", "B", "D"},
477 []string{"A", "B", "C", "D", "E"},
478 },
479 }
480
481 for _, tt := range tests {
482 t.Run(tt.name, func(t *testing.T) {
483 var tree *Node
484 for _, key := range tt.input {
485 tree, _ = tree.Set(key, nil)
486 }
487
488 // perform rotation or balance
489 tree = tree.balance()
490
491 // check tree structure
492 var result []string
493 tree.Iterate("", "", func(n *Node) bool {
494 result = append(result, n.Key())
495 return false
496 })
497
498 if !slicesEqual(tt.expected, result) {
499 t.Errorf("want %v got %v", tt.expected, result)
500 }
501 })
502 }
503}
504
505func TestRotateAndBalance(t *testing.T) {
506 tests := []struct {
507 name string
508 input []string
509 expected []string
510 }{
511 {
512 "right rotation",
513 []string{"A", "B", "C", "D", "E"},
514 []string{"A", "B", "C", "D", "E"},
515 },
516 {
517 "left rotation",
518 []string{"E", "D", "C", "B", "A"},
519 []string{"A", "B", "C", "D", "E"},
520 },
521 {
522 "left-right rotation",
523 []string{"C", "A", "E", "B", "D"},
524 []string{"A", "B", "C", "D", "E"},
525 },
526 {
527 "right-left rotation",
528 []string{"C", "E", "A", "D", "B"},
529 []string{"A", "B", "C", "D", "E"},
530 },
531 }
532
533 for _, tt := range tests {
534 t.Run(tt.name, func(t *testing.T) {
535 var tree *Node
536 for _, key := range tt.input {
537 tree, _ = tree.Set(key, nil)
538 }
539
540 tree = tree.balance()
541
542 var result []string
543 tree.Iterate("", "", func(n *Node) bool {
544 result = append(result, n.Key())
545 return false
546 })
547
548 if !slicesEqual(tt.expected, result) {
549 t.Errorf("want %v got %v", tt.expected, result)
550 }
551 })
552 }
553}
554
555func TestRemoveFromEmptyTree(t *testing.T) {
556 var tree *Node
557 newTree, _, val, removed := tree.Remove("NonExistent")
558 if newTree != nil {
559 t.Errorf("Removing from an empty tree should still be nil tree.")
560 }
561 if val != nil || removed {
562 t.Errorf("Expected no value and removed=false when removing from empty tree.")
563 }
564}
565
566func TestBalanceAfterRemoval(t *testing.T) {
567 tests := []struct {
568 name string
569 insertKeys []string
570 removeKey string
571 expectedBalance int
572 }{
573 {
574 name: "balance after removing right node",
575 insertKeys: []string{"B", "A", "D", "C", "E"},
576 removeKey: "E",
577 expectedBalance: 0,
578 },
579 {
580 name: "balance after removing left node",
581 insertKeys: []string{"D", "B", "E", "A", "C"},
582 removeKey: "A",
583 expectedBalance: 0,
584 },
585 {
586 name: "ensure no lean after removal",
587 insertKeys: []string{"C", "B", "E", "A", "D", "F"},
588 removeKey: "F",
589 expectedBalance: -1,
590 },
591 {
592 name: "descending order insert, remove middle node",
593 insertKeys: []string{"E", "D", "C", "B", "A"},
594 removeKey: "C",
595 expectedBalance: 0,
596 },
597 {
598 name: "ascending order insert, remove middle node",
599 insertKeys: []string{"A", "B", "C", "D", "E"},
600 removeKey: "C",
601 expectedBalance: 0,
602 },
603 {
604 name: "duplicate key insert, remove the duplicated key",
605 insertKeys: []string{"C", "B", "C", "A", "D"},
606 removeKey: "C",
607 expectedBalance: 1,
608 },
609 {
610 name: "complex rotation case",
611 insertKeys: []string{"H", "B", "A", "C", "E", "D", "F", "G"},
612 removeKey: "B",
613 expectedBalance: 0,
614 },
615 }
616
617 for _, tt := range tests {
618 t.Run(tt.name, func(t *testing.T) {
619 var tree *Node
620 for _, key := range tt.insertKeys {
621 tree, _ = tree.Set(key, nil)
622 }
623
624 tree, _, _, _ = tree.Remove(tt.removeKey)
625
626 balance := tree.calcBalance()
627 if balance != tt.expectedBalance {
628 t.Errorf("Expected balance factor %d, got %d", tt.expectedBalance, balance)
629 }
630
631 if balance < -1 || balance > 1 {
632 t.Errorf("Tree is unbalanced with factor %d", balance)
633 }
634
635 if errMsg := checkSubtreeBalance(t, tree); errMsg != "" {
636 t.Errorf("AVL property violation after removal: %s", errMsg)
637 }
638 })
639 }
640}
641
642func TestBSTProperty(t *testing.T) {
643 var tree *Node
644 keys := []string{"D", "B", "F", "A", "C", "E", "G"}
645 for _, key := range keys {
646 tree, _ = tree.Set(key, nil)
647 }
648
649 var result []string
650 inorderTraversal(t, tree, &result)
651
652 for i := 1; i < len(result); i++ {
653 if result[i] < result[i-1] {
654 t.Errorf("BST property violated: %s < %s (index %d)",
655 result[i], result[i-1], i)
656 }
657 }
658}
659
660// inorderTraversal performs an inorder traversal of the tree and returns the keys in a list.
661func inorderTraversal(t *testing.T, node *Node, result *[]string) {
662 t.Helper()
663
664 if node == nil {
665 return
666 }
667 // leaf
668 if node.height == 0 {
669 *result = append(*result, node.key)
670 return
671 }
672 inorderTraversal(t, node.leftNode, result)
673 inorderTraversal(t, node.rightNode, result)
674}
675
676// checkSubtreeBalance checks if all nodes under the given node satisfy the AVL tree conditions.
677// The balance factor of all nodes must be ∈ [-1, +1]
678func checkSubtreeBalance(t *testing.T, node *Node) string {
679 t.Helper()
680
681 if node == nil {
682 return ""
683 }
684
685 if node.IsLeaf() {
686 // leaf node must be height=0, size=1
687 if node.height != 0 {
688 return ufmt.Sprintf("Leaf node %s has height %d, expected 0", node.Key(), node.height)
689 }
690 if node.size != 1 {
691 return ufmt.Sprintf("Leaf node %s has size %d, expected 1", node.Key(), node.size)
692 }
693 return ""
694 }
695
696 // check balance factor for current node
697 balanceFactor := node.calcBalance()
698 if balanceFactor < -1 || balanceFactor > 1 {
699 return ufmt.Sprintf("Node %s is unbalanced: balanceFactor=%d", node.Key(), balanceFactor)
700 }
701
702 // check height / size relationship for children
703 left, right := node.getLeftNode(), node.getRightNode()
704 expectedHeight := maxInt8(left.height, right.height) + 1
705 if node.height != expectedHeight {
706 return ufmt.Sprintf("Node %s has incorrect height %d, expected %d", node.Key(), node.height, expectedHeight)
707 }
708 expectedSize := left.Size() + right.Size()
709 if node.size != expectedSize {
710 return ufmt.Sprintf("Node %s has incorrect size %d, expected %d", node.Key(), node.size, expectedSize)
711 }
712
713 // recursively check the left/right subtree
714 if errMsg := checkSubtreeBalance(t, left); errMsg != "" {
715 return errMsg
716 }
717 if errMsg := checkSubtreeBalance(t, right); errMsg != "" {
718 return errMsg
719 }
720
721 return ""
722}
723
724func slicesEqual(w1, w2 []string) bool {
725 if len(w1) != len(w2) {
726 return false
727 }
728 for i := 0; i < len(w1); i++ {
729 if w1[i] != w2[i] {
730 return false
731 }
732 }
733 return true
734}
735
736func reverseSlice(ss []string) {
737 for i := 0; i < len(ss)/2; i++ {
738 j := len(ss) - 1 - i
739 ss[i], ss[j] = ss[j], ss[i]
740 }
741}
node_test.gno
15.32 Kb · 741 lines