From 65a03c5859e57820d2c28ad2952dda3fd4ac7d9c Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 6 Feb 2016 16:36:49 -0700 Subject: [PATCH] implement %defer and ?defer see #110 --- src/all_types.hpp | 10 +++++++ src/analyze.cpp | 46 ++++++------------------------- src/codegen.cpp | 68 +++++++++++++++++++++++++++++++++++++++------- test/run_tests.cpp | 36 ++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 47 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 9b5fc79575..7378e4d6e4 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -78,8 +78,18 @@ struct ConstExprValue { } data; }; +enum ReturnKnowledge { + ReturnKnowledgeUnknown, + ReturnKnowledgeKnownError, + ReturnKnowledgeKnownNonError, + ReturnKnowledgeKnownNull, + ReturnKnowledgeKnownNonNull, + ReturnKnowledgeSkipDefers, +}; + struct Expr { TypeTableEntry *type_entry; + ReturnKnowledge return_knowledge; LLVMValueRef const_llvm_val; ConstExprValue const_val; diff --git a/src/analyze.cpp b/src/analyze.cpp index 03e5ad1899..06bb1b612e 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -3693,11 +3693,13 @@ static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, B // explicit cast from child type of maybe type to maybe type if (wanted_type->id == TypeTableEntryIdMaybe) { if (types_match_const_cast_only(wanted_type->data.maybe.child_type, actual_type)) { + get_resolved_expr(node)->return_knowledge = ReturnKnowledgeKnownNonNull; return resolve_cast(g, context, node, expr_node, wanted_type, CastOpMaybeWrap, true); } else if (actual_type->id == TypeTableEntryIdNumLitInt || actual_type->id == TypeTableEntryIdNumLitFloat) { if (num_lit_fits_in_other_type(g, expr_node, wanted_type->data.maybe.child_type)) { + get_resolved_expr(node)->return_knowledge = ReturnKnowledgeKnownNonNull; return resolve_cast(g, context, node, expr_node, wanted_type, CastOpMaybeWrap, true); } else { return g->builtin_types.entry_invalid; @@ -3708,11 +3710,13 @@ static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, B // explicit cast from child type of error type to error type if (wanted_type->id == TypeTableEntryIdErrorUnion) { if (types_match_const_cast_only(wanted_type->data.error.child_type, actual_type)) { + get_resolved_expr(node)->return_knowledge = ReturnKnowledgeKnownNonError; return resolve_cast(g, context, node, expr_node, wanted_type, CastOpErrorWrap, true); } else if (actual_type->id == TypeTableEntryIdNumLitInt || actual_type->id == TypeTableEntryIdNumLitFloat) { if (num_lit_fits_in_other_type(g, expr_node, wanted_type->data.error.child_type)) { + get_resolved_expr(node)->return_knowledge = ReturnKnowledgeKnownNonError; return resolve_cast(g, context, node, expr_node, wanted_type, CastOpErrorWrap, true); } else { return g->builtin_types.entry_invalid; @@ -3724,6 +3728,7 @@ static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, B if (wanted_type->id == TypeTableEntryIdErrorUnion && actual_type->id == TypeTableEntryIdPureError) { + get_resolved_expr(node)->return_knowledge = ReturnKnowledgeKnownError; return resolve_cast(g, context, node, expr_node, wanted_type, CastOpPureErrorWrap, false); } @@ -4602,44 +4607,11 @@ static TypeTableEntry *analyze_defer(CodeGen *g, ImportTableEntry *import, Block node->data.defer.child_block = new_block_context(node, parent_context); - switch (node->data.defer.kind) { - case ReturnKindUnconditional: - { - TypeTableEntry *resolved_type = analyze_expression(g, import, parent_context, nullptr, - node->data.defer.expr); - validate_voided_expr(g, node->data.defer.expr, resolved_type); + TypeTableEntry *resolved_type = analyze_expression(g, import, parent_context, nullptr, + node->data.defer.expr); + validate_voided_expr(g, node->data.defer.expr, resolved_type); - return g->builtin_types.entry_void; - } - case ReturnKindError: - { - TypeTableEntry *resolved_type = analyze_expression(g, import, parent_context, nullptr, - node->data.defer.expr); - if (resolved_type->id == TypeTableEntryIdInvalid) { - // OK - } else if (resolved_type->id == TypeTableEntryIdErrorUnion) { - // OK - } else { - add_node_error(g, node->data.defer.expr, - buf_sprintf("expected error type, got '%s'", buf_ptr(&resolved_type->name))); - } - return g->builtin_types.entry_void; - } - case ReturnKindMaybe: - { - TypeTableEntry *resolved_type = analyze_expression(g, import, parent_context, nullptr, - node->data.defer.expr); - if (resolved_type->id == TypeTableEntryIdInvalid) { - // OK - } else if (resolved_type->id == TypeTableEntryIdMaybe) { - // OK - } else { - add_node_error(g, node->data.defer.expr, - buf_sprintf("expected maybe type, got '%s'", buf_ptr(&resolved_type->name))); - } - return g->builtin_types.entry_void; - } - } + return g->builtin_types.entry_void; } static TypeTableEntry *analyze_string_literal_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, diff --git a/src/codegen.cpp b/src/codegen.cpp index 4cab9fa13d..76364833dc 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -483,6 +483,7 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) { } case CastOpPureErrorWrap: assert(wanted_type->id == TypeTableEntryIdErrorUnion); + if (!type_has_bits(wanted_type->data.error.child_type)) { return expr_val; } else { @@ -1593,18 +1594,48 @@ static LLVMValueRef gen_unwrap_err_expr(CodeGen *g, AstNode *node) { return phi; } -static void gen_defers_for_block(CodeGen *g, BlockContext *inner_block, BlockContext *outer_block) { +static void gen_defers_for_block(CodeGen *g, BlockContext *inner_block, BlockContext *outer_block, + bool gen_error_defers, bool gen_maybe_defers) +{ while (inner_block != outer_block) { - if (inner_block->node->type == NodeTypeDefer) { + if (inner_block->node->type == NodeTypeDefer && + ((inner_block->node->data.defer.kind == ReturnKindUnconditional) || + (gen_error_defers && inner_block->node->data.defer.kind == ReturnKindError) || + (gen_maybe_defers && inner_block->node->data.defer.kind == ReturnKindMaybe))) + { gen_expr(g, inner_block->node->data.defer.expr); } inner_block = inner_block->parent; } } -static LLVMValueRef gen_return(CodeGen *g, AstNode *source_node, LLVMValueRef value) { - gen_defers_for_block(g, source_node->block_context, - source_node->block_context->fn_entry->fn_def_node->block_context); +static int get_conditional_defer_count(BlockContext *inner_block, BlockContext *outer_block) { + int result = 0; + while (inner_block != outer_block) { + if (inner_block->node->type == NodeTypeDefer && + (inner_block->node->data.defer.kind == ReturnKindError || + inner_block->node->data.defer.kind == ReturnKindMaybe)) + { + result += 1; + } + inner_block = inner_block->parent; + } + return result; +} + +static LLVMValueRef gen_return(CodeGen *g, AstNode *source_node, LLVMValueRef value, ReturnKnowledge rk) { + BlockContext *defer_inner_block = source_node->block_context; + BlockContext *defer_outer_block = source_node->block_context->fn_entry->fn_def_node->block_context; + if (rk == ReturnKnowledgeUnknown) { + if (get_conditional_defer_count(defer_inner_block, defer_outer_block) > 0) { + // generate branching code that checks the return value and generates defers + // if the return value is error + zig_panic("TODO"); + } + } else if (rk != ReturnKnowledgeSkipDefers) { + gen_defers_for_block(g, defer_inner_block, defer_outer_block, + rk == ReturnKnowledgeKnownError, rk == ReturnKnowledgeKnownNull); + } TypeTableEntry *return_type = g->cur_fn->type_entry->data.fn.fn_type_id.return_type; if (handle_is_ptr(return_type)) { @@ -1628,7 +1659,23 @@ static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) { switch (node->data.return_expr.kind) { case ReturnKindUnconditional: { - return gen_return(g, node, value); + Expr *expr = get_resolved_expr(param_node); + if (expr->const_val.ok) { + if (value_type->id == TypeTableEntryIdErrorUnion) { + if (expr->const_val.data.x_err.err) { + expr->return_knowledge = ReturnKnowledgeKnownError; + } else { + expr->return_knowledge = ReturnKnowledgeKnownNonError; + } + } else if (value_type->id == TypeTableEntryIdMaybe) { + if (expr->const_val.data.x_maybe) { + expr->return_knowledge = ReturnKnowledgeKnownNonNull; + } else { + expr->return_knowledge = ReturnKnowledgeKnownNull; + } + } + } + return gen_return(g, node, value, expr->return_knowledge); } case ReturnKindError: { @@ -1653,7 +1700,7 @@ static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) { LLVMPositionBuilderAtEnd(g->builder, return_block); TypeTableEntry *return_type = g->cur_fn->type_entry->data.fn.fn_type_id.return_type; if (return_type->id == TypeTableEntryIdPureError) { - gen_return(g, node, err_val); + gen_return(g, node, err_val, ReturnKnowledgeKnownError); } else if (return_type->id == TypeTableEntryIdErrorUnion) { if (type_has_bits(return_type->data.error.child_type)) { assert(g->cur_ret_ptr); @@ -1663,7 +1710,7 @@ static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) { LLVMBuildStore(g->builder, err_val, tag_ptr); LLVMBuildRetVoid(g->builder); } else { - gen_return(g, node, err_val); + gen_return(g, node, err_val, ReturnKnowledgeKnownError); } } else { zig_unreachable(); @@ -1834,10 +1881,11 @@ static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *i return nullptr; } - gen_defers_for_block(g, block_node->data.block.nested_block, block_node->data.block.child_block); + gen_defers_for_block(g, block_node->data.block.nested_block, block_node->data.block.child_block, + false, false); if (implicit_return_type) { - return gen_return(g, block_node, return_value); + return gen_return(g, block_node, return_value, ReturnKnowledgeSkipDefers); } else { return return_value; } diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 50fa09f5a9..bc3dcefeed 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -1545,6 +1545,42 @@ pub fn main(args: [][]u8) -> %void { } )SOURCE", "before\ndefer2\ndefer1\n"); + + add_simple_case("%defer and it fails", R"SOURCE( +import "std.zig"; +pub fn main(args: [][]u8) -> %void { + do_test() %% return; +} +fn do_test() -> %void { + %%stdout.printf("before\n"); + defer %%stdout.printf("defer1\n"); + %defer %%stdout.printf("deferErr\n"); + %return its_gonna_fail(); + defer %%stdout.printf("defer3\n"); + %%stdout.printf("after\n"); +} +error IToldYouItWouldFail; +fn its_gonna_fail() -> %void { + return error.IToldYouItWouldFail; +} + )SOURCE", "before\ndeferErr\ndefer1\n"); + + + add_simple_case("%defer and it passes", R"SOURCE( +import "std.zig"; +pub fn main(args: [][]u8) -> %void { + do_test() %% return; +} +fn do_test() -> %void { + %%stdout.printf("before\n"); + defer %%stdout.printf("defer1\n"); + %defer %%stdout.printf("deferErr\n"); + %return its_gonna_pass(); + defer %%stdout.printf("defer3\n"); + %%stdout.printf("after\n"); +} +fn its_gonna_pass() -> %void { } + )SOURCE", "before\nafter\ndefer3\ndefer1\n"); }