diff --git a/doc/langref.md b/doc/langref.md index 3930b3f9b2..e1948d20ad 100644 --- a/doc/langref.md +++ b/doc/langref.md @@ -160,7 +160,7 @@ SliceExpression : token(LBracket) Expression token(Ellipsis) option(Expression) PrefixOp : token(Not) | token(Dash) | token(Tilde) | token(Star) | (token(Ampersand) option(token(Const))) -PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType +PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType | (token(AtSign) token(Symbol) FnCallExpression) StructValueExpression : token(Type) token(LBrace) list(StructValueExpressionField, token(Comma)) token(RBrace) diff --git a/src/analyze.cpp b/src/analyze.cpp index 18afb9eb05..3f4462be71 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -2064,6 +2064,41 @@ static TypeTableEntry *analyze_compiler_fn_type(CodeGen *g, ImportTableEntry *im } } +static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, + TypeTableEntry *expected_type, AstNode *node) +{ + AstNode *fn_ref_expr = node->data.fn_call_expr.fn_ref_expr; + Buf *name = &fn_ref_expr->data.symbol; + + auto entry = g->builtin_fn_table.maybe_get(name); + + if (entry) { + BuiltinFnEntry *builtin_fn = entry->value; + int actual_param_count = node->data.fn_call_expr.params.length; + + assert(node->codegen_node); + node->codegen_node->data.fn_call_node.builtin_fn = builtin_fn; + + if (builtin_fn->param_count != actual_param_count) { + add_node_error(g, node, + buf_sprintf("expected %d arguments, got %d", + builtin_fn->param_count, actual_param_count)); + } + + for (int i = 0; i < actual_param_count; i += 1) { + AstNode *child = node->data.fn_call_expr.params.at(i); + TypeTableEntry *expected_param_type = builtin_fn->param_types[i]; + analyze_expression(g, import, context, expected_param_type, child); + } + + return builtin_fn->return_type; + } else { + add_node_error(g, node, + buf_sprintf("invalid builtin function: '%s'", buf_ptr(name))); + return g->builtin_types.entry_invalid; + } +} + static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, TypeTableEntry *expected_type, AstNode *node) { @@ -2091,6 +2126,9 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import return g->builtin_types.entry_invalid; } } else if (fn_ref_expr->type == NodeTypeSymbol) { + if (node->data.fn_call_expr.is_builtin) { + return analyze_builtin_fn_call_expr(g, import, context, expected_type, node); + } name = &fn_ref_expr->data.symbol; } else { add_node_error(g, node, @@ -2126,12 +2164,12 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import if (fn_proto->is_var_args) { if (actual_param_count < expected_param_count) { add_node_error(g, node, - buf_sprintf("wrong number of arguments. Expected at least %d, got %d.", + buf_sprintf("expected at least %d arguments, got %d", expected_param_count, actual_param_count)); } } else if (expected_param_count != actual_param_count) { add_node_error(g, node, - buf_sprintf("wrong number of arguments. Expected %d, got %d.", + buf_sprintf("expected %d arguments, got %d", expected_param_count, actual_param_count)); } diff --git a/src/analyze.hpp b/src/analyze.hpp index fd2d11cb71..8785a4c616 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -148,6 +148,20 @@ struct FnTableEntry { HashMap label_table; }; +enum BuiltinFnId { + BuiltinFnIdInvalid, + BuiltinFnIdArithmeticWithOverflow, +}; + +struct BuiltinFnEntry { + BuiltinFnId id; + Buf name; + int param_count; + TypeTableEntry *return_type; + TypeTableEntry **param_types; + LLVMValueRef fn_val; +}; + struct CodeGen { LLVMModuleRef module; ZigList errors; @@ -161,6 +175,7 @@ struct CodeGen { HashMap str_table; HashMap link_table; HashMap import_table; + HashMap builtin_fn_table; struct { TypeTableEntry *entry_bool; @@ -342,6 +357,10 @@ struct WhileNode { bool contains_break; }; +struct FnCallNode { + BuiltinFnEntry *builtin_fn; +}; + struct CodeGenNode { union { TypeNode type_node; // for NodeTypeType @@ -363,17 +382,11 @@ struct CodeGenNode { ParamDeclNode param_decl_node; // for NodeTypeParamDecl ImportNode import_node; // for NodeTypeUse WhileNode while_node; // for NodeTypeWhileExpr + FnCallNode fn_call_node; // for NodeTypeFnCallExpr } data; ExprNode expr_node; // for all the expression nodes }; -static inline Buf *hack_get_fn_call_name(CodeGen *g, AstNode *node) { - // Assume that the expression evaluates to a simple name and return the buf - // TODO after type checking works we should be able to remove this hack - assert(node->type == NodeTypeSymbol); - return &node->data.symbol; -} - void semantic_analyze(CodeGen *g); void add_node_error(CodeGen *g, AstNode *node, Buf *msg); void alloc_codegen_node(AstNode *node); diff --git a/src/codegen.cpp b/src/codegen.cpp index 5ea7ff401f..b23757bee3 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -22,6 +22,7 @@ CodeGen *codegen_create(Buf *root_source_dir) { g->str_table.init(32); g->link_table.init(32); g->import_table.init(32); + g->builtin_fn_table.init(32); g->build_type = CodeGenBuildTypeDebug; g->root_source_dir = root_source_dir; @@ -139,6 +140,41 @@ static TypeTableEntry *get_expr_type(AstNode *node) { return cast_type ? cast_type : node->codegen_node->expr_node.type_entry; } +static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) { + assert(node->type == NodeTypeFnCallExpr); + AstNode *fn_ref_expr = node->data.fn_call_expr.fn_ref_expr; + assert(fn_ref_expr->type == NodeTypeSymbol); + BuiltinFnEntry *builtin_fn = node->codegen_node->data.fn_call_node.builtin_fn; + + switch (builtin_fn->id) { + case BuiltinFnIdInvalid: + zig_unreachable(); + case BuiltinFnIdArithmeticWithOverflow: + { + int fn_call_param_count = node->data.fn_call_expr.params.length; + assert(fn_call_param_count == 3); + + LLVMValueRef op1 = gen_expr(g, node->data.fn_call_expr.params.at(0)); + LLVMValueRef op2 = gen_expr(g, node->data.fn_call_expr.params.at(1)); + LLVMValueRef ptr_result = gen_expr(g, node->data.fn_call_expr.params.at(2)); + + LLVMValueRef params[] = { + op1, + op2, + }; + + add_debug_source_node(g, node); + LLVMValueRef result_struct = LLVMBuildCall(g->builder, builtin_fn->fn_val, params, 2, ""); + LLVMValueRef result = LLVMBuildExtractValue(g->builder, result_struct, 0, ""); + LLVMValueRef overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, ""); + LLVMBuildStore(g->builder, result, ptr_result); + + return overflow_bit; + } + } + zig_unreachable(); +} + static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) { assert(node->type == NodeTypeFnCallExpr); @@ -159,7 +195,15 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) { zig_unreachable(); } } else if (fn_ref_expr->type == NodeTypeSymbol) { - Buf *name = hack_get_fn_call_name(g, fn_ref_expr); + if (node->data.fn_call_expr.is_builtin) { + return gen_builtin_fn_call_expr(g, node); + } + + // Assume that the expression evaluates to a simple name and return the buf + // TODO after we support function pointers we can make this generic + assert(fn_ref_expr->type == NodeTypeSymbol); + Buf *name = &fn_ref_expr->data.symbol; + struct_type = nullptr; first_param_expr = nullptr; fn_table_entry = g->cur_fn->import_entry->fn_table.get(name); @@ -2167,6 +2211,64 @@ static void define_builtin_types(CodeGen *g) { } } +static void define_builtin_fns_int(CodeGen *g, TypeTableEntry *type_entry) { + assert(type_entry->id == TypeTableEntryIdInt); + struct OverflowFn { + const char *bare_name; + const char *signed_name; + const char *unsigned_name; + }; + OverflowFn overflow_fns[] = { + {"add", "sadd", "uadd"}, + {"sub", "ssub", "usub"}, + {"mul", "smul", "umul"}, + }; + for (int i = 0; i < sizeof(overflow_fns)/sizeof(overflow_fns[0]); i += 1) { + OverflowFn *overflow_fn = &overflow_fns[i]; + BuiltinFnEntry *builtin_fn = allocate(1); + buf_resize(&builtin_fn->name, 0); + buf_appendf(&builtin_fn->name, "%s_with_overflow_%s", overflow_fn->bare_name, buf_ptr(&type_entry->name)); + builtin_fn->id = BuiltinFnIdArithmeticWithOverflow; + builtin_fn->return_type = g->builtin_types.entry_bool; + builtin_fn->param_count = 3; + builtin_fn->param_types = allocate(builtin_fn->param_count); + builtin_fn->param_types[0] = type_entry; + builtin_fn->param_types[1] = type_entry; + builtin_fn->param_types[2] = get_pointer_to_type(g, type_entry, false, false); + + + const char *signed_str = type_entry->data.integral.is_signed ? + overflow_fn->signed_name : overflow_fn->unsigned_name; + Buf *llvm_name = buf_sprintf("llvm.%s.with.overflow.i%" PRIu64, signed_str, type_entry->size_in_bits); + + LLVMTypeRef return_elem_types[] = { + type_entry->type_ref, + LLVMInt1Type(), + }; + LLVMTypeRef param_types[] = { + type_entry->type_ref, + type_entry->type_ref, + }; + LLVMTypeRef return_struct_type = LLVMStructType(return_elem_types, 2, false); + LLVMTypeRef fn_type = LLVMFunctionType(return_struct_type, param_types, 2, false); + builtin_fn->fn_val = LLVMAddFunction(g->module, buf_ptr(llvm_name), fn_type); + assert(LLVMGetIntrinsicID(builtin_fn->fn_val)); + + g->builtin_fn_table.put(&builtin_fn->name, builtin_fn); + } +} + +static void define_builtin_fns(CodeGen *g) { + define_builtin_fns_int(g, g->builtin_types.entry_u8); + define_builtin_fns_int(g, g->builtin_types.entry_u16); + define_builtin_fns_int(g, g->builtin_types.entry_u32); + define_builtin_fns_int(g, g->builtin_types.entry_u64); + define_builtin_fns_int(g, g->builtin_types.entry_i8); + define_builtin_fns_int(g, g->builtin_types.entry_i16); + define_builtin_fns_int(g, g->builtin_types.entry_i32); + define_builtin_fns_int(g, g->builtin_types.entry_i64); +} + static void init(CodeGen *g, Buf *source_path) { @@ -2228,9 +2330,10 @@ static void init(CodeGen *g, Buf *source_path) { "", 0, !g->strip_debug_symbols); // This is for debug stuff that doesn't have a real file. - g->dummy_di_file = nullptr; //LLVMZigCreateFile(g->dbuilder, "", ""); + g->dummy_di_file = nullptr; define_builtin_types(g); + define_builtin_fns(g); } diff --git a/src/parser.cpp b/src/parser.cpp index 0a1fb654b3..828a2f6406 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -1313,7 +1313,7 @@ static AstNode *ast_parse_struct_val_expr(ParseContext *pc, int *token_index) { } /* -PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType +PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType | (token(AtSign) token(Symbol) FnCallExpression) KeywordLiteral : token(Unreachable) | token(Void) | token(True) | token(False) | token(Null) */ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool mandatory) { @@ -1356,6 +1356,18 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool AstNode *node = ast_create_node(pc, NodeTypeNullLiteral, token); *token_index += 1; return node; + } else if (token->id == TokenIdAtSign) { + *token_index += 1; + Token *name_tok = ast_eat_token(pc, token_index, TokenIdSymbol); + AstNode *name_node = ast_create_node(pc, NodeTypeSymbol, name_tok); + ast_buf_from_token(pc, name_tok, &name_node->data.symbol); + + AstNode *node = ast_create_node(pc, NodeTypeFnCallExpr, token); + node->data.fn_call_expr.fn_ref_expr = name_node; + ast_eat_token(pc, token_index, TokenIdLParen); + ast_parse_fn_call_param_list(pc, token_index, &node->data.fn_call_expr.params); + node->data.fn_call_expr.is_builtin = true; + return node; } else if (token->id == TokenIdSymbol) { Token *next_token = &pc->tokens->at(*token_index + 1); diff --git a/src/parser.hpp b/src/parser.hpp index 182ac13fd0..4f7bed03b5 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -176,6 +176,7 @@ struct AstNodeBinOpExpr { struct AstNodeFnCallExpr { AstNode *fn_ref_expr; ZigList params; + bool is_builtin; }; struct AstNodeArrayAccessExpr { diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp index 94f10c6e03..6abaa8448a 100644 --- a/src/tokenizer.cpp +++ b/src/tokenizer.cpp @@ -376,6 +376,10 @@ void tokenize(Buf *buf, Tokenization *out) { begin_token(&t, TokenIdTilde); end_token(&t); break; + case '@': + begin_token(&t, TokenIdAtSign); + end_token(&t); + break; case '-': begin_token(&t, TokenIdDash); t.state = TokenizeStateSawDash; @@ -1074,6 +1078,7 @@ static const char * token_name(Token *token) { case TokenIdMaybe: return "Maybe"; case TokenIdDoubleQuestion: return "DoubleQuestion"; case TokenIdMaybeAssign: return "MaybeAssign"; + case TokenIdAtSign: return "AtSign"; } return "(invalid token)"; } diff --git a/src/tokenizer.hpp b/src/tokenizer.hpp index f9224de98c..dc1da63a7b 100644 --- a/src/tokenizer.hpp +++ b/src/tokenizer.hpp @@ -89,6 +89,7 @@ enum TokenId { TokenIdMaybe, TokenIdDoubleQuestion, TokenIdMaybeAssign, + TokenIdAtSign, }; struct Token { diff --git a/std/std.zig b/std/std.zig index 82471748b1..2ccacc3166 100644 --- a/std/std.zig +++ b/std/std.zig @@ -63,10 +63,6 @@ pub fn parse_u64(buf: []u8, radix: u8, result: &u64) -> bool { return true; } - x *= radix; - x += digit; - - /* TODO intrinsics mul and add with overflow // x *= radix if (@mul_with_overflow_u64(x, radix, &x)) { return true; @@ -76,7 +72,6 @@ pub fn parse_u64(buf: []u8, radix: u8, result: &u64) -> bool { if (@add_with_overflow_u64(x, digit, &x)) { return true; } - */ i += 1; } diff --git a/test/run_tests.cpp b/test/run_tests.cpp index cc93e188e3..b72e23d23b 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -953,6 +953,24 @@ fn f(c: u8) -> u8 { } else { 2 } +} + )SOURCE", "OK\n"); + + add_simple_case("overflow intrinsics", R"SOURCE( +use "std.zig"; +pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 { + var result: u8; + if (!@add_with_overflow_u8(250, 100, &result)) { + print_str("BAD\n"); + } + if (@add_with_overflow_u8(100, 150, &result)) { + print_str("BAD\n"); + } + if (result != 250) { + print_str("BAD\n"); + } + print_str("OK\n"); + return 0; } )SOURCE", "OK\n"); } @@ -995,7 +1013,7 @@ fn a() { b(1); } fn b(a: i32, b: i32, c: i32) { } - )SOURCE", 1, ".tmp_source.zig:3:6: error: wrong number of arguments. Expected 3, got 1."); + )SOURCE", 1, ".tmp_source.zig:3:6: error: expected 3 arguments, got 1"); add_compile_fail_case("invalid type", R"SOURCE( fn a() -> bogus {}