diff --git a/src/all_types.hpp b/src/all_types.hpp index 4b84254c7b..f2f52403d3 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -46,7 +46,7 @@ struct IrExecutable { ZigList basic_block_list; size_t mem_slot_count; size_t next_debug_id; - size_t backward_branch_count; + size_t *backward_branch_count; size_t backward_branch_quota; bool invalid; ZigList all_labels; @@ -968,6 +968,7 @@ struct FnTableEntry { FnAnalState anal_state; IrExecutable ir_executable; IrExecutable analyzed_executable; + size_t prealloc_bbc; AstNode *fn_no_inline_set_node; AstNode *fn_export_set_node; @@ -1311,6 +1312,7 @@ struct IrBasicBlock { size_t debug_id; size_t ref_count; LLVMBasicBlockRef llvm_block; + LLVMBasicBlockRef llvm_exit_block; }; enum IrInstructionId { diff --git a/src/analyze.cpp b/src/analyze.cpp index c913d4a256..2619be5a3d 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -873,43 +873,9 @@ TypeTableEntry *get_underlying_type(TypeTableEntry *type_entry) { } } -static IrInstruction *analyze_const_value(CodeGen *g, Scope *scope, AstNode *node, - TypeTableEntry *expected_type) -{ - IrExecutable ir_executable = {0}; - ir_executable.is_inline = true; - ir_gen(g, node, scope, &ir_executable); - - if (ir_executable.invalid) - return g->invalid_instruction; - - if (g->verbose) { - fprintf(stderr, "\nSource: "); - ast_render(stderr, node, 4); - fprintf(stderr, "\n{ // (IR)\n"); - ir_print(stderr, &ir_executable, 4); - fprintf(stderr, "}\n"); - } - IrExecutable analyzed_executable = {0}; - analyzed_executable.is_inline = true; - analyzed_executable.backward_branch_quota = default_backward_branch_quota; - TypeTableEntry *result_type = ir_analyze(g, &ir_executable, &analyzed_executable, expected_type, node); - if (result_type->id == TypeTableEntryIdInvalid) - return g->invalid_instruction; - - if (g->verbose) { - fprintf(stderr, "{ // (analyzed)\n"); - ir_print(stderr, &analyzed_executable, 4); - fprintf(stderr, "}\n"); - } - - IrInstruction *result = ir_exec_const_result(&analyzed_executable); - if (!result) { - add_node_error(g, node, buf_sprintf("unable to evaluate constant expression")); - return g->invalid_instruction; - } - - return result; +static IrInstruction *analyze_const_value(CodeGen *g, Scope *scope, AstNode *node, TypeTableEntry *type_entry) { + size_t backward_branch_count = 0; + return ir_eval_const_value(g, scope, node, type_entry, &backward_branch_count, default_backward_branch_quota); } static TypeTableEntry *analyze_type_expr(CodeGen *g, Scope *scope, AstNode *node) { @@ -1403,9 +1369,10 @@ static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) { } FnTableEntry *fn_table_entry = allocate(1); + fn_table_entry->analyzed_executable.backward_branch_count = &fn_table_entry->prealloc_bbc; fn_table_entry->analyzed_executable.backward_branch_quota = default_backward_branch_quota; - fn_table_entry->ir_executable.fn_entry = fn_table_entry; fn_table_entry->analyzed_executable.fn_entry = fn_table_entry; + fn_table_entry->ir_executable.fn_entry = fn_table_entry; fn_table_entry->import_entry = import; fn_table_entry->proto_node = proto_node; fn_table_entry->fn_def_node = fn_def_node; diff --git a/src/codegen.cpp b/src/codegen.cpp index 8b8650a546..af62f57ebe 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1758,7 +1758,7 @@ static LLVMValueRef ir_render_phi(CodeGen *g, IrExecutable *executable, IrInstru LLVMBasicBlockRef *incoming_blocks = allocate(instruction->incoming_count); for (size_t i = 0; i < instruction->incoming_count; i += 1) { incoming_values[i] = ir_llvm_value(g, instruction->incoming_values[i]); - incoming_blocks[i] = instruction->incoming_blocks[i]->llvm_block; + incoming_blocks[i] = instruction->incoming_blocks[i]->llvm_exit_block; } LLVMAddIncoming(phi, incoming_values, incoming_blocks, instruction->incoming_count); return phi; @@ -1877,6 +1877,7 @@ static void ir_render(CodeGen *g, FnTableEntry *fn_entry) { continue; instruction->llvm_value = ir_render_instruction(g, executable, instruction); } + current_block->llvm_exit_block = LLVMGetInsertBlock(g->builder); } } diff --git a/src/ir.cpp b/src/ir.cpp index f759c887aa..e7c29d800d 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -6,6 +6,7 @@ */ #include "analyze.hpp" +#include "ast_render.hpp" #include "error.hpp" #include "eval.hpp" #include "ir.hpp" @@ -2814,7 +2815,7 @@ IrInstruction *ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutabl return return_instruction; } -IrInstruction *ir_gen_fn(CodeGen *codegn, FnTableEntry *fn_entry) { +IrInstruction *ir_gen_fn(CodeGen *codegen, FnTableEntry *fn_entry) { assert(fn_entry); IrExecutable *ir_executable = &fn_entry->ir_executable; @@ -2826,18 +2827,30 @@ IrInstruction *ir_gen_fn(CodeGen *codegn, FnTableEntry *fn_entry) { assert(fn_entry->child_scope); Scope *child_scope = fn_entry->child_scope; - return ir_gen(codegn, body_node, child_scope, ir_executable); + return ir_gen(codegen, body_node, child_scope, ir_executable); } static ErrorMsg *ir_add_error(IrAnalyze *ira, IrInstruction *source_instruction, Buf *msg) { + ira->new_irb.exec->invalid = true; return add_node_error(ira->codegen, source_instruction->source_node, msg); } -static IrInstruction *ir_eval_fn(IrAnalyze *ira, IrInstruction *source_instruction, - size_t arg_count, IrInstruction **args) -{ - // TODO count this as part of the backward branch quota - zig_panic("TODO ir_eval_fn"); +static IrInstruction *ir_exec_const_result(IrExecutable *exec) { + if (exec->basic_block_list.length != 1) + return nullptr; + + IrBasicBlock *bb = exec->basic_block_list.at(0); + if (bb->instruction_list.length != 1) + return nullptr; + + IrInstruction *only_inst = bb->instruction_list.at(0); + if (only_inst->id != IrInstructionIdReturn) + return nullptr; + + IrInstructionReturn *ret_inst = (IrInstructionReturn *)only_inst; + IrInstruction *value = ret_inst->value; + assert(value->static_value.special != ConstValSpecialRuntime); + return value; } static bool ir_emit_global_runtime_side_effect(IrAnalyze *ira, IrInstruction *source_instruction) { @@ -3140,14 +3153,26 @@ static TypeTableEntry *ir_unreach_error(IrAnalyze *ira) { return ira->codegen->builtin_types.entry_unreachable; } +static bool ir_emit_backward_branch(IrAnalyze *ira, IrInstruction *source_instruction) { + size_t *bbc = ira->new_irb.exec->backward_branch_count; + size_t quota = ira->new_irb.exec->backward_branch_quota; + + // If we're already over quota, we've already given an error message for this. + if (*bbc > quota) + return false; + + *bbc += 1; + if (*bbc > quota) { + ir_add_error(ira, source_instruction, buf_sprintf("evaluation exceeded %zu backwards branches", quota)); + return false; + } + return true; +} + static TypeTableEntry *ir_inline_bb(IrAnalyze *ira, IrInstruction *source_instruction, IrBasicBlock *old_bb) { if (old_bb->debug_id <= ira->old_irb.current_basic_block->debug_id) { - ira->new_irb.exec->backward_branch_count += 1; - if (ira->new_irb.exec->backward_branch_count > ira->new_irb.exec->backward_branch_quota) { - add_node_error(ira->codegen, source_instruction->source_node, - buf_sprintf("evaluation exceeded %zu backwards branches", ira->new_irb.exec->backward_branch_quota)); + if (!ir_emit_backward_branch(ira, source_instruction)) return ir_unreach_error(ira); - } } ir_start_bb(ira, old_bb, ira->old_irb.current_basic_block); @@ -3216,13 +3241,90 @@ static TypeTableEntry *ir_analyze_const_usize(IrAnalyze *ira, IrInstruction *ins static ConstExprValue *ir_resolve_const(IrAnalyze *ira, IrInstruction *value) { if (value->static_value.special != ConstValSpecialStatic) { - add_node_error(ira->codegen, value->source_node, - buf_sprintf("unable to evaluate constant expression")); + ir_add_error(ira, value, buf_sprintf("unable to evaluate constant expression")); return nullptr; } return &value->static_value; } +IrInstruction *ir_eval_const_value(CodeGen *codegen, Scope *scope, AstNode *node, + TypeTableEntry *expected_type, size_t *backward_branch_count, size_t backward_branch_quota) +{ + IrExecutable ir_executable = {0}; + ir_executable.is_inline = true; + ir_gen(codegen, node, scope, &ir_executable); + + if (ir_executable.invalid) + return codegen->invalid_instruction; + + if (codegen->verbose) { + fprintf(stderr, "\nSource: "); + ast_render(stderr, node, 4); + fprintf(stderr, "\n{ // (IR)\n"); + ir_print(stderr, &ir_executable, 4); + fprintf(stderr, "}\n"); + } + IrExecutable analyzed_executable = {0}; + analyzed_executable.is_inline = true; + analyzed_executable.backward_branch_count = backward_branch_count; + analyzed_executable.backward_branch_quota = backward_branch_quota; + TypeTableEntry *result_type = ir_analyze(codegen, &ir_executable, &analyzed_executable, expected_type, node); + if (result_type->id == TypeTableEntryIdInvalid) + return codegen->invalid_instruction; + + if (codegen->verbose) { + fprintf(stderr, "{ // (analyzed)\n"); + ir_print(stderr, &analyzed_executable, 4); + fprintf(stderr, "}\n"); + } + + IrInstruction *result = ir_exec_const_result(&analyzed_executable); + if (!result) { + add_node_error(codegen, node, buf_sprintf("unable to evaluate constant expression")); + return codegen->invalid_instruction; + } + + return result; +} + +static IrInstruction *ir_eval_fn(IrAnalyze *ira, IrInstruction *source_instruction, + FnTableEntry *fn_entry, IrInstruction **args) +{ + if (!fn_entry) { + ir_add_error(ira, source_instruction, + buf_sprintf("unable to evaluate constant expression")); + return ira->codegen->invalid_instruction; + } + + if (!ir_emit_backward_branch(ira, source_instruction)) + return ira->codegen->invalid_instruction; + + TypeTableEntry *fn_type = fn_entry->type_entry; + FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id; + + // Fork a scope of the function with known values for the parameters. + + Scope *exec_scope = &fn_entry->fndef_scope->base; + for (size_t i = 0; i < fn_type_id->param_count; i += 1) { + AstNode *param_decl_node = fn_entry->proto_node->data.fn_proto.params.at(i); + Buf *param_name = param_decl_node->data.param_decl.name; + IrInstruction *arg = args[i]; + ConstExprValue *arg_val = ir_resolve_const(ira, arg); + if (!arg_val) + return ira->codegen->invalid_instruction; + + VariableTableEntry *var = add_variable(ira->codegen, param_decl_node, exec_scope, param_name, + arg->type_entry, true, arg_val); + exec_scope = var->child_scope; + } + + // Analyze the fn body block like any other constant expression. + + AstNode *body_node = fn_entry->fn_def_node->data.fn_def.body; + return ir_eval_const_value(ira->codegen, exec_scope, body_node, fn_type_id->return_type, + ira->new_irb.exec->backward_branch_count, ira->new_irb.exec->backward_branch_quota); +} + static TypeTableEntry *ir_resolve_type_lval(IrAnalyze *ira, IrInstruction *type_value, LValPurpose lval) { if (lval != LValPurposeNone) zig_panic("TODO"); @@ -4238,7 +4340,8 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal return ira->codegen->builtin_types.entry_invalid; if (is_inline) { - IrInstruction *result = ir_eval_fn(ira, &call_instruction->base, call_param_count, casted_args); + assert(call_param_count == fn_type_id->param_count); + IrInstruction *result = ir_eval_fn(ira, &call_instruction->base, fn_entry, casted_args); if (result->type_entry->id == TypeTableEntryIdInvalid) return ira->codegen->builtin_types.entry_invalid; @@ -4252,9 +4355,9 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal fn_entry, fn_ref, call_param_count, casted_args); if (type_has_bits(return_type) && handle_is_ptr(return_type)) { - FnTableEntry *fn_entry = exec_fn_entry(ira->new_irb.exec); - assert(fn_entry); - fn_entry->alloca_list.append(new_call_instruction); + FnTableEntry *owner_fn = exec_fn_entry(ira->new_irb.exec); + assert(owner_fn); + owner_fn->alloca_list.append(new_call_instruction); } return ir_finish_anal(ira, return_type); @@ -4782,13 +4885,13 @@ static TypeTableEntry *ir_analyze_var_ptr(IrAnalyze *ira, IrInstruction *instruc ConstExprValue *mem_slot = nullptr; FnTableEntry *fn_entry = scope_fn_entry(var->parent_scope); - if (fn_entry) { + if (var->src_is_const && var->value) { + mem_slot = var->value; + assert(mem_slot->special != ConstValSpecialRuntime); + } else if (fn_entry) { // TODO once the analyze code is fully ported over to IR we won't need this SIZE_MAX thing. if (var->mem_slot_index != SIZE_MAX) mem_slot = &ira->exec_context.mem_slot_list[var->mem_slot_index]; - } else if (var->src_is_const) { - mem_slot = var->value; - assert(mem_slot->special != ConstValSpecialRuntime); } if (mem_slot && mem_slot->special != ConstValSpecialRuntime) { @@ -6481,24 +6584,6 @@ bool ir_has_side_effects(IrInstruction *instruction) { zig_unreachable(); } -IrInstruction *ir_exec_const_result(IrExecutable *exec) { - if (exec->basic_block_list.length != 1) - return nullptr; - - IrBasicBlock *bb = exec->basic_block_list.at(0); - if (bb->instruction_list.length != 1) - return nullptr; - - IrInstruction *only_inst = bb->instruction_list.at(0); - if (only_inst->id != IrInstructionIdReturn) - return nullptr; - - IrInstructionReturn *ret_inst = (IrInstructionReturn *)only_inst; - IrInstruction *value = ret_inst->value; - assert(value->static_value.special != ConstValSpecialRuntime); - return value; -} - // TODO port over all this commented out code into new IR way of doing things //static TypeTableEntry *analyze_min_max_value(CodeGen *g, ImportTableEntry *import, BlockContext *context, diff --git a/src/ir.hpp b/src/ir.hpp index 5f0e4b5ad9..4fcf3d34ab 100644 --- a/src/ir.hpp +++ b/src/ir.hpp @@ -13,11 +13,12 @@ IrInstruction *ir_gen(CodeGen *g, AstNode *node, Scope *scope, IrExecutable *ir_executable); IrInstruction *ir_gen_fn(CodeGen *g, FnTableEntry *fn_entry); +IrInstruction *ir_eval_const_value(CodeGen *codegen, Scope *scope, AstNode *node, + TypeTableEntry *expected_type, size_t *backward_branch_count, size_t backward_branch_quota); + TypeTableEntry *ir_analyze(CodeGen *g, IrExecutable *old_executable, IrExecutable *new_executable, TypeTableEntry *expected_type, AstNode *expected_type_source_node); -IrInstruction *ir_exec_const_result(IrExecutable *exec); - bool ir_has_side_effects(IrInstruction *instruction); ConstExprValue *const_ptr_pointee(ConstExprValue *const_val); diff --git a/test/self_hosted2.zig b/test/self_hosted2.zig index 980e14d65c..605720be1a 100644 --- a/test/self_hosted2.zig +++ b/test/self_hosted2.zig @@ -96,6 +96,22 @@ fn testStructStatic() { assert(result == 7); } +const should_be_11 = FooA.add(5, 6); +fn testStaticFnEval() { + assert(should_be_11 == 11); +} + +fn fib(x: i32) -> i32 { + if (x < 2) x else fib(x - 1) + fib(x - 2) +} + +const fib_7 = fib(7); + +fn testCompileTimeFib() { + assert(fib_7 == 13); +} + + fn assert(ok: bool) { if (!ok) @unreachable(); @@ -111,6 +127,8 @@ fn runAllTests() { testNamespaceFnCall(); gotoAndLabels(); testStructStatic(); + testStaticFnEval(); + testCompileTimeFib(); } export nakedcc fn _start() -> unreachable {