diff --git a/example/guess_number/main.zig b/example/guess_number/main.zig index b70a890858..084674fe93 100644 --- a/example/guess_number/main.zig +++ b/example/guess_number/main.zig @@ -27,6 +27,8 @@ pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 { print_u64(answer); print_str("\n"); + return 0; + /* while (true) { const line = readline("\nGuess a number between 1 and 100: "); @@ -45,6 +47,4 @@ pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 { } } */ - - return 0; } diff --git a/src/analyze.cpp b/src/analyze.cpp index 162c2ab582..8174c2c0d6 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -12,6 +12,8 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, BlockContext *context, TypeTableEntry *expected_type, AstNode *node); +static TypeTableEntry *eval_const_expr(CodeGen *g, BlockContext *context, + AstNode *node, AstNodeNumberLiteral *out_number_literal); static AstNode *first_executing_node(AstNode *node) { switch (node->type) { @@ -284,6 +286,98 @@ static TypeTableEntry *get_unknown_size_array_type(CodeGen *g, ImportTableEntry } } +static TypeTableEntry *eval_const_expr_bin_op(CodeGen *g, BlockContext *context, + AstNode *node, AstNodeNumberLiteral *out_number_literal) +{ + AstNodeNumberLiteral op1_lit; + AstNodeNumberLiteral op2_lit; + TypeTableEntry *op1_type = eval_const_expr(g, context, node->data.bin_op_expr.op1, &op1_lit); + TypeTableEntry *op2_type = eval_const_expr(g, context, node->data.bin_op_expr.op1, &op2_lit); + + if (op1_type->id == TypeTableEntryIdInvalid || + op2_type->id == TypeTableEntryIdInvalid) + { + return g->builtin_types.entry_invalid; + } + + // TODO complete more of this function instead of returning invalid + // returning invalid makes the "unable to evaluate constant expression" error + + switch (node->data.bin_op_expr.bin_op) { + case BinOpTypeCmpNotEq: + { + if (is_num_lit_unsigned(op1_lit.kind) && + is_num_lit_unsigned(op2_lit.kind)) + { + out_number_literal->kind = NumLitU8; + out_number_literal->overflow = false; + out_number_literal->data.x_uint = (op1_lit.data.x_uint != op2_lit.data.x_uint); + return node->codegen_node->expr_node.type_entry; + } else { + return g->builtin_types.entry_invalid; + } + } + case BinOpTypeCmpLessThan: + { + if (is_num_lit_unsigned(op1_lit.kind) && + is_num_lit_unsigned(op2_lit.kind)) + { + out_number_literal->kind = NumLitU8; + out_number_literal->overflow = false; + out_number_literal->data.x_uint = (op1_lit.data.x_uint < op2_lit.data.x_uint); + return node->codegen_node->expr_node.type_entry; + } else { + return g->builtin_types.entry_invalid; + } + } + case BinOpTypeMod: + { + if (is_num_lit_unsigned(op1_lit.kind) && + is_num_lit_unsigned(op2_lit.kind)) + { + out_number_literal->kind = NumLitU64; + out_number_literal->overflow = false; + out_number_literal->data.x_uint = (op1_lit.data.x_uint % op2_lit.data.x_uint); + return node->codegen_node->expr_node.type_entry; + } else { + return g->builtin_types.entry_invalid; + } + } + case BinOpTypeBoolOr: + case BinOpTypeBoolAnd: + case BinOpTypeCmpEq: + case BinOpTypeCmpGreaterThan: + case BinOpTypeCmpLessOrEq: + case BinOpTypeCmpGreaterOrEq: + case BinOpTypeBinOr: + case BinOpTypeBinXor: + case BinOpTypeBinAnd: + case BinOpTypeBitShiftLeft: + case BinOpTypeBitShiftRight: + case BinOpTypeAdd: + case BinOpTypeSub: + case BinOpTypeMult: + case BinOpTypeDiv: + return g->builtin_types.entry_invalid; + case BinOpTypeInvalid: + case BinOpTypeAssign: + case BinOpTypeAssignTimes: + case BinOpTypeAssignDiv: + case BinOpTypeAssignMod: + case BinOpTypeAssignPlus: + case BinOpTypeAssignMinus: + case BinOpTypeAssignBitShiftLeft: + case BinOpTypeAssignBitShiftRight: + case BinOpTypeAssignBitAnd: + case BinOpTypeAssignBitXor: + case BinOpTypeAssignBitOr: + case BinOpTypeAssignBoolAnd: + case BinOpTypeAssignBoolOr: + zig_unreachable(); + } + zig_unreachable(); +} + static TypeTableEntry *eval_const_expr(CodeGen *g, BlockContext *context, AstNode *node, AstNodeNumberLiteral *out_number_literal) { @@ -291,9 +385,11 @@ static TypeTableEntry *eval_const_expr(CodeGen *g, BlockContext *context, case NodeTypeNumberLiteral: *out_number_literal = node->data.number_literal; return node->codegen_node->expr_node.type_entry; + case NodeTypeBoolLiteral: + out_number_literal->data.x_uint = node->data.bool_literal ? 1 : 0; + return node->codegen_node->expr_node.type_entry; case NodeTypeBinOpExpr: - zig_panic("TODO eval_const_expr bin op expr"); - break; + return eval_const_expr_bin_op(g, context, node, out_number_literal); case NodeTypeCompilerFnType: { Buf *name = &node->data.compiler_fn_type.name; @@ -1133,8 +1229,12 @@ BlockContext *new_block_context(AstNode *node, BlockContext *parent) { context->variable_table.init(8); if (parent) { - context->break_allowed = parent->break_allowed || parent->next_child_break_allowed; - parent->next_child_break_allowed = false; + if (parent->next_child_parent_loop_node) { + context->parent_loop_node = parent->next_child_parent_loop_node; + parent->next_child_parent_loop_node = nullptr; + } else { + context->parent_loop_node = parent->parent_loop_node; + } } if (node && node->type == NodeTypeFnDef) { @@ -1690,20 +1790,45 @@ static TypeTableEntry *analyze_struct_val_expr(CodeGen *g, ImportTableEntry *imp static TypeTableEntry *analyze_while_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, TypeTableEntry *expected_type, AstNode *node) { - analyze_expression(g, import, context, g->builtin_types.entry_bool, node->data.while_expr.condition); + AstNode *condition_node = node->data.while_expr.condition; + AstNode *while_body_node = node->data.while_expr.body; + TypeTableEntry *condition_type = analyze_expression(g, import, context, + g->builtin_types.entry_bool, condition_node); - context->next_child_break_allowed = true; - analyze_expression(g, import, context, g->builtin_types.entry_void, node->data.while_expr.body); + context->next_child_parent_loop_node = node; + analyze_expression(g, import, context, g->builtin_types.entry_void, while_body_node); - return g->builtin_types.entry_void; + + TypeTableEntry *expr_return_type = g->builtin_types.entry_void; + + if (condition_type->id == TypeTableEntryIdInvalid) { + expr_return_type = g->builtin_types.entry_invalid; + } else { + // if the condition is a simple constant expression and there are no break statements + // then the return type is unreachable + AstNodeNumberLiteral number_literal; + TypeTableEntry *resolved_type = eval_const_expr(g, context, condition_node, &number_literal); + if (resolved_type->id != TypeTableEntryIdInvalid) { + assert(resolved_type->id == TypeTableEntryIdBool); + bool constant_cond_value = number_literal.data.x_uint; + if (constant_cond_value && !node->codegen_node->data.while_node.contains_break) { + expr_return_type = g->builtin_types.entry_unreachable; + } + } + } + + return expr_return_type; } static TypeTableEntry *analyze_break_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, TypeTableEntry *expected_type, AstNode *node) { - if (!context->break_allowed) { + AstNode *loop_node = context->parent_loop_node; + if (loop_node) { + loop_node->codegen_node->data.while_node.contains_break = true; + } else { add_node_error(g, node, - buf_sprintf("'break' expression not in loop")); + buf_sprintf("'break' expression outside loop")); } return g->builtin_types.entry_unreachable; } @@ -1711,9 +1836,9 @@ static TypeTableEntry *analyze_break_expr(CodeGen *g, ImportTableEntry *import, static TypeTableEntry *analyze_continue_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, TypeTableEntry *expected_type, AstNode *node) { - if (!context->break_allowed) { + if (!context->parent_loop_node) { add_node_error(g, node, - buf_sprintf("'continue' expression not in loop")); + buf_sprintf("'continue' expression outside loop")); } return g->builtin_types.entry_unreachable; } diff --git a/src/analyze.hpp b/src/analyze.hpp index 8e3f569302..d14ae714d7 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -244,8 +244,8 @@ struct BlockContext { HashMap variable_table; ZigList cast_expr_alloca_list; ZigList struct_val_expr_alloca_list; - bool break_allowed; - bool next_child_break_allowed; + AstNode *parent_loop_node; + AstNode *next_child_parent_loop_node; LLVMZigDIScope *di_scope; }; @@ -340,6 +340,10 @@ struct ImportNode { ImportTableEntry *import; }; +struct WhileNode { + bool contains_break; +}; + struct CodeGenNode { union { TypeNode type_node; // for NodeTypeType @@ -358,6 +362,7 @@ struct CodeGenNode { IfVarNode if_var_node; // for NodeTypeStructValueExpr ParamDeclNode param_decl_node; // for NodeTypeParamDecl ImportNode import_node; // for NodeTypeUse + WhileNode while_node; // for NodeTypeWhileExpr } data; ExprNode expr_node; // for all the expression nodes }; diff --git a/src/codegen.cpp b/src/codegen.cpp index 3e555a5778..3f4262364e 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1157,29 +1157,52 @@ static LLVMValueRef gen_while_expr(CodeGen *g, AstNode *node) { assert(node->data.while_expr.condition); assert(node->data.while_expr.body); - LLVMBasicBlockRef cond_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileCond"); - LLVMBasicBlockRef body_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileBody"); - LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileEnd"); + if (get_expr_type(node)->id == TypeTableEntryIdUnreachable) { + // generate a forever loop. guarantees no break statements - add_debug_source_node(g, node); - LLVMBuildBr(g->builder, cond_block); + LLVMBasicBlockRef body_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileBody"); - LLVMPositionBuilderAtEnd(g->builder, cond_block); - LLVMValueRef cond_val = gen_expr(g, node->data.while_expr.condition); - add_debug_source_node(g, node->data.while_expr.condition); - LLVMBuildCondBr(g->builder, cond_val, body_block, end_block); + add_debug_source_node(g, node); + LLVMBuildBr(g->builder, body_block); - LLVMPositionBuilderAtEnd(g->builder, body_block); - g->break_block_stack.append(end_block); - g->continue_block_stack.append(cond_block); - gen_expr(g, node->data.while_expr.body); - g->break_block_stack.pop(); - g->continue_block_stack.pop(); - if (get_expr_type(node->data.while_expr.body)->id != TypeTableEntryIdUnreachable) { + LLVMPositionBuilderAtEnd(g->builder, body_block); + g->continue_block_stack.append(body_block); + gen_expr(g, node->data.while_expr.body); + g->continue_block_stack.pop(); + + if (get_expr_type(node->data.while_expr.body)->id != TypeTableEntryIdUnreachable) { + add_debug_source_node(g, node); + LLVMBuildBr(g->builder, body_block); + } + } else { + // generate a normal while loop + + LLVMBasicBlockRef cond_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileCond"); + LLVMBasicBlockRef body_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileBody"); + LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileEnd"); + + add_debug_source_node(g, node); LLVMBuildBr(g->builder, cond_block); + + LLVMPositionBuilderAtEnd(g->builder, cond_block); + LLVMValueRef cond_val = gen_expr(g, node->data.while_expr.condition); + add_debug_source_node(g, node->data.while_expr.condition); + LLVMBuildCondBr(g->builder, cond_val, body_block, end_block); + + LLVMPositionBuilderAtEnd(g->builder, body_block); + g->break_block_stack.append(end_block); + g->continue_block_stack.append(cond_block); + gen_expr(g, node->data.while_expr.body); + g->break_block_stack.pop(); + g->continue_block_stack.pop(); + if (get_expr_type(node->data.while_expr.body)->id != TypeTableEntryIdUnreachable) { + add_debug_source_node(g, node); + LLVMBuildBr(g->builder, cond_block); + } + + LLVMPositionBuilderAtEnd(g->builder, end_block); } - LLVMPositionBuilderAtEnd(g->builder, end_block); return nullptr; } diff --git a/std/rand.zig b/std/rand.zig index 8d30c74c94..0362e2d8bc 100644 --- a/std/rand.zig +++ b/std/rand.zig @@ -67,9 +67,6 @@ pub struct Rand { return start + (rand_val % range); } } - // TODO detect simple constant in while loop and no breaks and turn it into unreachable - // type. then we can remove this unreachable. - unreachable; } fn generate_numbers(r: &Rand) { diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 661100f6b4..c41832e47d 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -683,7 +683,12 @@ pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 { print_str("loop\n"); i += 1; } - return 0; + return f(); +} +fn f() -> i32 { + while (true) { + return 0; + } } )SOURCE", "loop\nloop\nloop\nloop\n"); @@ -1168,13 +1173,13 @@ fn f() { fn f() { break; } - )SOURCE", 1, ".tmp_source.zig:3:5: error: 'break' expression not in loop"); + )SOURCE", 1, ".tmp_source.zig:3:5: error: 'break' expression outside loop"); add_compile_fail_case("invalid continue expression", R"SOURCE( fn f() { continue; } - )SOURCE", 1, ".tmp_source.zig:3:5: error: 'continue' expression not in loop"); + )SOURCE", 1, ".tmp_source.zig:3:5: error: 'continue' expression outside loop"); add_compile_fail_case("invalid maybe type", R"SOURCE( fn f() {