diff --git a/doc/langref.md b/doc/langref.md index 287db13056..4ae1a2c0c8 100644 --- a/doc/langref.md +++ b/doc/langref.md @@ -141,7 +141,7 @@ StructLiteralField = "." "Symbol" "=" Expression PrefixOp = "!" | "-" | "~" | "*" | ("&" option("const")) | "?" | "%" | "%%" -PrimaryExpression = "Number" | "String" | "CharLiteral" | KeywordLiteral | GroupedExpression | GotoExpression | BlockExpression | "Symbol" | ("@" "Symbol" FnCallExpression) | ArrayType | FnProto | AsmExpression | ("error" "." "Symbol") +PrimaryExpression = "Number" | "String" | "CharLiteral" | KeywordLiteral | GroupedExpression | GotoExpression | BlockExpression | "Symbol" | ("@" "Symbol" FnCallExpression) | ArrayType | (option("extern") FnProto) | AsmExpression | ("error" "." "Symbol") ArrayType = "[" option(Expression) "]" option("const") PrefixOpExpression diff --git a/src/all_types.hpp b/src/all_types.hpp index f592c4c886..6983a63887 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -189,6 +189,7 @@ struct AstNodeFnProto { FnTableEntry *fn_table_entry; bool skip; TopLevelDecl top_level_decl; + Expr resolved_expr; }; struct AstNodeFnDef { @@ -828,6 +829,7 @@ struct TypeTableEntryFn { bool is_var_args; int gen_param_count; LLVMCallConv calling_convention; + bool is_extern; bool is_naked; }; diff --git a/src/analyze.cpp b/src/analyze.cpp index c3d66a26f8..38b2a8faca 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -451,44 +451,20 @@ static TypeTableEntry *analyze_type_expr(CodeGen *g, ImportTableEntry *import, B return resolve_type(g, *node_ptr); } -static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_table_entry, - ImportTableEntry *import) +static TypeTableEntry *analyze_fn_proto_type(CodeGen *g, ImportTableEntry *import, BlockContext *context, + TypeTableEntry *expected_type, AstNode *node, bool is_naked) { assert(node->type == NodeTypeFnProto); AstNodeFnProto *fn_proto = &node->data.fn_proto; if (fn_proto->skip) { - return; + return g->builtin_types.entry_invalid; } TypeTableEntry *fn_type = new_type_table_entry(TypeTableEntryIdFn); - fn_table_entry->type_entry = fn_type; - fn_type->data.fn.calling_convention = fn_table_entry->internal_linkage ? LLVMFastCallConv : LLVMCCallConv; - - for (int i = 0; i < fn_proto->directives->length; i += 1) { - AstNode *directive_node = fn_proto->directives->at(i); - Buf *name = &directive_node->data.directive.name; - - if (buf_eql_str(name, "attribute")) { - Buf *attr_name = &directive_node->data.directive.param; - if (fn_table_entry->fn_def_node) { - if (buf_eql_str(attr_name, "naked")) { - fn_type->data.fn.is_naked = true; - } else if (buf_eql_str(attr_name, "inline")) { - fn_table_entry->is_inline = true; - } else { - add_node_error(g, directive_node, - buf_sprintf("invalid function attribute: '%s'", buf_ptr(name))); - } - } else { - add_node_error(g, directive_node, - buf_sprintf("invalid function attribute: '%s'", buf_ptr(name))); - } - } else { - add_node_error(g, directive_node, - buf_sprintf("invalid directive: '%s'", buf_ptr(name))); - } - } + fn_type->data.fn.is_extern = fn_proto->is_extern || (fn_proto->visib_mod == VisibModExport); + fn_type->data.fn.is_naked = is_naked; + fn_type->data.fn.calling_convention = fn_proto->is_extern ? LLVMCCallConv : LLVMFastCallConv; int src_param_count = node->data.fn_proto.params.length; fn_type->size_in_bits = g->pointer_size_bytes * 8; @@ -499,10 +475,9 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t // first, analyze the parameters and return type in order they appear in // source code in order for error messages to be in the best order. buf_resize(&fn_type->name, 0); - const char *export_str = fn_table_entry->internal_linkage ? "" : "export "; - const char *inline_str = fn_table_entry->is_inline ? "inline " : ""; + const char *extern_str = fn_type->data.fn.is_extern ? "extern " : ""; const char *naked_str = fn_type->data.fn.is_naked ? "naked " : ""; - buf_appendf(&fn_type->name, "%s%s%sfn(", export_str, inline_str, naked_str); + buf_appendf(&fn_type->name, "%s%sfn(", extern_str, naked_str); for (int i = 0; i < src_param_count; i += 1) { AstNode *child = node->data.fn_proto.params.at(i); assert(child->type == NodeTypeParamDecl); @@ -525,10 +500,9 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t const char *comma = (src_param_count == 0) ? "" : ", "; buf_appendf(&fn_type->name, "%s...", comma); } - buf_appendf(&fn_type->name, ")"); if (return_type->id != TypeTableEntryIdVoid) { - buf_appendf(&fn_type->name, " %s", buf_ptr(&return_type->name)); + buf_appendf(&fn_type->name, " -> %s", buf_ptr(&return_type->name)); } @@ -593,13 +567,12 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t fn_type->data.fn.gen_param_count = gen_param_index; if (fn_proto->skip) { - return; + return g->builtin_types.entry_invalid; } auto table_entry = import->fn_type_table.maybe_get(&fn_type->name); if (table_entry) { - fn_type = table_entry->value; - fn_table_entry->type_entry = fn_type; + return table_entry->value; } else { fn_type->data.fn.raw_type_ref = LLVMFunctionType(gen_return_type->type_ref, gen_param_types, gen_param_index, fn_type->data.fn.is_var_args); @@ -608,8 +581,56 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t param_di_types, gen_param_index + 1, 0); import->fn_type_table.put(&fn_type->name, fn_type); + + return fn_type; + } +} + + +static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_table_entry, + ImportTableEntry *import) +{ + assert(node->type == NodeTypeFnProto); + AstNodeFnProto *fn_proto = &node->data.fn_proto; + + if (fn_proto->skip) { + return; } + bool is_naked = false; + for (int i = 0; i < fn_proto->directives->length; i += 1) { + AstNode *directive_node = fn_proto->directives->at(i); + Buf *name = &directive_node->data.directive.name; + + if (buf_eql_str(name, "attribute")) { + Buf *attr_name = &directive_node->data.directive.param; + if (fn_table_entry->fn_def_node) { + if (buf_eql_str(attr_name, "naked")) { + is_naked = true; + } else if (buf_eql_str(attr_name, "inline")) { + fn_table_entry->is_inline = true; + } else { + add_node_error(g, directive_node, + buf_sprintf("invalid function attribute: '%s'", buf_ptr(name))); + } + } else { + add_node_error(g, directive_node, + buf_sprintf("invalid function attribute: '%s'", buf_ptr(name))); + } + } else { + add_node_error(g, directive_node, + buf_sprintf("invalid directive: '%s'", buf_ptr(name))); + } + } + + TypeTableEntry *fn_type = analyze_fn_proto_type(g, import, import->block_context, nullptr, node, is_naked); + + if (fn_type->id == TypeTableEntryIdInvalid) { + fn_proto->skip = true; + return; + } + + fn_table_entry->type_entry = fn_type; fn_table_entry->fn_value = LLVMAddFunction(g->module, buf_ptr(&fn_table_entry->symbol_name), fn_type->data.fn.raw_type_ref); @@ -624,7 +645,7 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t LLVMSetLinkage(fn_table_entry->fn_value, fn_table_entry->internal_linkage ? LLVMInternalLinkage : LLVMExternalLinkage); - if (return_type->id == TypeTableEntryIdUnreachable) { + if (fn_type->data.fn.src_return_type->id == TypeTableEntryIdUnreachable) { LLVMAddFunctionAttr(fn_table_entry->fn_value, LLVMNoReturnAttribute); } LLVMSetFunctionCallConv(fn_table_entry->fn_value, fn_type->data.fn.calling_convention); @@ -1353,7 +1374,29 @@ static bool types_match_const_cast_only(TypeTableEntry *expected_type, TypeTable if (expected_type->id == TypeTableEntryIdFn && actual_type->id == TypeTableEntryIdFn) { - zig_panic("TODO types_match_const_cast_only for fns"); + if (expected_type->data.fn.is_extern != actual_type->data.fn.is_extern) { + return false; + } + if (expected_type->data.fn.is_naked != actual_type->data.fn.is_naked) { + return false; + } + if (!types_match_const_cast_only(expected_type->data.fn.src_return_type, + actual_type->data.fn.src_return_type)) + { + return false; + } + if (expected_type->data.fn.src_param_count != actual_type->data.fn.src_param_count) { + return false; + } + for (int i = 0; i < expected_type->data.fn.src_param_count; i += 1) { + // note it's reversed for parameters + if (types_match_const_cast_only(actual_type->data.fn.param_types[i], + expected_type->data.fn.param_types[i])) + { + return false; + } + } + return true; } @@ -2902,6 +2945,18 @@ static TypeTableEntry *analyze_array_type(CodeGen *g, ImportTableEntry *import, } } +static TypeTableEntry *analyze_fn_proto_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, + TypeTableEntry *expected_type, AstNode *node) +{ + TypeTableEntry *type_entry = analyze_fn_proto_type(g, import, context, expected_type, node, false); + + if (type_entry->id == TypeTableEntryIdInvalid) { + return type_entry; + } + + return resolve_expr_const_val_as_type(g, node, type_entry); +} + static TypeTableEntry *analyze_while_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, TypeTableEntry *expected_type, AstNode *node) { @@ -4240,6 +4295,9 @@ static TypeTableEntry *analyze_expression(CodeGen *g, ImportTableEntry *import, case NodeTypeArrayType: return_type = analyze_array_type(g, import, context, expected_type, node); break; + case NodeTypeFnProto: + return_type = analyze_fn_proto_expr(g, import, context, expected_type, node); + break; case NodeTypeErrorType: return_type = resolve_expr_const_val_as_type(g, node, g->builtin_types.entry_pure_error); break; @@ -4250,7 +4308,6 @@ static TypeTableEntry *analyze_expression(CodeGen *g, ImportTableEntry *import, case NodeTypeSwitchRange: case NodeTypeDirective: case NodeTypeFnDecl: - case NodeTypeFnProto: case NodeTypeParamDecl: case NodeTypeRoot: case NodeTypeRootExportDecl: @@ -4555,13 +4612,23 @@ static void collect_expr_decl_deps(CodeGen *g, ImportTableEntry *import, AstNode collect_expr_decl_deps(g, import, node->data.switch_range.start, decl_node); collect_expr_decl_deps(g, import, node->data.switch_range.end, decl_node); break; - case NodeTypeVariableDeclaration: case NodeTypeFnProto: + // remember that fn proto node is used for function definitions as well + // as types + for (int i = 0; i < node->data.fn_proto.params.length; i += 1) { + AstNode *param = node->data.fn_proto.params.at(i); + collect_expr_decl_deps(g, import, param, decl_node); + } + collect_expr_decl_deps(g, import, node->data.fn_proto.return_type, decl_node); + break; + case NodeTypeParamDecl: + collect_expr_decl_deps(g, import, node->data.param_decl.type, decl_node); + break; + case NodeTypeVariableDeclaration: case NodeTypeRootExportDecl: case NodeTypeFnDef: case NodeTypeRoot: case NodeTypeFnDecl: - case NodeTypeParamDecl: case NodeTypeDirective: case NodeTypeImport: case NodeTypeCImport: @@ -4705,12 +4772,8 @@ static void detect_top_level_decl_deps(CodeGen *g, ImportTableEntry *import, Ast // determine which other top level declarations this function prototype depends on. TopLevelDecl *decl_node = &node->data.fn_proto.top_level_decl; decl_node->deps.init(1); - for (int i = 0; i < node->data.fn_proto.params.length; i += 1) { - AstNode *param_node = node->data.fn_proto.params.at(i); - assert(param_node->type == NodeTypeParamDecl); - collect_expr_decl_deps(g, import, param_node->data.param_decl.type, decl_node); - } - collect_expr_decl_deps(g, import, node->data.fn_proto.return_type, decl_node); + + collect_expr_decl_deps(g, import, node, decl_node); decl_node->name = name; decl_node->import = import; @@ -4999,11 +5062,12 @@ Expr *get_resolved_expr(AstNode *node) { return &node->data.error_type.resolved_expr; case NodeTypeSwitchExpr: return &node->data.switch_expr.resolved_expr; + case NodeTypeFnProto: + return &node->data.fn_proto.resolved_expr; case NodeTypeSwitchProng: case NodeTypeSwitchRange: case NodeTypeRoot: case NodeTypeRootExportDecl: - case NodeTypeFnProto: case NodeTypeFnDef: case NodeTypeFnDecl: case NodeTypeParamDecl: diff --git a/src/parser.cpp b/src/parser.cpp index e717a18941..b9755f3123 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -503,6 +503,8 @@ static AstNode *ast_parse_if_expr(ParseContext *pc, int *token_index, bool manda static AstNode *ast_parse_block_expr(ParseContext *pc, int *token_index, bool mandatory); static AstNode *ast_parse_unwrap_expr(ParseContext *pc, int *token_index, bool mandatory); static AstNode *ast_parse_prefix_op_expr(ParseContext *pc, int *token_index, bool mandatory); +static AstNode *ast_parse_fn_proto(ParseContext *pc, int *token_index, bool mandatory, + ZigList *directives, VisibMod visib_mod); static void ast_expect_token(ParseContext *pc, Token *token, TokenId token_id) { if (token->id == token_id) { @@ -671,7 +673,7 @@ static AstNode *ast_parse_grouped_expr(ParseContext *pc, int *token_index, bool Token *l_paren = &pc->tokens->at(*token_index); if (l_paren->id != TokenIdLParen) { if (mandatory) { - ast_invalid_token_error(pc, l_paren); + ast_expect_token(pc, l_paren, TokenIdLParen); } else { return nullptr; } @@ -695,7 +697,7 @@ static AstNode *ast_parse_array_type_expr(ParseContext *pc, int *token_index, bo Token *l_bracket = &pc->tokens->at(*token_index); if (l_bracket->id != TokenIdLBracket) { if (mandatory) { - ast_invalid_token_error(pc, l_bracket); + ast_expect_token(pc, l_bracket, TokenIdLBracket); } else { return nullptr; } @@ -865,7 +867,7 @@ static AstNode *ast_parse_asm_expr(ParseContext *pc, int *token_index, bool mand if (asm_token->id != TokenIdKeywordAsm) { if (mandatory) { - ast_invalid_token_error(pc, asm_token); + ast_expect_token(pc, asm_token, TokenIdKeywordAsm); } else { return nullptr; } @@ -905,7 +907,7 @@ static AstNode *ast_parse_asm_expr(ParseContext *pc, int *token_index, bool mand } /* -PrimaryExpression : "Number" | "String" | "CharLiteral" | KeywordLiteral | GroupedExpression | GotoExpression | BlockExpression | "Symbol" | ("@" "Symbol" FnCallExpression) | ArrayType | AsmExpression | ("error" "." "Symbol") +PrimaryExpression = "Number" | "String" | "CharLiteral" | KeywordLiteral | GroupedExpression | GotoExpression | BlockExpression | "Symbol" | ("@" "Symbol" FnCallExpression) | ArrayType | FnProto | AsmExpression | ("error" "." "Symbol") KeywordLiteral : "true" | "false" | "null" | "break" | "continue" | "undefined" | "error" */ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool mandatory) { @@ -956,6 +958,11 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool AstNode *node = ast_create_node(pc, NodeTypeErrorType, token); *token_index += 1; return node; + } else if (token->id == TokenIdKeywordExtern) { + *token_index += 1; + AstNode *node = ast_parse_fn_proto(pc, token_index, true, nullptr, VisibModPrivate); + node->data.fn_proto.is_extern = true; + return node; } else if (token->id == TokenIdAtSign) { *token_index += 1; Token *name_tok = ast_eat_token(pc, token_index, TokenIdSymbol); @@ -1002,6 +1009,11 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool return array_type_node; } + AstNode *fn_proto_node = ast_parse_fn_proto(pc, token_index, false, nullptr, VisibModPrivate); + if (fn_proto_node) { + return fn_proto_node; + } + AstNode *asm_expr = ast_parse_asm_expr(pc, token_index, false); if (asm_expr) { return asm_expr; @@ -1055,7 +1067,7 @@ static AstNode *ast_parse_curly_suffix_expr(ParseContext *pc, int *token_index, token = &pc->tokens->at(*token_index); continue; } else if (comma_tok->id != TokenIdRBrace) { - ast_invalid_token_error(pc, comma_tok); + ast_expect_token(pc, comma_tok, TokenIdRBrace); } else { *token_index += 1; break; @@ -1084,7 +1096,7 @@ static AstNode *ast_parse_curly_suffix_expr(ParseContext *pc, int *token_index, token = &pc->tokens->at(*token_index); continue; } else if (comma_tok->id != TokenIdRBrace) { - ast_invalid_token_error(pc, comma_tok); + ast_expect_token(pc, comma_tok, TokenIdRBrace); } else { *token_index += 1; break; @@ -1555,7 +1567,7 @@ static AstNode *ast_parse_else(ParseContext *pc, int *token_index, bool mandator if (else_token->id != TokenIdKeywordElse) { if (mandatory) { - ast_invalid_token_error(pc, else_token); + ast_expect_token(pc, else_token, TokenIdKeywordElse); } else { return nullptr; } @@ -1574,7 +1586,7 @@ static AstNode *ast_parse_if_expr(ParseContext *pc, int *token_index, bool manda Token *if_tok = &pc->tokens->at(*token_index); if (if_tok->id != TokenIdKeywordIf) { if (mandatory) { - ast_invalid_token_error(pc, if_tok); + ast_expect_token(pc, if_tok, TokenIdKeywordIf); } else { return nullptr; } @@ -1637,7 +1649,8 @@ static AstNode *ast_parse_return_expr(ParseContext *pc, int *token_index, bool m kind = ReturnKindError; *token_index += 2; } else if (mandatory) { - ast_invalid_token_error(pc, token); + ast_expect_token(pc, next_token, TokenIdKeywordReturn); + zig_unreachable(); } else { return nullptr; } @@ -1647,7 +1660,8 @@ static AstNode *ast_parse_return_expr(ParseContext *pc, int *token_index, bool m kind = ReturnKindMaybe; *token_index += 2; } else if (mandatory) { - ast_invalid_token_error(pc, token); + ast_expect_token(pc, next_token, TokenIdKeywordReturn); + zig_unreachable(); } else { return nullptr; } @@ -1655,7 +1669,8 @@ static AstNode *ast_parse_return_expr(ParseContext *pc, int *token_index, bool m kind = ReturnKindUnconditional; *token_index += 1; } else if (mandatory) { - ast_invalid_token_error(pc, token); + ast_expect_token(pc, token, TokenIdKeywordReturn); + zig_unreachable(); } else { return nullptr; } @@ -1756,7 +1771,7 @@ static AstNode *ast_parse_while_expr(ParseContext *pc, int *token_index, bool ma if (token->id != TokenIdKeywordWhile) { if (mandatory) { - ast_invalid_token_error(pc, token); + ast_expect_token(pc, token, TokenIdKeywordWhile); } else { return nullptr; } @@ -1791,7 +1806,7 @@ static AstNode *ast_parse_for_expr(ParseContext *pc, int *token_index, bool mand if (token->id != TokenIdKeywordFor) { if (mandatory) { - ast_invalid_token_error(pc, token); + ast_expect_token(pc, token, TokenIdKeywordFor); } else { return nullptr; } @@ -1829,7 +1844,7 @@ static AstNode *ast_parse_switch_expr(ParseContext *pc, int *token_index, bool m if (token->id != TokenIdKeywordSwitch) { if (mandatory) { - ast_invalid_token_error(pc, token); + ast_expect_token(pc, token, TokenIdKeywordSwitch); } else { return nullptr; } @@ -2082,7 +2097,7 @@ static AstNode *ast_parse_label(ParseContext *pc, int *token_index, bool mandato Token *symbol_token = &pc->tokens->at(*token_index); if (symbol_token->id != TokenIdSymbol) { if (mandatory) { - ast_invalid_token_error(pc, symbol_token); + ast_expect_token(pc, symbol_token, TokenIdSymbol); } else { return nullptr; } @@ -2091,7 +2106,7 @@ static AstNode *ast_parse_label(ParseContext *pc, int *token_index, bool mandato Token *colon_token = &pc->tokens->at(*token_index + 1); if (colon_token->id != TokenIdColon) { if (mandatory) { - ast_invalid_token_error(pc, colon_token); + ast_expect_token(pc, colon_token, TokenIdColon); } else { return nullptr; } @@ -2122,7 +2137,7 @@ static AstNode *ast_parse_block(ParseContext *pc, int *token_index, bool mandato if (last_token->id != TokenIdLBrace) { if (mandatory) { - ast_invalid_token_error(pc, last_token); + ast_expect_token(pc, last_token, TokenIdLBrace); } else { return nullptr; } @@ -2245,7 +2260,7 @@ static AstNode *ast_parse_extern_decl(ParseContext *pc, int *token_index, bool m Token *extern_kw = &pc->tokens->at(*token_index); if (extern_kw->id != TokenIdKeywordExtern) { if (mandatory) { - ast_invalid_token_error(pc, extern_kw); + ast_expect_token(pc, extern_kw, TokenIdKeywordExtern); } else { return nullptr; } @@ -2591,7 +2606,9 @@ void normalize_parent_ptrs(AstNode *node) { break; case NodeTypeFnProto: set_field(&node->data.fn_proto.return_type); - set_list_fields(node->data.fn_proto.directives); + if (node->data.fn_proto.directives) { + set_list_fields(node->data.fn_proto.directives); + } set_list_fields(&node->data.fn_proto.params); break; case NodeTypeFnDef: diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 940c92769c..6fa6e2eed9 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -1866,6 +1866,23 @@ fn f(i32) {} )SOURCE", 2, ".tmp_source.zig:2:1: error: missing function name", ".tmp_source.zig:3:6: error: missing parameter name"); + + add_compile_fail_case("wrong function type", R"SOURCE( +const fns = []fn(){ a, b, c }; +fn a() -> i32 {0} +fn b() -> i32 {1} +fn c() -> i32 {2} + )SOURCE", 3, + ".tmp_source.zig:2:21: error: expected type 'fn()', got 'fn() -> i32'", + ".tmp_source.zig:2:24: error: expected type 'fn()', got 'fn() -> i32'", + ".tmp_source.zig:2:27: error: expected type 'fn()', got 'fn() -> i32'"); + + add_compile_fail_case("extern function pointer mismatch", R"SOURCE( +const fns = [](fn(i32)->i32){ a, b, c }; +pub fn a(x: i32) -> i32 {x + 0} +pub fn b(x: i32) -> i32 {x + 1} +export fn c(x: i32) -> i32 {x + 2} + )SOURCE", 1, ".tmp_source.zig:2:37: error: expected type 'fn(i32) -> i32', got 'extern fn(i32) -> i32'"); } //////////////////////////////////////////////////////////////////////////////