1 module collections.treemap;
2 
3 import std.functional : binaryFun;
4 import collections.commons;
5 /**
6  * Implements an AVL tree backed treemap.
7  * Intended to used as an alternative to D's own associative array.
8  * If E set to void, then it works more like a regular tree datastructure, and can be indexed with any type K has an
9  * opCmp override.
10  * `nogcIndexing` changes the behavior of `opIndex` if no match is found. If set to true, indexing returns the default
11  * value if no match found, which will need some design consideration. If set to false, indexing throws an exception
12  * if no match found.
13  * Nodes should have the lesser elements on the left side. Behavior can be changed with `less`.
14  */
15 public struct TreeMap(K, E, bool nogcIndexing = true, alias less = "a < b") {
16 	private struct Node {
17 		K				key;		///Identifier key, also used for automatic sorting
18 		static if (E.stringof != "void")
19 			E			elem;		///The element stored in this field if exists
20 		Node*			left;		///The node that holds a key with a lesser value
21 		Node*			right;		///The node that holds a key with a greater value
22 		string toString() const {
23 			import std.conv : to;
24 			string result = "{K: " ~ to!string(key) ~ " ; ";
25 			static if (E.stringof != "void")
26 				result ~= "E: " ~ to!string(elem) ~ " ; ";
27 			if (left) result ~= "L: " ~ left.toString() ~ " ; ";
28 			if (right) result ~= "R: " ~ right.toString() ~ " ; ";
29 			return result ~ "}";
30 		}
31 		///Returns the balance of the node.
32 		@property sizediff_t balance() @nogc @safe pure nothrow const {
33 			sizediff_t result;
34 			if(left) result -= left.height;
35 			if(right) result += right.height;
36 			return result;
37 		}
38 		///Returns the height of the node.
39 		@property size_t height() @nogc @safe pure nothrow const {
40 			const size_t lhs = left ? left.height + 1 : 0;
41 			const size_t rhs = right ? right.height + 1 : 0;
42 			return lhs >= rhs ? lhs : rhs;
43 		}
44 		static if (E.stringof != "void"){
45 			/**
46 			 * Implements a simple left-to-right tree traversal.
47 			 */
48 			int opApply(scope int delegate(ref E) dg) {
49 				if(left !is null)
50 					if(left.opApply(dg))
51 						return 1;
52 				if(dg(elem))
53 					return 1;
54 				if(right !is null)
55 					if(right.opApply(dg))
56 						return 1;
57 				return 0;
58 			}
59 			/**
60 			 * Implements a simple left-to-right tree traversal.
61 			 */
62 			int opApply(scope int delegate(K, ref E) dg) {
63 				if(left !is null)
64 					if(left.opApply(dg))
65 						return 1;
66 				if(dg(key, elem))
67 					return 1;
68 				if(right !is null)
69 					if(right.opApply(dg))
70 						return 1;
71 				return 0;
72 			}
73 			/**
74 			 * Implements a simple right-to-left tree traversal.
75 			 */
76 			int opApplyReverse(scope int delegate(ref E) dg) {
77 				if(right !is null)
78 					if(right.opApply(dg))
79 						return 1;
80 				if(dg(elem))
81 					return 1;
82 				if(left !is null)
83 					if(left.opApply(dg))
84 						return 1;
85 				return 0;
86 			}
87 			/**
88 			 * Implements a simple right-to-left tree traversal.
89 			 */
90 			int opApplyReverse(scope int delegate(K, ref E) dg) {
91 				if(right !is null)
92 					if(right.opApply(dg))
93 						return 1;
94 				if(dg(key, elem))
95 					return 1;
96 				if(left !is null)
97 					if(left.opApply(dg))
98 						return 1;
99 				return 0;
100 			}
101 		} else {
102 			/**
103 			 * Implements a simple left-to-right tree traversal.
104 			 */
105 			int opApply(scope int delegate(K) dg) {
106 				if(left !is null)
107 					if(left.opApply(dg))
108 						return 1;
109 				if(dg(key))
110 					return 1;
111 				if(right !is null)
112 					if(right.opApply(dg))
113 						return 1;
114 				return 0;
115 			}
116 			/**
117 			 * Implements a simple right-to-left tree traversal.
118 			 */
119 			int opApplyReverse(scope int delegate(K) dg) {
120 				if(right !is null)
121 					if(right.opApply(dg))
122 						return 1;
123 				if(dg(key))
124 					return 1;
125 				if(left !is null)
126 					if(left.opApply(dg))
127 						return 1;
128 				return 0;
129 			}
130 		}
131 	}
132 	private size_t		nOfElements;///Current number of elements in this collection
133 	private Node*		root;		///The root element of the 
134 	
135 	static if (E.stringof != "void"){
136 		static if (nogcIndexing) {
137 			/**
138 			 * @nogc capable indexing.
139 			 * Can be indexed with any type of value as long as K.opCmp supports it.
140 			 * Returns the found element if match found.
141 			 * Returns E.init if match not found.
142 			 */
143 			E opIndex(K key) @nogc @safe pure nothrow {
144 				Node* crnt = root;
145 				while(crnt) {
146 					if(binaryFun!less(key, crnt.key)) {		//key is smaller than current element's, look at lesser elements
147 						crnt = crnt.left;
148 					} else if(binaryFun!less(crnt.key, key)) {			//key is greater than current element's, look at greater elements
149 						crnt = crnt.right;
150 					} else {	//match found, return element
151 						return crnt.elem;
152 						
153 					}
154 				}
155 				return E.init;
156 			}
157 			/**
158 			 * Returns the pointer of the element, or null if key not found.
159 			 */
160 			E* ptrOf(K key) @nogc @safe pure nothrow {
161 				Node* crnt = root;
162 				while(crnt) {
163 					if(binaryFun!less(key, crnt.key)) {		//key is smaller than current element's, look at lesser elements
164 						crnt = crnt.left;
165 					} else if(binaryFun!less(crnt.key, key)) {			//key is greater than current element's, look at greater elements
166 						crnt = crnt.right;
167 					} else {	//match found, return element
168 						return &crnt.elem;
169 						
170 					}
171 				}
172 				return null;
173 			}
174 		} else {
175 			/**
176 			 * Indexing function that relies on the GC, and throws if no match found
177 			 * Can be indexed with any type of value as long as K.opCmp supports it.
178 			 * Returns the found element if match found.
179 			 */
180 			ref E opIndex(T)(T key) @safe pure {
181 				Node* crnt = root;
182 				while(crnt) {
183 					if(binaryFun!less(key, crnt.key)) {		//key is smaller than current element's, look at lesser elements
184 						crnt = crnt.left;
185 					} else if(binaryFun!less(crnt.key, key)) {			//key is greater than current element's, look at greater elements
186 						crnt = crnt.right;
187 					} else {	//match found, return element
188 						return crnt.elem;
189 					}
190 				}
191 				throw new ElementNotFoundException("No match found");
192 			}
193 		}
194 	} else {
195 		static if (nogcIndexing) {
196 			/**
197 			 * @nogc capable indexing.
198 			 * Can be indexed with any type of value as long as K.opCmp supports it.
199 			 * Returns the found element if match found.
200 			 * Returns E.init if match not found.
201 			 */
202 			K opIndex(T)(T key) @nogc @safe pure nothrow {
203 				Node* crnt = root;
204 				while(crnt) {
205 					if(binaryFun!less(key, crnt.key)) {		//key is smaller than current element's, look at lesser elements
206 						crnt = crnt.left;
207 					} else if(binaryFun!less(crnt.key, key)) {			//key is greater than current element's, look at greater elements
208 						crnt = crnt.right;
209 					} else {	//match found, return element
210 						return crnt.key;
211 					}
212 				}
213 				return K.init;
214 			}
215 		} else {
216 			/**
217 			 * Indexing function that relies on the GC, and throws if no match found
218 			 * Can be indexed with any type of value as long as K.opCmp supports it.
219 			 * Returns the found element if match found.
220 			 */
221 			K opIndex(T)(T key) @safe pure {
222 				Node* crnt = root;
223 				while(crnt) {
224 					if(binaryFun!less(key, crnt.key)) {		//key is smaller than current element's, look at lesser elements
225 						crnt = crnt.left;
226 					} else if(binaryFun!less(crnt.key, key)) {			//key is greater than current element's, look at greater elements
227 						crnt = crnt.right;
228 					} else {	//match found, return element
229 						return crnt.key;
230 					}
231 				}
232 				throw new ElementNotFoundException("No match found");
233 			}
234 		}
235 	}
236 	static if (E.stringof != "void"){
237 		/**
238 		 * Assigns a value to the given key.
239 		 * If key found, the value will be overwritten without node insertion.
240 		 * If key isn't found, a new node will be inserted.
241 		 */
242 		auto opIndexAssign(E elem, K key) @safe pure nothrow {
243 			if(!root){	//Best case scenario: root is empty
244 				nOfElements++;
245 				root = new Node(key, elem, null, null);
246 				return elem;
247 			}
248 			Node* crnt = root;
249 			while(crnt) {
250 				if(binaryFun!less(key, crnt.key)) {	//Key is smaller, look at left hand side
251 					if(crnt.left is null) {
252 						crnt.left = new Node(key, elem, null, null);
253 						crnt = null;
254 						nOfElements++;
255 					}
256 					else crnt = crnt.left;
257 				} else if(binaryFun!less(crnt.key, key)) {		//Key is greater, look ay right hand side
258 					if(crnt.right is null) {
259 						crnt.right = new Node(key, elem, null, null);
260 						crnt = null;
261 						nOfElements++;
262 					}
263 					else crnt = crnt.right;
264 				} else {	//Another best case scenario: a keymatch is found
265 					crnt.elem = elem;
266 					crnt = null;
267 				} 
268 			}
269 			rebalance();
270 			return elem;
271 		}
272 		/**
273 		 * Removes an item by key.
274 		 * Returns the removed item if found, or E.init if not.
275 		 */
276 		public E remove(T)(T key) @safe pure nothrow {
277 			import core.memory : GC;
278 			Node* crnt = root, prev;
279 			while(crnt !is null) {
280 				if(binaryFun!less(key, crnt.key)) {		//Key has a lesser value, search on the left.
281 					prev = crnt;
282 					crnt = crnt.left;
283 				} else if(binaryFun!less(crnt.key, key)) {		//Key has a greater value, search on the right
284 					prev = crnt;
285 					crnt = crnt.right;
286 				} else {				//Keymatch must have been found
287 					E result = crnt.elem;
288 					//dispose of the node properly if needed
289 					if(prev !is null) {
290 						if(crnt.left && crnt.right) {	//Worst case scenario: find the smallest node on the right hand side
291 							Node* temp = findMin(crnt.right);
292 							remove(temp.key);
293 							crnt.key = temp.key;
294 							crnt.elem = temp.elem;
295 							return result;
296 						} else if(!crnt.left && crnt.right) {
297 							if(binaryFun!less(key, prev.key)) {	//The node was on the left side of the previous one
298 								prev.left = crnt.right;
299 							} else {
300 								prev.right = crnt.right;
301 							}
302 						} else if(crnt.left && !crnt.right) {
303 							if(binaryFun!less(key, prev.key)) {	//The node was on the left side of the previous one
304 								prev.left = crnt.left;
305 							} else {
306 								prev.right = crnt.left;
307 							}
308 						} else { //Best case scenario: there are no child nodes, just dereference from prev
309 							if(binaryFun!less(key, prev.key)) {	//The node was on the left side of the previous one
310 								prev.left = null;
311 							} else {
312 								prev.right = null;
313 							}
314 						}
315 					} else {//must be root element
316 						if(crnt.left && crnt.right) {	//Worst case scenario: find the smallest node on the right hand side
317 							Node* temp = findMin(crnt.right);
318 							remove(temp.key);
319 							crnt.key = temp.key;
320 							crnt.elem = temp.elem;
321 							return result;
322 						} else if(!crnt.left && crnt.right) {
323 							root = crnt.right;
324 						} else if(crnt.left && !crnt.right) {
325 							root = crnt.left;
326 						} else { //Best case scenario: there are no child nodes, just dereference from root
327 							root = null;
328 						}
329 					}
330 					nOfElements--;
331 					rebalance();
332 					return result;
333 				}
334 			}
335 			return E.init;
336 		}
337 	} else {
338 		/**
339 		 * Puts an element into the TreeMap
340 		 */
341 		public K put(K key) @safe pure nothrow {
342 			if(!root){	//Best case scenario: root is empty
343 				nOfElements++;
344 				root = new Node(key, null, null);
345 				return key;
346 			}
347 			Node* crnt = root;
348 			while(crnt) {
349 				if(binaryFun!less(key, crnt.key)) {	//Key is smaller, look at left hand side
350 					if(crnt.left is null) {
351 						crnt.left = new Node(key, null, null);
352 						crnt = null;
353 						nOfElements++;
354 					}
355 					else crnt = crnt.left;
356 				} else {		//Key must be greater, look ay right hand side
357 					if(crnt.right is null) {
358 						crnt.right = new Node(key, null, null);
359 						crnt = null;
360 						nOfElements++;
361 					}
362 					else crnt = crnt.right;
363 				}
364 			}
365 			rebalance();
366 			return key;
367 		}
368 		/**
369 		 * Removes an item by key.
370 		 * Returns the removed item if found, or K.init if not.
371 		 */
372 		public K remove(K key) @safe pure nothrow {
373 			import core.memory : GC;
374 			Node* crnt = root, prev;
375 			while(crnt !is null) {
376 				if(binaryFun!less(key,crnt.key)) {		//Key has a lesser value, search on the left.
377 					prev = crnt;
378 					crnt = crnt.left;
379 				} else if(binaryFun!less(crnt.key, key)) {		//Key has a greater value, search on the right
380 					prev = crnt;
381 					crnt = crnt.right;
382 				} else {				//Key must have been found
383 					K result = crnt.key;
384 					//dispose of the node properly if needed
385 					if(prev !is null) {
386 						if(crnt.left && crnt.right) {	//Worst case scenario: find the smallest node on the right hand side
387 							Node* temp = findMin(crnt.right);
388 							remove(temp.key);
389 							crnt.key = temp.key;
390 							return result;
391 						} else if(!crnt.left && crnt.right) {
392 							if(binaryFun!less(key, prev.key)) {	//The node was on the left side of the previous one
393 								prev.left = crnt.right;
394 							} else {
395 								prev.right = crnt.right;
396 							}
397 						} else if(crnt.left && !crnt.right) {
398 							if(binaryFun!less(key, prev.key)) {	//The node was on the left side of the previous one
399 								prev.left = crnt.left;
400 							} else {
401 								prev.right = crnt.left;
402 							}
403 						} else { //Best case scenario: there are no child nodes, just dereference from prev
404 							if(binaryFun!less(key, prev.key)) {	//The node was on the left side of the previous one
405 								prev.left = null;
406 							} else {
407 								prev.right = null;
408 							}
409 						}
410 					} else {//must be root element
411 						if(crnt.left && crnt.right) {	//Worst case scenario: find the smallest node on the right hand side
412 							Node* temp = findMin(crnt.right);
413 							remove(temp.key);
414 							crnt.key = temp.key;
415 							return result;
416 						} else if(!crnt.left && crnt.right) {
417 							root = crnt.right;
418 						} else if(crnt.left && !crnt.right) {
419 							root = crnt.left;
420 						} else { //Best case scenario: there are no child nodes, just dereference from root
421 							root = null;
422 						}
423 					}
424 					nOfElements--;
425 					rebalance();
426 					return result;
427 				}
428 			}
429 			return K.init;
430 		}
431 	}
432 	/**
433 	 * Returns the smallest node
434 	 */
435 	private Node* findMin(Node* currentNode) @nogc @safe pure nothrow {
436 		while(currentNode.left !is null){
437 			currentNode = currentNode.left;
438 		}
439 		return currentNode;
440 	}
441 	/**
442 	 * Rebalances the tree.
443 	 */
444 	public void rebalance() @nogc @safe pure nothrow {
445 		void rebalanceLocal(ref Node* node) @nogc @safe pure nothrow {
446 			if(node.balance >= 2) {		//Right hand imbalance
447 				if(node.right.balance > 0) {
448 					rotateLeft(node);
449 				} else if(node.right.balance < 0) {
450 					rotateLeftRight(node);
451 				}
452 			} else if(node.balance <= -2) {		//Left hand imbalance
453 				if(node.left.balance < 0) {
454 					rotateRight(node);
455 				} else if(node.left.balance > 0) {
456 					rotateRightLeft(node);
457 				}
458 			}
459 			if(node.left) rebalanceLocal(node.left);
460 			if(node.right) rebalanceLocal(node.right);
461 		}
462 		if(root !is null)
463 			rebalanceLocal(root);
464 	}
465 	static if (E.stringof != "void"){
466 		/**
467 		 * Implements a simple left-to-right tree traversal.
468 		 */
469 		int opApply(scope int delegate(ref E) dg) {
470 			return root.opApply(dg);
471 		}
472 		/**
473 		 * Implements a simple left-to-right tree traversal.
474 		 */
475 		int opApply(scope int delegate(K, ref E) dg) {
476 			return root.opApply(dg);
477 		}
478 		/**
479 		 * Implements a simple right-to-left tree traversal.
480 		 */
481 		int opApplyReverse(scope int delegate(ref E) dg) {
482 			return root.opApplyReverse(dg);
483 		}
484 		/**
485 		 * Implements a simple right-to-left tree traversal.
486 		 */
487 		int opApplyReverse(scope int delegate(K, ref E) dg) {
488 			return root.opApplyReverse(dg);
489 		}
490 	} else {
491 		/**
492 		 * Implements a simple left-to-right tree traversal by depth.
493 		 */
494 		int opApply(scope int delegate(K) dg) {
495 			return root.opApply(dg);
496 		}
497 		/**
498 		 * Implements a simple right-to-left tree traversal.
499 		 */
500 		int opApplyReverse(scope int delegate(K) dg) {
501 			return root.opApplyReverse(dg);
502 		}
503 	}
504 	/**
505 	 * Tree rotation for rebalancing.
506 	 * Rotates the node to the left.
507 	 */
508 	private void rotateLeft(ref Node* node) @nogc @safe pure nothrow {
509 		Node* temp = node.right;
510 		node.right = temp.left;
511 		temp.left = node;
512 		node = temp;
513 	}
514 	/**
515 	 * Tree rotation for rebalancing.
516 	 * Rotates the node to the left.
517 	 */
518 	private void rotateRight(ref Node* node) @nogc @safe pure nothrow {
519 		Node* temp = node.left;
520 		node.left = temp.right;
521 		temp.right = node;
522 		node = temp;
523 	}
524 	/**
525 	 * Tree rotation for rebalancing.
526 	 * Rotates the node's right to the left, then the node to the right.
527 	 */
528 	private void rotateRightLeft(ref Node* node) @nogc @safe pure nothrow {
529 		rotateLeft(node.left);
530 		rotateRight(node);
531 	}
532 	/**
533 	 * Tree rotation for rebalancing.
534 	 * Rotates the node's left to the right, then the node to the left.
535 	 */
536 	private void rotateLeftRight(ref Node* node) @nogc @safe pure nothrow {
537 		rotateRight(node.right);
538 		rotateLeft(node);
539 	}
540 	/**
541 	 * Returns the number of currently held elements within the tree.
542 	 */
543 	public @property size_t length() @nogc @safe pure nothrow const {
544 		return nOfElements;
545 	}
546 	/**
547 	 * returns the string representation of the tree.
548 	 */
549 	public string toString() const {
550 		if(root !is null)
551 			return root.toString;
552 		else
553 			return "Empty";
554 	}
555 }
556 
557 unittest {
558 	import std.stdio : writeln, write;
559 	import std.random : uniform;
560 	import std.exception : assertThrown;
561 	{
562 		alias IntMap = TreeMap!(int, int, true);
563 		IntMap test0, test1, test2, test3;
564 		for(int i ; i < 1024 ; i++)//Stress test to see if large number of elements would cause any issues
565 			test0[uniform(0, 65536)] = i;
566 		foreach(k, e; test0){
567 
568 		}
569 		foreach_reverse(k, e; test0){
570 
571 		}
572 		for(int i ; i < 16 ; i++)
573 			test1[uniform(0, 65536)] = i;
574 		writeln(test1.toString);
575 		for(int i ; i < 32 ; i++)
576 			test2[uniform(0, 65536)] = i;
577 		writeln(test2.toString);
578 		for(int i ; i < 64 ; i++)
579 			test3[i] = i;
580 		for(int i ; i < 64 ; i++)
581 			write(test3[i],";");
582 		writeln();
583 		for(int i ; i < 16 ; i++)
584 			test3.remove(uniform(0,64));
585 		foreach(i ; test3)
586 			write(i,";");
587 		writeln();
588 	}
589 	{
590 		alias IntMap = TreeMap!(int, int, false);
591 		IntMap test0, test1;
592 		for(int i ; i < 64 ; i++)
593 			test0[i] = i;
594 		assert(test0.length == 64, "TreeMap length mismatch");
595 		assertThrown!ElementNotFoundException(test0[420]);
596 		assertThrown!ElementNotFoundException(test0[666]);
597 		for(int i ; i < 64 ; i++)
598 			test0.remove(i);
599 		assert(test0.length == 0, "Treemap item removal failure");
600 		for(int i ; i < 16 ; i++) {
601 			test1[i] = i;
602 			writeln(test1.toString);
603 		}
604 	}
605 	{
606 		alias IntMap = TreeMap!(int, void, true);
607 		IntMap test0;
608 		for(int i ; i < 64 ; i++) {
609 			test0.put(i);
610 			writeln(test0.toString());
611 		}
612 		assert(test0.length == 64, "TreeMap length mismatch");
613 		for(int i ; i < 64 ; i++) {
614 			test0.remove(i);
615 			writeln(test0.toString());
616 		}
617 		assert(test0.length == 0, "Treemap item removal failure");
618 	}
619 	{
620 		alias IntMap = TreeMap!(int, void, false);
621 		IntMap test0;
622 		for(int i ; i < 64 ; i++) {
623 			test0.put(i);
624 			writeln(test0.toString());
625 		}
626 		assert(test0.length == 64, "TreeMap length mismatch");
627 		assertThrown!ElementNotFoundException(test0[420]);
628 		assertThrown!ElementNotFoundException(test0[666]);
629 		for(int i ; i < 64 ; i++) {
630 			test0.remove(i);
631 			writeln(test0.toString());
632 		}
633 		assert(test0.length == 0, "Treemap item removal failure");
634 	}
635 }