codegen and tests for modify operators. closes #16

This commit is contained in:
Josh Wolfe 2015-12-12 19:47:37 -07:00
parent 5cb5f5dbf6
commit 0f02e29a2b
3 changed files with 74 additions and 45 deletions

View File

@ -324,30 +324,33 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) {
zig_unreachable();
}
static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g,
LLVMValueRef val1, LLVMValueRef val2,
TypeTableEntry *op1_type, TypeTableEntry *op2_type,
AstNode *node)
{
assert(node->type == NodeTypeBinOpExpr);
LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
TypeTableEntry *op1_type = get_expr_type(node->data.bin_op_expr.op1);
TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
assert(op1_type == op2_type);
switch (node->data.bin_op_expr.bin_op) {
case BinOpTypeBinOr:
case BinOpTypeAssignBitOr:
add_debug_source_node(g, node);
return LLVMBuildOr(g->builder, val1, val2, "");
case BinOpTypeBinXor:
case BinOpTypeAssignBitXor:
add_debug_source_node(g, node);
return LLVMBuildXor(g->builder, val1, val2, "");
case BinOpTypeBinAnd:
case BinOpTypeAssignBitAnd:
add_debug_source_node(g, node);
return LLVMBuildAnd(g->builder, val1, val2, "");
case BinOpTypeBitShiftLeft:
case BinOpTypeAssignBitShiftLeft:
add_debug_source_node(g, node);
return LLVMBuildShl(g->builder, val1, val2, "");
case BinOpTypeBitShiftRight:
case BinOpTypeAssignBitShiftRight:
add_debug_source_node(g, node);
if (op1_type->id == TypeTableEntryIdInt) {
return LLVMBuildAShr(g->builder, val1, val2, "");
@ -355,6 +358,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
return LLVMBuildLShr(g->builder, val1, val2, "");
}
case BinOpTypeAdd:
case BinOpTypeAssignPlus:
add_debug_source_node(g, node);
if (op1_type->id == TypeTableEntryIdFloat) {
return LLVMBuildFAdd(g->builder, val1, val2, "");
@ -362,6 +366,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
return LLVMBuildNSWAdd(g->builder, val1, val2, "");
}
case BinOpTypeSub:
case BinOpTypeAssignMinus:
add_debug_source_node(g, node);
if (op1_type->id == TypeTableEntryIdFloat) {
return LLVMBuildFSub(g->builder, val1, val2, "");
@ -369,6 +374,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
return LLVMBuildNSWSub(g->builder, val1, val2, "");
}
case BinOpTypeMult:
case BinOpTypeAssignTimes:
add_debug_source_node(g, node);
if (op1_type->id == TypeTableEntryIdFloat) {
return LLVMBuildFMul(g->builder, val1, val2, "");
@ -376,6 +382,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
return LLVMBuildNSWMul(g->builder, val1, val2, "");
}
case BinOpTypeDiv:
case BinOpTypeAssignDiv:
add_debug_source_node(g, node);
if (op1_type->id == TypeTableEntryIdFloat) {
return LLVMBuildFDiv(g->builder, val1, val2, "");
@ -388,6 +395,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
}
}
case BinOpTypeMod:
case BinOpTypeAssignMod:
add_debug_source_node(g, node);
if (op1_type->id == TypeTableEntryIdFloat) {
return LLVMBuildFRem(g->builder, val1, val2, "");
@ -409,22 +417,23 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
case BinOpTypeCmpGreaterOrEq:
case BinOpTypeInvalid:
case BinOpTypeAssign:
case BinOpTypeAssignTimes:
case BinOpTypeAssignDiv:
case BinOpTypeAssignMod:
case BinOpTypeAssignPlus:
case BinOpTypeAssignMinus:
case BinOpTypeAssignBitShiftLeft:
case BinOpTypeAssignBitShiftRight:
case BinOpTypeAssignBitAnd:
case BinOpTypeAssignBitXor:
case BinOpTypeAssignBitOr:
case BinOpTypeAssignBoolAnd:
case BinOpTypeAssignBoolOr:
zig_unreachable();
}
zig_unreachable();
}
static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeBinOpExpr);
LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
TypeTableEntry *op1_type = get_expr_type(node->data.bin_op_expr.op1);
TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
return gen_arithmetic_bin_op(g, val1, val2, op1_type, op2_type, node);
}
static LLVMIntPredicate cmp_op_to_int_predicate(BinOpType cmp_op, bool is_signed) {
switch (cmp_op) {
@ -555,11 +564,8 @@ static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
AstNode *lhs_node = node->data.bin_op_expr.op1;
bool is_read_first = node->data.bin_op_expr.bin_op != BinOpTypeAssign;
if (is_read_first) {
zig_panic("TODO: implement modify assignment ops");
}
LLVMValueRef target_ref;
TypeTableEntry *op1_type;
if (lhs_node->type == NodeTypeSymbol) {
LocalVariableTableEntry *var = find_local_variable(node->codegen_node->expr_node.block_context,
&lhs_node->data.symbol);
@ -567,33 +573,30 @@ static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
// semantic checking ensures no variables are constant
assert(!var->is_const);
LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
add_debug_source_node(g, node);
return LLVMBuildStore(g->builder, value, var->value_ref);
op1_type = var->type;
target_ref = var->value_ref;
} else if (lhs_node->type == NodeTypeArrayAccessExpr) {
LLVMValueRef ptr = gen_array_ptr(g, lhs_node);
LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
add_debug_source_node(g, node);
return LLVMBuildStore(g->builder, value, ptr);
} else if (lhs_node->type == NodeTypeFieldAccessExpr) {
/*
LLVMValueRef ptr = gen_field_ptr(g, lhs_node);
LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
add_debug_source_node(g, node);
return LLVMBuildStore(g->builder, value, ptr);
*/
LLVMValueRef struct_val = gen_expr(g, lhs_node->data.field_access_expr.struct_expr);
assert(struct_val);
FieldAccessNode *codegen_field_access = &lhs_node->codegen_node->data.field_access_node;
assert(codegen_field_access->field_index >= 0);
LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
add_debug_source_node(g, node);
return LLVMBuildInsertValue(g->builder, struct_val, value, codegen_field_access->field_index, "");
TypeTableEntry *array_type = get_expr_type(lhs_node->data.array_access_expr.array_ref_expr);
assert(array_type->id == TypeTableEntryIdArray);
op1_type = array_type->data.array.child_type;
target_ref = gen_array_ptr(g, lhs_node);
} else {
zig_panic("bad assign target");
}
LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
if (node->data.bin_op_expr.bin_op == BinOpTypeAssign) {
// value is ready as is
} else {
add_debug_source_node(g, node->data.bin_op_expr.op1);
LLVMValueRef left_value = LLVMBuildLoad(g->builder, target_ref, "");
TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
value = gen_arithmetic_bin_op(g, left_value, value, op1_type, op2_type, node);
}
add_debug_source_node(g, node);
return LLVMBuildStore(g->builder, value, target_ref);
}
static LLVMValueRef gen_bin_op_expr(CodeGen *g, AstNode *node) {

View File

@ -402,6 +402,7 @@ void tokenize(Buf *buf, Tokenization *out) {
t.cur_tok->id = TokenIdBitShiftRightEq;
end_token(&t);
t.state = TokenizeStateStart;
break;
default:
t.pos -= 1;
end_token(&t);
@ -415,6 +416,7 @@ void tokenize(Buf *buf, Tokenization *out) {
t.cur_tok->id = TokenIdCmpLessOrEq;
end_token(&t);
t.state = TokenizeStateStart;
break;
case '<':
t.cur_tok->id = TokenIdBitShiftLeft;
t.state = TokenizeStateSawLessThanLessThan;
@ -432,6 +434,7 @@ void tokenize(Buf *buf, Tokenization *out) {
t.cur_tok->id = TokenIdBitShiftLeftEq;
end_token(&t);
t.state = TokenizeStateStart;
break;
default:
t.pos -= 1;
end_token(&t);

View File

@ -454,6 +454,29 @@ export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 {
}
)SOURCE", "OK 1\nOK 2\nOK 3\nOK 4\n");
add_simple_case("modify operators", R"SOURCE(
use "std.zig";
export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 {
let mut i : i32 = 0;
i += 5; if i != 5 { print_str("BAD +=\n" as string); }
i -= 2; if i != 3 { print_str("BAD -=\n" as string); }
i *= 20; if i != 60 { print_str("BAD *=\n" as string); }
i /= 3; if i != 20 { print_str("BAD /=\n" as string); }
i %= 11; if i != 9 { print_str("BAD %=\n" as string); }
i <<= 1; if i != 18 { print_str("BAD <<=\n" as string); }
i >>= 2; if i != 4 { print_str("BAD >>=\n" as string); }
i = 6;
i &= 5; if i != 4 { print_str("BAD &=\n" as string); }
i ^= 6; if i != 2 { print_str("BAD ^=\n" as string); }
i = 6;
i |= 3; if i != 7 { print_str("BAD |=\n" as string); }
print_str("OK\n" as string);
return 0;
}
)SOURCE", "OK\n");
}
static void add_compile_failure_test_cases(void) {