import java.util.*; class AVLNode extends BNode { public int lh, rh; public static int max(int x, int y) { if (x > y) return x; return y; } public int fixHt() { lh = rh = -1; AVLNode l = (AVLNode) left; AVLNode r = (AVLNode) right; if (l != null) lh = 1 + max(l.lh, l.rh); if (r != null) rh = 1 + max(r.lh, r.rh); if (lh > rh) return lh; return rh; } public AVLNode(Object d, BNode p, BNode l, BNode r) { super(d, p, l, r); fixHt(); } } class AVLTree extends BST { public AVLTree() { super();} public void rebalance(AVLNode n) { if (n == null) return; int oldMax = AVLNode.max(n.lh, n.rh); int newMax = n.fixHt(); if (newMax > n.lh + 1 || newMax > n.rh + 1) { AVLNode cl = (AVLNode) n.getLeft(); AVLNode cr = (AVLNode) n.getRight(); if (n.lh > n.rh && cl.lh >= cl.rh) // single rotate at left rebuildNode(n, cl, cl.getLeft(), n, cl.getLeft().getLeft(), cl.getLeft().getRight(), cl.getRight(), cr); else if (n.lh > n.rh) // double rotate at left rebuildNode(n, cl.getRight(), cl, n, cl.getLeft(), cl.getRight().getLeft(), cl.getRight().getRight(), cr); else if (cr.rh >= cr.lh) // single rotate at right rebuildNode(n, cr, n, cr.getRight(), cl, cr.getLeft(), cr.getRight().getLeft(), cr.getRight().getRight()); else // double rotate at right rebuildNode(n, cr.getLeft(), n, cr, cl, cr.getLeft().getLeft(), cr.getLeft().getRight(), cr.getRight()); } newMax = n.fixHt(); if (oldMax != newMax) rebalance((AVLNode) n.getParent()); } public void rebuildNode(BNode n, BNode d, BNode dl, BNode dr, BNode t1, BNode t2, BNode t3, BNode t4) { AVLNode l = new AVLNode(dl.getData(), n, t1, t2); AVLNode r = new AVLNode(dr.getData(), n, t3, t4); if (t1 != null) t1.setParent(l); if (t2 != null) t2.setParent(l); if (t3 != null) t3.setParent(r); if (t4 != null) t4.setParent(r); l.fixHt(); r.fixHt(); n.setData(d.getData()); n.setLeft(l); n.setRight(r); } public void addRoot(Object x) { if (root != null) throw new RuntimeException(); root = new AVLNode(x, null, null, null); size++; } public void addLeft(BNode n, Object x) { if (n.getLeft() != null) throw new RuntimeException(); n.setLeft(new AVLNode(x, n, null, null)); size++; rebalance((AVLNode) n); } public void addRight(BNode n, Object x) { if (n.getRight() != null) throw new RuntimeException(); n.setRight(new AVLNode(x, n, null, null)); size++; rebalance((AVLNode) n); } public void remove(Comparable x) { BNode n = findNode(x); if (n == null || !n.getData().equals(x)) return; // x is not present AVLNode p = (AVLNode) removeNode(n); rebalance(p); } }