diff --git a/doc/langref.md b/doc/langref.md index 439ef6a500..334d0ad3d1 100644 --- a/doc/langref.md +++ b/doc/langref.md @@ -140,7 +140,7 @@ ArrayAccessExpression : token(LBracket) Expression token(RBracket) PrefixOp : token(Not) | token(Dash) | token(Tilde) -PrimaryExpression : token(Number) | token(String) | KeywordLiteral | GroupedExpression | token(Symbol) | Goto +PrimaryExpression : token(Number) | token(String) | KeywordLiteral | GroupedExpression | token(Symbol) | Goto | BlockExpression Goto: token(Goto) token(Symbol) diff --git a/src/codegen.cpp b/src/codegen.cpp index e5f938a52c..881ef2be18 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -478,8 +478,8 @@ static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) { assert(node->type == NodeTypeBinOpExpr); LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1); + LLVMBasicBlockRef post_val1_block = LLVMGetInsertBlock(g->builder); - LLVMBasicBlockRef orig_block = LLVMGetInsertBlock(g->builder); // block for when val1 == true LLVMBasicBlockRef true_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoolAndTrue"); // block for when val1 == false (don't even evaluate the second part) @@ -490,6 +490,8 @@ static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) { LLVMPositionBuilderAtEnd(g->builder, true_block); LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2); + LLVMBasicBlockRef post_val2_block = LLVMGetInsertBlock(g->builder); + add_debug_source_node(g, node); LLVMBuildBr(g->builder, false_block); @@ -497,7 +499,7 @@ static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) { add_debug_source_node(g, node); LLVMValueRef phi = LLVMBuildPhi(g->builder, LLVMInt1Type(), ""); LLVMValueRef incoming_values[2] = {val1, val2}; - LLVMBasicBlockRef incoming_blocks[2] = {orig_block, true_block}; + LLVMBasicBlockRef incoming_blocks[2] = {post_val1_block, post_val2_block}; LLVMAddIncoming(phi, incoming_values, incoming_blocks, 2); return phi; @@ -507,8 +509,7 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) { assert(expr_node->type == NodeTypeBinOpExpr); LLVMValueRef val1 = gen_expr(g, expr_node->data.bin_op_expr.op1); - - LLVMBasicBlockRef orig_block = LLVMGetInsertBlock(g->builder); + LLVMBasicBlockRef post_val1_block = LLVMGetInsertBlock(g->builder); // block for when val1 == false LLVMBasicBlockRef false_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoolOrFalse"); @@ -520,6 +521,9 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) { LLVMPositionBuilderAtEnd(g->builder, false_block); LLVMValueRef val2 = gen_expr(g, expr_node->data.bin_op_expr.op2); + + LLVMBasicBlockRef post_val2_block = LLVMGetInsertBlock(g->builder); + add_debug_source_node(g, expr_node); LLVMBuildBr(g->builder, true_block); @@ -527,7 +531,7 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) { add_debug_source_node(g, expr_node); LLVMValueRef phi = LLVMBuildPhi(g->builder, LLVMInt1Type(), ""); LLVMValueRef incoming_values[2] = {val1, val2}; - LLVMBasicBlockRef incoming_blocks[2] = {orig_block, false_block}; + LLVMBasicBlockRef incoming_blocks[2] = {post_val1_block, post_val2_block}; LLVMAddIncoming(phi, incoming_values, incoming_blocks, 2); return phi; diff --git a/src/parser.cpp b/src/parser.cpp index 7b20bd7811..8b7ae50ae4 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -572,6 +572,7 @@ static void ast_invalid_token_error(ParseContext *pc, Token *token) { static AstNode *ast_parse_expression(ParseContext *pc, int *token_index, bool mandatory); static AstNode *ast_parse_block(ParseContext *pc, int *token_index, bool mandatory); static AstNode *ast_parse_if_expr(ParseContext *pc, int *token_index, bool mandatory); +static AstNode *ast_parse_block_expr(ParseContext *pc, int *token_index, bool mandatory); static void ast_expect_token(ParseContext *pc, Token *token, TokenId token_id) { if (token->id != token_id) { @@ -809,7 +810,7 @@ static AstNode *ast_parse_grouped_expr(ParseContext *pc, int *token_index, bool } /* -PrimaryExpression : token(Number) | token(String) | token(Unreachable) | GroupedExpression | token(Symbol) | Goto +PrimaryExpression : token(Number) | token(String) | KeywordLiteral | GroupedExpression | token(Symbol) | Goto | BlockExpression */ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool mandatory) { Token *token = &pc->tokens->at(*token_index); @@ -864,6 +865,11 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool return grouped_expr_node; } + AstNode *block_expr_node = ast_parse_block_expr(pc, token_index, false); + if (block_expr_node) { + return block_expr_node; + } + if (!mandatory) return nullptr; diff --git a/test/run_tests.cpp b/test/run_tests.cpp index a737298ef8..d4e0cc3314 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -431,6 +431,29 @@ export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 { } )SOURCE", "OK\n"); + add_simple_case("short circuit", R"SOURCE( +use "std.zig"; + +export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 { + if true || { print_str("BAD 1\n" as string); false } { + print_str("OK 1\n" as string); + } + if false || { print_str("OK 2\n" as string); false } { + print_str("BAD 2\n" as string); + } + + if true && { print_str("OK 3\n" as string); false } { + print_str("BAD 3\n" as string); + } + if false && { print_str("BAD 4\n" as string); false } { + } else { + print_str("OK 4\n" as string); + } + + return 0; +} + )SOURCE", "OK 1\nOK 2\nOK 3\nOK 4\n"); + } static void add_compile_failure_test_cases(void) {