From 419652ee8ff709deae988603ecca2c335599d969 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 24 Jan 2016 18:34:50 -0700 Subject: [PATCH] ability to return structs byvalue from functions closes #57 --- src/all_types.hpp | 7 ++- src/analyze.cpp | 122 ++++++++++++++++++++++++++++++--------------- src/codegen.cpp | 116 ++++++++++++++++++++++++++++++------------ test/run_tests.cpp | 21 ++++++++ 4 files changed, 193 insertions(+), 73 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 89d3fd8ecf..8ab24f33d2 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -214,6 +214,9 @@ struct AstNodeParamDecl { // populated by semantic analyzer VariableTableEntry *variable; + bool is_byval; + int src_index; + int gen_index; }; struct AstNodeBlock { @@ -794,7 +797,8 @@ struct TypeTableEntryEnum { }; struct TypeTableEntryFn { - TypeTableEntry *return_type; + TypeTableEntry *src_return_type; + TypeTableEntry *gen_return_type; TypeTableEntry **param_types; int src_param_count; LLVMTypeRef raw_type_ref; @@ -989,6 +993,7 @@ struct CodeGen { OutType out_type; FnTableEntry *cur_fn; + LLVMValueRef cur_ret_ptr; // TODO remove this in favor of get_resolved_expr(expr_node)->context BlockContext *cur_block_context; ZigList break_block_stack; diff --git a/src/analyze.cpp b/src/analyze.cpp index bc4f463dd8..28bce3d52b 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -424,6 +424,7 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t AstNodeFnProto *fn_proto = &node->data.fn_proto; TypeTableEntry *fn_type = new_type_table_entry(TypeTableEntryIdFn); + fn_table_entry->type_entry = fn_type; fn_type->data.fn.calling_convention = fn_table_entry->internal_linkage ? LLVMFastCallConv : LLVMCCallConv; for (int i = 0; i < fn_proto->directives->length; i += 1) { @@ -452,22 +453,18 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t } int src_param_count = node->data.fn_proto.params.length; - fn_type->size_in_bits = g->pointer_size_bytes * 8; fn_type->align_in_bits = g->pointer_size_bytes * 8; fn_type->data.fn.src_param_count = src_param_count; fn_type->data.fn.param_types = allocate(src_param_count); - fn_table_entry->type_entry = fn_type; - + // first, analyze the parameters and return type in order they appear in + // source code in order for error messages to be in the best order. buf_resize(&fn_type->name, 0); const char *export_str = fn_table_entry->internal_linkage ? "" : "export "; const char *inline_str = fn_table_entry->is_inline ? "inline " : ""; const char *naked_str = fn_type->data.fn.is_naked ? "naked " : ""; buf_appendf(&fn_type->name, "%s%s%sfn(", export_str, inline_str, naked_str); - int gen_param_count = 0; - LLVMTypeRef *gen_param_types = allocate(src_param_count); - LLVMZigDIType **param_di_types = allocate(1 + src_param_count); for (int i = 0; i < src_param_count; i += 1) { AstNode *child = node->data.fn_proto.params.at(i); assert(child->type == NodeTypeParamDecl); @@ -475,6 +472,54 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t child->data.param_decl.type); fn_type->data.fn.param_types[i] = type_entry; + const char *comma = (i == 0) ? "" : ", "; + buf_appendf(&fn_type->name, "%s%s", comma, buf_ptr(&type_entry->name)); + } + + TypeTableEntry *return_type = analyze_type_expr(g, import, import->block_context, + node->data.fn_proto.return_type); + fn_type->data.fn.src_return_type = return_type; + if (return_type->id == TypeTableEntryIdInvalid) { + fn_proto->skip = true; + } + fn_type->data.fn.is_var_args = fn_proto->is_var_args; + if (fn_proto->is_var_args) { + const char *comma = (src_param_count == 0) ? "" : ", "; + buf_appendf(&fn_type->name, "%s...", comma); + } + + buf_appendf(&fn_type->name, ")"); + if (return_type->id != TypeTableEntryIdVoid) { + buf_appendf(&fn_type->name, " %s", buf_ptr(&return_type->name)); + } + + + // next, loop over the parameters again and compute debug information + // and codegen information + bool first_arg_return = handle_is_ptr(return_type); + // +1 for maybe making the first argument the return value + LLVMTypeRef *gen_param_types = allocate(1 + src_param_count); + // +1 because 0 is the return type and +1 for maybe making first arg ret val + LLVMZigDIType **param_di_types = allocate(2 + src_param_count); + param_di_types[0] = return_type->di_type; + int gen_param_index = 0; + TypeTableEntry *gen_return_type; + if (first_arg_return) { + TypeTableEntry *gen_type = get_pointer_to_type(g, return_type, false); + gen_param_types[gen_param_index] = gen_type->type_ref; + gen_param_index += 1; + // after the gen_param_index += 1 because 0 is the return type + param_di_types[gen_param_index] = gen_type->di_type; + gen_return_type = g->builtin_types.entry_void; + } else { + gen_return_type = return_type; + } + fn_type->data.fn.gen_return_type = gen_return_type; + for (int i = 0; i < src_param_count; i += 1) { + AstNode *child = node->data.fn_proto.params.at(i); + assert(child->type == NodeTypeParamDecl); + TypeTableEntry *type_entry = fn_type->data.fn.param_types[i]; + if (type_entry->id == TypeTableEntryIdUnreachable) { add_node_error(g, child->data.param_decl.type, buf_sprintf("parameter of type 'unreachable' not allowed")); @@ -483,39 +528,29 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t fn_proto->skip = true; } + child->data.param_decl.src_index = i; + child->data.param_decl.gen_index = -1; + if (!fn_proto->skip && type_entry->size_in_bits > 0) { - const char *comma = (gen_param_count == 0) ? "" : ", "; - buf_appendf(&fn_type->name, "%s%s", comma, buf_ptr(&type_entry->name)); - TypeTableEntry *gen_type = handle_is_ptr(type_entry) ? - get_pointer_to_type(g, type_entry, true) : type_entry; - gen_param_types[gen_param_count] = gen_type->type_ref; + TypeTableEntry *gen_type; + if (handle_is_ptr(type_entry)) { + gen_type = get_pointer_to_type(g, type_entry, true); + child->data.param_decl.is_byval = true; + } else { + gen_type = type_entry; + } + gen_param_types[gen_param_index] = gen_type->type_ref; + child->data.param_decl.gen_index = gen_param_index; - gen_param_count += 1; + gen_param_index += 1; - // after the gen_param_count += 1 because 0 is the return type - param_di_types[gen_param_count] = gen_type->di_type; + // after the gen_param_index += 1 because 0 is the return type + param_di_types[gen_param_index] = gen_type->di_type; } } - fn_type->data.fn.gen_param_count = gen_param_count; - fn_type->data.fn.is_var_args = fn_proto->is_var_args; - if (fn_proto->is_var_args) { - const char *comma = (gen_param_count == 0) ? "" : ", "; - buf_appendf(&fn_type->name, "%s...", comma); - } - - TypeTableEntry *return_type = analyze_type_expr(g, import, import->block_context, - node->data.fn_proto.return_type); - fn_type->data.fn.return_type = return_type; - if (return_type->id == TypeTableEntryIdInvalid) { - fn_proto->skip = true; - } - - buf_appendf(&fn_type->name, ")"); - if (return_type->id != TypeTableEntryIdVoid) { - buf_appendf(&fn_type->name, " %s", buf_ptr(&return_type->name)); - } + fn_type->data.fn.gen_param_count = gen_param_index; if (fn_proto->skip) { return; @@ -526,12 +561,11 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t fn_type = table_entry->value; fn_table_entry->type_entry = fn_type; } else { - fn_type->data.fn.raw_type_ref = LLVMFunctionType(return_type->type_ref, gen_param_types, gen_param_count, - fn_type->data.fn.is_var_args); + fn_type->data.fn.raw_type_ref = LLVMFunctionType(gen_return_type->type_ref, + gen_param_types, gen_param_index, fn_type->data.fn.is_var_args); fn_type->type_ref = LLVMPointerType(fn_type->data.fn.raw_type_ref, 0); - param_di_types[0] = return_type->di_type; fn_type->di_type = LLVMZigCreateSubroutineType(g->dbuilder, import->di_file, - param_di_types, gen_param_count + 1, 0); + param_di_types, gen_param_index + 1, 0); import->fn_type_table.put(&fn_type->name, fn_type); } @@ -550,7 +584,7 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t LLVMSetLinkage(fn_table_entry->fn_value, fn_table_entry->internal_linkage ? LLVMInternalLinkage : LLVMExternalLinkage); - if (return_type->id == TypeTableEntryIdUnreachable) { + if (gen_return_type->id == TypeTableEntryIdUnreachable) { LLVMAddFunctionAttr(fn_table_entry->fn_value, LLVMNoReturnAttribute); } LLVMSetFunctionCallConv(fn_table_entry->fn_value, fn_type->data.fn.calling_convention); @@ -3320,7 +3354,13 @@ static TypeTableEntry *analyze_fn_call_raw(CodeGen *g, ImportTableEntry *import, analyze_expression(g, import, context, expected_param_type, child); } - return unwrapped_node_type(fn_proto->return_type); + TypeTableEntry *return_type = unwrapped_node_type(fn_proto->return_type); + + if (handle_is_ptr(return_type)) { + context->cast_alloca_list.append(node); + } + + return return_type; } static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, @@ -3417,7 +3457,7 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import // function pointer if (invoke_type_entry->id == TypeTableEntryIdFn) { - return invoke_type_entry->data.fn.return_type; + return invoke_type_entry->data.fn.src_return_type; } else { add_node_error(g, fn_ref_expr, buf_sprintf("type '%s' not a function", buf_ptr(&invoke_type_entry->name))); @@ -3681,6 +3721,7 @@ static TypeTableEntry *analyze_return_expr(CodeGen *g, ImportTableEntry *import, } else { add_node_error(g, node->data.return_expr.expr, buf_sprintf("expected error type, got '%s'", buf_ptr(&resolved_type->name))); + return g->builtin_types.entry_invalid; } } case ReturnKindMaybe: @@ -4711,6 +4752,7 @@ bool handle_is_ptr(TypeTableEntry *type_entry) { return type_entry->id == TypeTableEntryIdStruct || (type_entry->id == TypeTableEntryIdEnum && type_entry->data.enumeration.gen_field_count != 0) || type_entry->id == TypeTableEntryIdMaybe || - type_entry->id == TypeTableEntryIdArray; + type_entry->id == TypeTableEntryIdArray || + (type_entry->id == TypeTableEntryIdErrorUnion && type_entry->data.error.child_type->size_in_bits > 0); } diff --git a/src/codegen.cpp b/src/codegen.cpp index 10395dd892..5ee40c144f 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -81,11 +81,6 @@ static TypeTableEntry *get_type_for_type_node(AstNode *node) { return const_val->data.x_type; } -static bool is_param_decl_type_void(CodeGen *g, AstNode *param_decl_node) { - assert(param_decl_node->type == NodeTypeParamDecl); - return get_type_for_type_node(param_decl_node->data.param_decl.type)->size_in_bits == 0; -} - static void add_debug_source_node(CodeGen *g, AstNode *node) { if (!g->cur_block_context) return; @@ -424,17 +419,22 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) { fn_type = get_expr_type(fn_ref_expr); } - int expected_param_count = fn_type->data.fn.src_param_count; + TypeTableEntry *src_return_type = fn_type->data.fn.src_return_type; + TypeTableEntry *gen_return_type = fn_type->data.fn.gen_return_type; + int fn_call_param_count = node->data.fn_call_expr.params.length; - int actual_param_count = fn_call_param_count + (struct_type ? 1 : 0); + bool first_arg_ret = handle_is_ptr(src_return_type); + int actual_param_count = fn_call_param_count + (struct_type ? 1 : 0) + (first_arg_ret ? 1 : 0); bool is_var_args = fn_type->data.fn.is_var_args; - assert((is_var_args && actual_param_count >= expected_param_count) || - actual_param_count == expected_param_count); // don't really include void values LLVMValueRef *gen_param_values = allocate(actual_param_count); int gen_param_index = 0; + if (first_arg_ret) { + gen_param_values[gen_param_index] = node->data.fn_call_expr.tmp_ptr; + gen_param_index += 1; + } if (struct_type) { gen_param_values[gen_param_index] = gen_expr(g, first_param_expr); gen_param_index += 1; @@ -454,8 +454,10 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) { LLVMValueRef result = LLVMZigBuildCall(g->builder, fn_val, gen_param_values, gen_param_index, fn_type->data.fn.calling_convention, ""); - if (fn_type->data.fn.return_type->id == TypeTableEntryIdUnreachable) { + if (gen_return_type->id == TypeTableEntryIdUnreachable) { return LLVMBuildUnreachable(g->builder); + } else if (first_arg_ret) { + return node->data.fn_call_expr.tmp_ptr; } else { return result; } @@ -1236,21 +1238,60 @@ static LLVMValueRef gen_bin_op_expr(CodeGen *g, AstNode *node) { zig_unreachable(); } +static LLVMValueRef gen_return(CodeGen *g, AstNode *source_node, LLVMValueRef value) { + TypeTableEntry *return_type = g->cur_fn->type_entry->data.fn.src_return_type; + if (handle_is_ptr(return_type)) { + assert(g->cur_ret_ptr); + gen_assign_raw(g, source_node, BinOpTypeAssign, g->cur_ret_ptr, value, return_type, return_type); + add_debug_source_node(g, source_node); + return LLVMBuildRetVoid(g->builder); + } else { + add_debug_source_node(g, source_node); + return LLVMBuildRet(g->builder, value); + } +} + static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) { assert(node->type == NodeTypeReturnExpr); AstNode *param_node = node->data.return_expr.expr; assert(param_node); + LLVMValueRef value = gen_expr(g, param_node); + TypeTableEntry *value_type = get_expr_type(param_node); switch (node->data.return_expr.kind) { case ReturnKindUnconditional: - { - LLVMValueRef value = gen_expr(g, param_node); - - add_debug_source_node(g, node); - return LLVMBuildRet(g->builder, value); - } + return gen_return(g, node, value); case ReturnKindError: - zig_panic("TODO"); + { + assert(value_type->id == TypeTableEntryIdErrorUnion); + TypeTableEntry *child_type = value_type->data.error.child_type; + + LLVMBasicBlockRef return_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "ErrReturnYes"); + LLVMBasicBlockRef continue_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "ErrReturnNo"); + + LLVMPositionBuilderAtEnd(g->builder, return_block); + if (child_type->size_in_bits > 0) { + zig_panic("TODO write the error tag value to sret"); + add_debug_source_node(g, node); + LLVMBuildRetVoid(g->builder); + } else { + add_debug_source_node(g, node); + LLVMBuildRet(g->builder, value); + } + + LLVMPositionBuilderAtEnd(g->builder, continue_block); + if (child_type->size_in_bits > 0) { + add_debug_source_node(g, node); + LLVMValueRef val_ptr = LLVMBuildStructGEP(g->builder, value, 1, ""); + if (handle_is_ptr(child_type)) { + return val_ptr; + } else { + return LLVMBuildLoad(g->builder, val_ptr, ""); + } + } else { + return nullptr; + } + } case ReturnKindMaybe: zig_panic("TODO"); } @@ -1386,13 +1427,8 @@ static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *i return_value = gen_expr(g, statement_node); } - if (implicit_return_type) { - add_debug_source_node(g, block_node); - if (implicit_return_type->id == TypeTableEntryIdVoid) { - LLVMBuildRetVoid(g->builder); - } else if (implicit_return_type->id != TypeTableEntryIdUnreachable) { - LLVMBuildRet(g->builder, return_value); - } + if (implicit_return_type && implicit_return_type->id != TypeTableEntryIdUnreachable) { + gen_return(g, block_node, return_value); } g->cur_block_context = old_block_context; @@ -2257,25 +2293,35 @@ static void do_code_gen(CodeGen *g) { assert(proto_node->type == NodeTypeFnProto); AstNodeFnProto *fn_proto = &proto_node->data.fn_proto; + if (handle_is_ptr(fn_table_entry->type_entry->data.fn.src_return_type)) { + LLVMValueRef first_arg = LLVMGetParam(fn_table_entry->fn_value, 0); + LLVMAddAttribute(first_arg, LLVMStructRetAttribute); + } + // set parameter attributes - int gen_param_index = 0; for (int param_decl_i = 0; param_decl_i < fn_proto->params.length; param_decl_i += 1) { AstNode *param_node = fn_proto->params.at(param_decl_i); assert(param_node->type == NodeTypeParamDecl); - if (is_param_decl_type_void(g, param_node)) + + int gen_index = param_node->data.param_decl.gen_index; + + if (gen_index < 0) { continue; + } + AstNode *type_node = param_node->data.param_decl.type; TypeTableEntry *param_type = fn_proto_type_from_type_node(g, type_node); - LLVMValueRef argument_val = LLVMGetParam(fn_table_entry->fn_value, gen_param_index); + LLVMValueRef argument_val = LLVMGetParam(fn_table_entry->fn_value, gen_index); bool param_is_noalias = param_node->data.param_decl.is_noalias; if (param_type->id == TypeTableEntryIdPointer && param_is_noalias) { LLVMAddAttribute(argument_val, LLVMNoAliasAttribute); - } else if (param_type->id == TypeTableEntryIdPointer && - param_type->data.pointer.is_const) - { + } + if (param_type->id == TypeTableEntryIdPointer && param_type->data.pointer.is_const) { LLVMAddAttribute(argument_val, LLVMReadOnlyAttribute); } - gen_param_index += 1; + if (param_node->data.param_decl.is_byval) { + LLVMAddAttribute(argument_val, LLVMByValAttribute); + } } } @@ -2287,6 +2333,11 @@ static void do_code_gen(CodeGen *g) { AstNode *fn_def_node = fn_table_entry->fn_def_node; LLVMValueRef fn = fn_table_entry->fn_value; g->cur_fn = fn_table_entry; + if (handle_is_ptr(fn_table_entry->type_entry->data.fn.src_return_type)) { + g->cur_ret_ptr = LLVMGetParam(fn, 0); + } else { + g->cur_ret_ptr = nullptr; + } AstNode *proto_node = fn_table_entry->proto_node; assert(proto_node->type == NodeTypeFnProto); @@ -2368,8 +2419,9 @@ static void do_code_gen(CodeGen *g) { AstNode *param_decl = fn_proto->params.at(param_i); assert(param_decl->type == NodeTypeParamDecl); - if (is_param_decl_type_void(g, param_decl)) + if (param_decl->data.param_decl.gen_index < 0) { continue; + } VariableTableEntry *variable = param_decl->data.param_decl.variable; diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 5f7a9ec8e0..e7b3e57c66 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -1233,6 +1233,27 @@ pub fn main(args: [][]u8) %void => { print_str("OK\n"); return; } +} + )SOURCE", "OK\n"); + + add_simple_case("return struct byval from function", R"SOURCE( +import "std.zig"; +struct Foo { + x: i32, + y: i32, +} +fn make_foo() Foo => { + Foo { + .x = 1234, + .y = 5678, + } +} +pub fn main(args: [][]u8) %void => { + const foo = make_foo(); + if (foo.y != 5678) { + print_str("BAD\n"); + } + print_str("OK\n"); } )SOURCE", "OK\n"); }