import Cython.Compiler.Main
import Cython.Compiler.Errors
from Cython.Compiler.Symtab import BuiltinScope, ModuleScope
import tempfile
import os

from Cython.Compiler.CythonNode import Document, walkdom

from Cython.Compiler.XPathTransform import XPathTransform, template

def w3cfix(tree):
    doc = Document()
    for node in walkdom(tree):
        if node.ownerDocument != doc:
            node.ownerDocument = doc
    doc.childNodes = [tree]


# The below duplicates enough of Main.py behaviour in order for isolated experimentation of
# parsing and playing with the tree. Sigh.
def parse_string_to_pyrex_tree(s):
    cython = Cython.Compiler.Main.Context([])
    fileno, source = tempfile.mkstemp(text=True)
    f = os.fdopen(fileno, "w")
    print >>f, s
    f.close()
    module_name = "testbed"
    initial_pos = (source, 1, 0)
    Cython.Compiler.Errors.open_listing_file(None)
    scope = cython.find_module(module_name, pos = initial_pos, need_pxd = 0)
    tree = cython.parse(source, scope.type_names, pxd = 0, full_module_name = module_name)
    os.unlink(source)
    w3cfix(tree)
    return tree


def dumpxml(tree):
    from xml.dom.ext import PrettyPrint
    from xml.dom.minidom import Document
    doc = Document()
    doc.appendChild(doc.importNode(A, deep=True))
    PrettyPrint(doc)

import Cython.Compiler.Nodes as Nodes
import Cython.Compiler.ExprNodes as ExprNodes

class ForInToForFrom(XPathTransform):
    @template("pyr:ForInStatNode[iterator/pyr:IteratorNode/sequence/pyr:SimpleCallNode/function/pyr:NameNode/@name = 'range']")
    def for_in_range_to_for_from_range(self, node):
        result = Nodes.ForFromStatNode(
            pos=node.pos,
            target=node.target,
            body=node.body,
            else_clause=node.else_clause,
            step=None,
            relation1 = "<=",
            relation2 = "<"
        )

        range_func = node.iterator.sequence
        if len(range_func.args) >= 2:
            result.bound1 = range_func.args[0]
            result.bound2 = range_func.args[1]
            if len(range_func.args) == 3:
                result.step = range_func.args[2]
        else:
            result.bound1 = ExprNodes.IntNode(pos=node.pos, value=0)
            result.bound2 = range_func.args[0]
        return result

A = parse_string_to_pyrex_tree("""
a = True
def foo():
    for i in range(10):
        print "Hello"
""")


pt = ForInToForFrom()

pt.initialize("testbed")
pt.process_tree(A)

dumpxml(A)


