view src/ast_convert.rs @ 49:141f1769e1f0

Add ast.Pass.
author Bastien Orivel <eijebong@bananium.fr>
date Wed, 08 Jun 2016 17:18:28 +0200
parents 039f85b187f2
children 5edbc24b625f
line wrap: on
line source

use python_ast::{Module, stmt, expr, expr_context, cmpop, boolop, operator, unaryop, arguments, arg, alias, comprehension};

use cpython::{Python, PyObject};
use cpython::ObjectProtocol; //for call method

fn get_str(py: Python, object: PyObject) -> String {
    let pystring = object.str(py).unwrap();
    let mut string = pystring.to_string(py).unwrap();
    string.to_mut().to_string()
}

fn get_ctx(py: Python, object: PyObject) -> expr_context {
    let builtins_module = py.import("builtins").unwrap();
    let isinstance = builtins_module.get(py, "isinstance").unwrap();

    let is_instance = |object: &PyObject, type_: &PyObject| {
        return isinstance.call(py, (object, type_), None).unwrap().is_true(py).unwrap();
    };

    let ast_module = py.import("ast").unwrap();
    let store_type = ast_module.get(py, "Store").unwrap();
    let load_type = ast_module.get(py, "Load").unwrap();
    let del_type = ast_module.get(py, "Del").unwrap();

    let ctx = object.getattr(py, "ctx").unwrap();
    if is_instance(&ctx, &store_type) {
        expr_context::Store
    } else if is_instance(&ctx, &load_type) {
        expr_context::Load
    } else if is_instance(&ctx, &del_type) {
        expr_context::Del
    } else{
        unreachable!();
    }
}

fn parse_list<T, F: Fn(Python, PyObject) -> T>(py: Python, list: PyObject, parse: F) -> Vec<T> {
    let mut exprs = vec!();
    for item in list.iter(py).unwrap() {
        let item = item.unwrap();
        let item = parse(py, item);
        exprs.push(item);
    }
    exprs
}

fn parse_alias(py: Python, ast: PyObject) -> alias {
    let ast_alias = get_str(py, ast.getattr(py, "name").unwrap());
    let asname =  ast.getattr(py, "asname").unwrap();
    if asname == py.None() {
        alias{name: ast_alias, asname: None}
    } else {
        alias{name: ast_alias, asname: Some(get_str(py, asname))}
    }
}

fn parse_unaryop(py: Python, ast: PyObject) -> unaryop {
    let builtins_module = py.import("builtins").unwrap();
    let isinstance = builtins_module.get(py, "isinstance").unwrap();

    let is_instance = |object: &PyObject, type_: &PyObject| {
        return isinstance.call(py, (object, type_), None).unwrap().is_true(py).unwrap();
    };

    let ast_module = py.import("ast").unwrap();
    let ast_type = ast_module.get(py, "AST").unwrap();
    let invert_type = ast_module.get(py, "Invert").unwrap();
    let not_type = ast_module.get(py, "Not").unwrap();
    let uadd_type = ast_module.get(py, "UAdd").unwrap();
    let usub_type = ast_module.get(py, "USub").unwrap();

    assert!(is_instance(&ast, &ast_type));

    if is_instance(&ast, &invert_type) {
        unaryop::Invert
    } else if is_instance(&ast, &not_type) {
        unaryop::Not
    } else if is_instance(&ast, &uadd_type) {
        unaryop::UAdd
    } else if is_instance(&ast, &usub_type) {
        unaryop::USub
    } else {
        unreachable!()
    }
}

fn parse_boolop(py: Python, ast: PyObject) -> boolop {
    let builtins_module = py.import("builtins").unwrap();
    let isinstance = builtins_module.get(py, "isinstance").unwrap();

    let is_instance = |object: &PyObject, type_: &PyObject| {
        return isinstance.call(py, (object, type_), None).unwrap().is_true(py).unwrap();
    };

    let ast_module = py.import("ast").unwrap();
    let ast_type = ast_module.get(py, "AST").unwrap();
    let and_type = ast_module.get(py, "And").unwrap();
    let or_type = ast_module.get(py, "Or").unwrap();

    assert!(is_instance(&ast, &ast_type));

    if is_instance(&ast, &and_type) {
        boolop::And
    } else if is_instance(&ast, &or_type) {
        boolop::Or
    } else {
        unreachable!()
    }
}

fn parse_cmpop(py: Python, ast: PyObject) -> cmpop {
    let builtins_module = py.import("builtins").unwrap();
    let isinstance = builtins_module.get(py, "isinstance").unwrap();

    let is_instance = |object: &PyObject, type_: &PyObject| {
        return isinstance.call(py, (object, type_), None).unwrap().is_true(py).unwrap();
    };

    let ast_module = py.import("ast").unwrap();
    let ast_type = ast_module.get(py, "AST").unwrap();
    let eq_type = ast_module.get(py, "Eq").unwrap();
    let noteq_type = ast_module.get(py, "NotEq").unwrap();
    let lt_type = ast_module.get(py, "Lt").unwrap();
    let lte_type = ast_module.get(py, "LtE").unwrap();
    let gt_type = ast_module.get(py, "Gt").unwrap();
    let gte_type = ast_module.get(py, "GtE").unwrap();
    let is_type = ast_module.get(py, "Is").unwrap();
    let is_not_type = ast_module.get(py, "IsNot").unwrap();
    let in_type = ast_module.get(py, "In").unwrap();
    let not_in_type = ast_module.get(py, "NotIn").unwrap();

    assert!(is_instance(&ast, &ast_type));

    if is_instance(&ast, &eq_type) {
        cmpop::Eq
    } else if is_instance(&ast, &noteq_type) {
        cmpop::NotEq
    } else if is_instance(&ast, &lt_type) {
        cmpop::Lt
    } else if is_instance(&ast, &lte_type) {
        cmpop::LtE
    } else if is_instance(&ast, &gt_type) {
        cmpop::Gt
    } else if is_instance(&ast, &gte_type) {
        cmpop::GtE
    } else if is_instance(&ast, &is_type) {
        cmpop::Is
    } else if is_instance(&ast, &is_not_type) {
        cmpop::IsNot
    } else if is_instance(&ast, &in_type) {
        cmpop::In
    } else if is_instance(&ast, &not_in_type) {
        cmpop::NotIn
    } else {
        unreachable!()
    }
}

fn parse_comprehension(py: Python, ast: PyObject) -> comprehension {
    let target = ast.getattr(py, "target").unwrap();
    let iter = ast.getattr(py, "iter").unwrap();
    let ifs = ast.getattr(py, "ifs").unwrap();

    let target = parse_expr(py, target);
    let iter = parse_expr(py, iter);
    let ifs = parse_list(py, ifs, parse_expr);

    comprehension {target: target, iter: iter, ifs: ifs}
}

fn parse_operator(py: Python, ast: PyObject) -> operator {
    let builtins_module = py.import("builtins").unwrap();
    let isinstance = builtins_module.get(py, "isinstance").unwrap();

    let is_instance = |object: &PyObject, type_: &PyObject| {
        return isinstance.call(py, (object, type_), None).unwrap().is_true(py).unwrap();
    };

    let ast_module = py.import("ast").unwrap();
    let ast_type = ast_module.get(py, "AST").unwrap();
    let add_type = ast_module.get(py, "Add").unwrap();
    let sub_type = ast_module.get(py, "Sub").unwrap();
    let mult_type = ast_module.get(py, "Mult").unwrap();
    let matmult_type = ast_module.get(py, "MatMult").unwrap();
    let div_type = ast_module.get(py, "Div").unwrap();
    let mod_type = ast_module.get(py, "Mod").unwrap();
    let pow_type = ast_module.get(py, "Pow").unwrap();
    let lshift_type = ast_module.get(py, "LShift").unwrap();
    let rshift_type = ast_module.get(py, "RShift").unwrap();
    let bitor_type = ast_module.get(py, "BitOr").unwrap();
    let bitxor_type = ast_module.get(py, "BitXor").unwrap();
    let bitand_type = ast_module.get(py, "BitAnd").unwrap();
    let floordiv_type = ast_module.get(py, "FloorDiv").unwrap();

    assert!(is_instance(&ast, &ast_type));

    if is_instance(&ast, &add_type) {
        operator::Add
    } else if is_instance(&ast, &sub_type) {
        operator::Sub
    } else if is_instance(&ast, &mult_type) {
        operator::Mult
    } else if is_instance(&ast, &matmult_type) {
        operator::MatMult
    } else if is_instance(&ast, &div_type) {
        operator::Div
    } else if is_instance(&ast, &mod_type) {
        operator::Mod
    } else if is_instance(&ast, &pow_type) {
        operator::Pow
    } else if is_instance(&ast, &lshift_type) {
        operator::LShift
    } else if is_instance(&ast, &rshift_type) {
        operator::RShift
    } else if is_instance(&ast, &bitor_type) {
        operator::BitOr
    } else if is_instance(&ast, &bitxor_type) {
        operator::BitXor
    } else if is_instance(&ast, &bitand_type) {
        operator::BitAnd
    } else if is_instance(&ast, &floordiv_type) {
        operator::FloorDiv
    } else {
        println!("operator {}", ast);
        panic!()
    }
}

fn parse_expr(py: Python, ast: PyObject) -> expr {
    let builtins_module = py.import("builtins").unwrap();
    let isinstance = builtins_module.get(py, "isinstance").unwrap();

    let is_instance = |object: &PyObject, type_: &PyObject| {
        return isinstance.call(py, (object, type_), None).unwrap().is_true(py).unwrap();
    };

    let ast_module = py.import("ast").unwrap();
    let ast_type = ast_module.get(py, "AST").unwrap();
    let arg_type = ast_module.get(py, "arg").unwrap();
    let unary_op_type = ast_module.get(py, "UnaryOp").unwrap();
    let bool_op_type = ast_module.get(py, "BoolOp").unwrap();
    let bin_op_type = ast_module.get(py, "BinOp").unwrap();
    let name_constant_type = ast_module.get(py, "NameConstant").unwrap();
    let attribute_type = ast_module.get(py, "Attribute").unwrap();
    let name_type = ast_module.get(py, "Name").unwrap();
    let num_type = ast_module.get(py, "Num").unwrap();
    let str_type = ast_module.get(py, "Str").unwrap();
    let list_type = ast_module.get(py, "List").unwrap();
    let compare_type = ast_module.get(py, "Compare").unwrap();
    let call_type = ast_module.get(py, "Call").unwrap();
    let listcomp_type = ast_module.get(py, "ListComp").unwrap();
    let dictcomp_type = ast_module.get(py, "DictComp").unwrap();
    let tuple_type = ast_module.get(py, "Tuple").unwrap();

    assert!(is_instance(&ast, &ast_type));

    if is_instance(&ast, &arg_type) {
        let arg = ast.getattr(py, "arg").unwrap();
        let arg = get_str(py, arg);
        expr::Name(arg, get_ctx(py, ast))
    } else if is_instance(&ast, &attribute_type) {
        let value = ast.getattr(py, "value").unwrap();
        let attr = ast.getattr(py, "attr").unwrap();

        let value = parse_expr(py, value);
        let attr = get_str(py, attr);

        expr::Attribute(Box::new(value), attr, get_ctx(py, ast))
    } else if is_instance(&ast, &name_type) {
        let id = ast.getattr(py, "id").unwrap();
        let id = get_str(py, id);
        expr::Name(id, get_ctx(py, ast))
    } else if is_instance(&ast, &name_constant_type) {
        let value = ast.getattr(py, "value").unwrap();
        let value = get_str(py, value);
        expr::NameConstant(value)
    } else if is_instance(&ast, &num_type) {
        let n = ast.getattr(py, "n").unwrap();
        let n = get_str(py, n);
        expr::Num(n)
    } else if is_instance(&ast, &str_type) {
        let s = ast.getattr(py, "s").unwrap();
        let s = get_str(py, s);
        expr::Str(s)
    } else if is_instance(&ast, &list_type) {
        let elts = ast.getattr(py, "elts").unwrap();
        let elements = parse_list(py, elts, parse_expr);
        expr::List(elements, get_ctx(py, ast))
    } else if is_instance(&ast, &unary_op_type) {
        let op = ast.getattr(py, "op").unwrap();
        let operand = ast.getattr(py, "operand").unwrap();

        let op = parse_unaryop(py, op);
        let operand = parse_expr(py, operand);

        expr::UnaryOp(op, Box::new(operand))
    } else if is_instance(&ast, &bool_op_type) {
        let op = ast.getattr(py, "op").unwrap();
        let values = ast.getattr(py, "values").unwrap();

        let op = parse_boolop(py, op);
        let values = parse_list(py, values, parse_expr);

        expr::BoolOp(op, values)
    } else if is_instance(&ast, &bin_op_type) {
        let left = ast.getattr(py, "left").unwrap();
        let op = ast.getattr(py, "op").unwrap();
        let right = ast.getattr(py, "right").unwrap();

        let left = parse_expr(py, left);
        let op = parse_operator(py, op);
        let right = parse_expr(py, right);

        expr::BinOp(Box::new(left), op, Box::new(right))
    } else if is_instance(&ast, &call_type) {
        let func = ast.getattr(py, "func").unwrap();
        let args = ast.getattr(py, "args").unwrap();
        let keywords = ast.getattr(py, "keywords").unwrap();

        let func = parse_expr(py, func);
        let args = parse_list(py, args, parse_expr);

        /*
        let mut kwargs = vec!();
        for arg in keywords.iter(py).unwrap() {
            let arg = arg.unwrap();
            kwargs.push(parse_expr(py, arg));
        }
        */

        expr::Call(Box::new(func), args, vec!())
    } else if is_instance(&ast, &compare_type) {
        let left = ast.getattr(py, "left").unwrap();
        let ops = ast.getattr(py, "ops").unwrap();
        let comparators = ast.getattr(py, "comparators").unwrap();

        let left = parse_expr(py, left);
        let ops = parse_list(py, ops, parse_cmpop);
        let comparators = parse_list(py, comparators, parse_expr);

        expr::Compare(Box::new(left), ops, comparators)
    } else if is_instance(&ast, &listcomp_type) {
        let elt = ast.getattr(py, "elt").unwrap();
        let generators = ast.getattr(py, "generators").unwrap();

        let elt = parse_expr(py, elt);
        let generators = parse_list(py, generators, parse_comprehension);

        expr::ListComp(Box::new(elt), generators)
    } else if is_instance(&ast, &dictcomp_type) {
        let key = ast.getattr(py, "key").unwrap();
        let value = ast.getattr(py, "value").unwrap();
        let generators = ast.getattr(py, "generators").unwrap();

        let key = parse_expr(py, key);
        let value = parse_expr(py, value);
        let generators = parse_list(py, generators, parse_comprehension);

        expr::DictComp(Box::new(key), Box::new(value), generators)
    } else if is_instance(&ast, &tuple_type) {
        let elts = ast.getattr(py, "elts").unwrap();
        let elts = parse_list(py, elts, parse_expr);
        expr::Tuple(elts, get_ctx(py, ast))
    } else {
        println!("expr {}", ast);
        unreachable!()
    }
}

fn parse_arguments(py: Python, ast: PyObject) -> arguments {
    let builtins_module = py.import("builtins").unwrap();
    let isinstance = builtins_module.get(py, "isinstance").unwrap();

    let is_instance = |object: &PyObject, type_: &PyObject| {
        return isinstance.call(py, (object, type_), None).unwrap().is_true(py).unwrap();
    };

    let ast_module = py.import("ast").unwrap();
    let ast_type = ast_module.get(py, "AST").unwrap();
    let arguments_type = ast_module.get(py, "arguments").unwrap();
    let arg_type = ast_module.get(py, "arg").unwrap();

    assert!(is_instance(&ast, &ast_type));

    if is_instance(&ast, &arguments_type) {
        let args = arguments{
            //args: Vec<arg>,
            args: {
                let args = ast.getattr(py, "args").unwrap();
                let mut arguments = vec!();
                for arg in args.iter(py).unwrap() {
                    let arg = arg.unwrap();
                    assert!(is_instance(&arg, &arg_type));
                    let arg = get_str(py, arg);
                    arguments.push(arg{arg: arg, annotation: None});
                }
                arguments
            },
            //vararg: Option<arg>,
            vararg: None,
            //kwonlyargs: Vec<arg>,
            kwonlyargs: vec!(),
            //kw_defaults: Vec<expr>,
            kw_defaults: vec!(),
            //kwarg: Option<arg>,
            kwarg: None,
            //defaults: Vec<expr>
            defaults: vec!()
        };
        args
    } else {
        println!("arguments {}", ast);
        panic!()
    }
}

fn parse_statement(py: Python, ast: PyObject) -> stmt {
    //stmt::FunctionDef(expr::Name("function".to_string()), vec!(expr::Name("a".to_string()), expr::Name("b".to_string())), vec!())
    //stmt::If(expr::BinOp(BinOp::BinEq, Box::new(expr::Name("__name__".to_string())), Box::new(expr::Str("__main__".to_string()))), vec!(stmt::Expr(expr::Call(Box::new(expr::Name("function".to_string())), vec!(expr::Num(1), expr::Num(2))))))

    let builtins_module = py.import("builtins").unwrap();
    let isinstance = builtins_module.get(py, "isinstance").unwrap();

    let is_instance = |object: &PyObject, type_: &PyObject| {
        return isinstance.call(py, (object, type_), None).unwrap().is_true(py).unwrap();
    };

    let ast_module = py.import("ast").unwrap();
    let ast_type = ast_module.get(py, "AST").unwrap();
    let class_def_type = ast_module.get(py, "ClassDef").unwrap();
    let function_def_type = ast_module.get(py, "FunctionDef").unwrap();
    let global_type = ast_module.get(py, "Global").unwrap();
    let assign_type = ast_module.get(py, "Assign").unwrap();
    let aug_assign_type = ast_module.get(py, "AugAssign").unwrap();
    let return_type = ast_module.get(py, "Return").unwrap();
    let import_from_type = ast_module.get(py, "ImportFrom").unwrap();
    let import_type = ast_module.get(py, "Import").unwrap();
    let if_type = ast_module.get(py, "If").unwrap();
    let while_type = ast_module.get(py, "While").unwrap();
    let for_type = ast_module.get(py, "For").unwrap();
    let expr_type = ast_module.get(py, "Expr").unwrap();
    let break_type = ast_module.get(py, "Break").unwrap();
    let delete_type = ast_module.get(py, "Delete").unwrap();
    let pass_type = ast_module.get(py, "Pass").unwrap();

    assert!(is_instance(&ast, &ast_type));

    /*
    // TODO: implement Hash for PyObject. (trivial)
    let map = {
        let fields = ast.getattr(py, "_fields").unwrap();
        let mut map = HashMap::new();
        for field in fields.iter(py).unwrap() {
            let field = field.unwrap();
            let value = ast.getattr(py, field).unwrap();
            map.insert(field, value);
        }
        map
    };
    */

    if is_instance(&ast, &class_def_type) {
        let name = ast.getattr(py, "name").unwrap();
        let bases = ast.getattr(py, "bases").unwrap();
        //let keywords = ast.getattr(py, "keywords").unwrap();
        let body = ast.getattr(py, "body").unwrap();
        //let decorator_list = ast.getattr(py, "decorator_list").unwrap();

        let name = get_str(py, name);
        let bases = parse_list(py, bases, parse_expr);
        let body = parse_list(py, body, parse_statement);

        stmt::ClassDef(name, bases, vec!(), body, vec!())
    } else if is_instance(&ast, &function_def_type) {
        let name = ast.getattr(py, "name").unwrap();
        let args = ast.getattr(py, "args").unwrap();
        let body = ast.getattr(py, "body").unwrap();

        let name = get_str(py, name);
        let args = parse_arguments(py, args);
        let body = parse_list(py, body, parse_statement);

        let decorators = vec!();
        let returns = None;

        stmt::FunctionDef(name, args, body, decorators, returns)
    } else if is_instance(&ast, &global_type) {
        let names = ast.getattr(py, "names").unwrap();
        let names = parse_list(py, names, get_str);
        stmt::Global(names)
    } else if is_instance(&ast, &if_type) {
        let test = ast.getattr(py, "test").unwrap();
        let body = ast.getattr(py, "body").unwrap();
        let orelse = ast.getattr(py, "orelse").unwrap();

        let test = parse_expr(py, test);
        let body = parse_list(py, body, parse_statement);
        let orelse = parse_list(py, orelse, parse_statement);

        stmt::If(test, body, orelse)
    } else if is_instance(&ast, &while_type) {
        let test = ast.getattr(py, "test").unwrap();
        let body = ast.getattr(py, "body").unwrap();
        let orelse = ast.getattr(py, "orelse").unwrap();

        let test = parse_expr(py, test);
        let body = parse_list(py, body, parse_statement);
        let orelse = parse_list(py, orelse, parse_statement);

        stmt::While(test, body, orelse)
    } else if is_instance(&ast, &for_type) {
        let target = ast.getattr(py, "target").unwrap();
        let iter = ast.getattr(py, "iter").unwrap();
        let body = ast.getattr(py, "body").unwrap();
        let orelse = ast.getattr(py, "orelse").unwrap();

        let target = parse_expr(py, target);
        let iter = parse_expr(py, iter);
        let body = parse_list(py, body, parse_statement);
        let orelse = parse_list(py, orelse, parse_statement);

        stmt::For(target, iter, body, orelse)
    } else if is_instance(&ast, &assign_type) {
        let targets = ast.getattr(py, "targets").unwrap();
        let value = ast.getattr(py, "value").unwrap();

        let targets = parse_list(py, targets, parse_expr);
        let value = parse_expr(py, value);

        stmt::Assign(targets, value)
    } else if is_instance(&ast, &aug_assign_type) {
        let target = ast.getattr(py, "target").unwrap();
        let op = ast.getattr(py, "op").unwrap();
        let value = ast.getattr(py, "value").unwrap();

        let target = parse_expr(py, target);
        let op = parse_operator(py, op);
        let value = parse_expr(py, value);

        stmt::AugAssign(target, op, value)
    } else if is_instance(&ast, &import_from_type) {
        let module = ast.getattr(py, "module").unwrap();
        let names = ast.getattr(py, "names").unwrap();
        let level = ast.getattr(py, "level").unwrap();

        let module = get_str(py, module);
        let names = parse_list(py, names, parse_alias);

        if level == py.None() {
            stmt::ImportFrom(module, names, None)
        } else {
            let level = level.extract(py).unwrap();
            stmt::ImportFrom(module, names, Some(level))
        }
    } else if is_instance(&ast, &import_type) {
        let names = ast.getattr(py, "names").unwrap();
        let names = parse_list(py, names, parse_alias);

        stmt::Import(names)
    } else if is_instance(&ast, &return_type) {
        let value = ast.getattr(py, "value").unwrap();
        if value == py.None() {
            stmt::Return(None)
        } else {
            let value = parse_expr(py, value);
            stmt::Return(Some(value))
        }
    } else if is_instance(&ast, &expr_type) {
        let value = ast.getattr(py, "value").unwrap();
        let value = parse_expr(py, value);
        stmt::Expr(value)
    } else if is_instance(&ast, &break_type) {
        stmt::Break
    } else if is_instance(&ast, &delete_type) {
        let targets = ast.getattr(py, "targets").unwrap();
        let targets = parse_list(py, targets, parse_expr);
        stmt::Delete(targets)
    } else if is_instance(&ast, &pass_type) {
        stmt::Pass
    } else {
        println!("stmt {}", ast);
        panic!()
    }
}

#[allow(dead_code)]
pub fn convert_ast(name: String, module: &PyObject) -> Module {
    let gil = Python::acquire_gil();
    let py = gil.python();

    let builtins_module = py.import("builtins").unwrap();
    let isinstance = builtins_module.get(py, "isinstance").unwrap();

    let ast_module = py.import("ast").unwrap();
    let module_type = ast_module.get(py, "Module").unwrap();

    assert!(isinstance.call(py, (module, module_type), None).unwrap().is_true(py).unwrap());

    let body = module.getattr(py, "body").unwrap();
    let body = parse_list(py, body, parse_statement);
    Module{name: name, statements: body}
}