#!/usr/bin/python
import PIL.Image, PIL.ImageFont, PIL.ImageDraw
import copy

class Node(object):
	def __init__(self, label):
		self.label = label
		self.childNodes = []
		self.attributes = {}
		self.parent = None
	def __iter__(self):
		class NodeIterator:
			baseNode = None
			_visited = []
			_current = None
			def __init__(self, baseNode):
				self.baseNode = baseNode
				self._current = baseNode
			def __iter__(self):
				return self
			def _traverse(self):
				if self._current not in self._visited:
					return True
				for child in self._current.getChildNodes():
					if child not in self._visited:
						self._current = child
						return self._traverse()
				if self._current is self.baseNode:
					return False
				self._current = self._current.getParentNode()
				return self._traverse()
			def next(self):
				if not self._traverse():
					raise StopIteration()
				self._visited.append(self._current)
				return self._current
		it = NodeIterator(self)
		return it
	def __str__(self):
		return self.label
	def addChild(self, child):
		self.childNodes.append(child)
		child.parent = self
	def newChild(self, label):
		node = Node(label)
		self.addChild(node)
		return node
	def getParentNode(self):
		return self.parent
	def getRootNode(self):
		parent = self
		while parent.getParentNode() is not None:
			parent = parent.getParentNode()
		return parent
	def getDepth(self):
		parent = self
		depth = 0
		while parent.getParentNode() is not None:
			parent = parent.getParentNode()
			depth += 1
		return depth
	def subtree(self):
		mcopy = copy.copy(self) # Shallow copy!		
		mcopy.baseNode = None
		return mcopy
	def removeChild(self, child):
		self.childNodes.remove(child)
	def setLabel(self, label):
		self.label = label
	def getLabel(self):
		return self.label
	def getChildNodes(self):
		return self.childNodes
	def setAttr(self, attr, val):
		self.attributes[attr] = val
	def getAttr(self, attr):
		if attr in self.attributes:
			return self.attributes[attr]
		else:
			return None
	
	def _plot_calcRequiredWidth(self):
		maxDepth = max(( child.getDepth() for child in self ))
		levels = [ 0 ] * (maxDepth + 1)
		for child in self:
			levels[child.getDepth()] += len(str(child))
		return max(levels)
	def plotNode(self):
		yScale = 50
		xScale = 15
		maxDepth = max(( child.getDepth() for child in self ))
		# Iterate through subtrees. Align nodes with parent nodes.
		widthRequirements = {}
		levelUsage = [ 0 ] * maxDepth
		for child in self:
			widthRequirements[child] = child._plot_calcRequiredWidth()
		positions = {}
		for child in self.getRootNode():
			# Add nodes left from all parent nodes to left-must
			leftMust = 0
			iter = child
			while True:
				parent = iter.getParentNode()
				if not parent: break
				for node in parent.getChildNodes():
					if node is iter: break
					leftMust += widthRequirements[node]
				iter = parent
			# Check width of same-level nodes and center within
			if child.getParentNode():
				leftMust += (widthRequirements[child.getParentNode()] - sum((len(str(x))
					for x in child.getParentNode().getChildNodes())) / 2)
			else:
				leftMust += widthRequirements[child] #(widthRequirements[child] - len(str(child))) / 2
			depth = child.getDepth()
			
			nodePos = (leftMust * xScale, depth * yScale)
			positions[child] = nodePos
		imageWidth = max(( positions[child][0] + len(str(child)) * xScale for child in self ))
		image = PIL.Image.new("RGB", (imageWidth, (maxDepth + 1) * yScale), "#fff")
		drawer = PIL.ImageDraw.Draw(image)
		font = PIL.ImageFont.load_default()
		for child in self.getRootNode():
			nodePos = positions[child]
			if child.getParentNode():
				parentPos = positions[child.getParentNode()]
				drawer.line([ (nodePos[0] + len(str(child))*xScale/2, nodePos[1] + 10),
					(parentPos[0] + len(str(child.getParentNode()))*xScale/2, parentPos[1] + 10) ], fill="#f00")
			drawer.ellipse([ nodePos[0] + len(str(child))*xScale/2 - 3, nodePos[1] + 10 - 3,
				nodePos[0] + len(str(child))*xScale/2 + 3, nodePos[1] + 10 + 3], fill="#f00")
			drawer.text(nodePos, str(child), font=font, fill="#000")
		del drawer
		return image

if __name__ == "__main__":
	test = 3
	if test == 1:
		test = Node("Baum")
		child1 = test.newChild("Kind 1")
		child2 = test.newChild("Kind 2")
		child1.newChild("Blatt")
		child3 = child2.newChild("Knoten")
		child3.newChild("Blatt")
		child3.newChild("Blatt")
		img = test.plotNode()
		img.save("out.png", "PNG")
	if test == 2:
		test = Node("Baum")
		test.newChild("L").newChild("LeftChild").newChild("LeftChild").newChild("Left")
		a = test.newChild("Right")
		a.newChild("Test")
		a.newChild("Test 2")
		img = test.plotNode()
		img.save("out.png", "PNG")
	if test == 3:
		import random
		test = Node("Baum")
		active = test
		for i in range(30):
			sub = active.newChild(str(i))
			do = random.randint(1,3)
			if do == 1:
				active = sub
			elif do == 2 and active.getParentNode():
				active = active.getParentNode()
		img = test.plotNode()
		img.save("out.png", "PNG")
