suport checked arithmetic operations via intrinsics

closes #32
This commit is contained in:
Andrew Kelley 2016-01-08 23:41:40 -07:00
parent 14b9cbd43c
commit b7dd88ad68
10 changed files with 205 additions and 19 deletions

View File

@ -160,7 +160,7 @@ SliceExpression : token(LBracket) Expression token(Ellipsis) option(Expression)
PrefixOp : token(Not) | token(Dash) | token(Tilde) | token(Star) | (token(Ampersand) option(token(Const)))
PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType
PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType | (token(AtSign) token(Symbol) FnCallExpression)
StructValueExpression : token(Type) token(LBrace) list(StructValueExpressionField, token(Comma)) token(RBrace)

View File

@ -2064,6 +2064,41 @@ static TypeTableEntry *analyze_compiler_fn_type(CodeGen *g, ImportTableEntry *im
}
}
static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
TypeTableEntry *expected_type, AstNode *node)
{
AstNode *fn_ref_expr = node->data.fn_call_expr.fn_ref_expr;
Buf *name = &fn_ref_expr->data.symbol;
auto entry = g->builtin_fn_table.maybe_get(name);
if (entry) {
BuiltinFnEntry *builtin_fn = entry->value;
int actual_param_count = node->data.fn_call_expr.params.length;
assert(node->codegen_node);
node->codegen_node->data.fn_call_node.builtin_fn = builtin_fn;
if (builtin_fn->param_count != actual_param_count) {
add_node_error(g, node,
buf_sprintf("expected %d arguments, got %d",
builtin_fn->param_count, actual_param_count));
}
for (int i = 0; i < actual_param_count; i += 1) {
AstNode *child = node->data.fn_call_expr.params.at(i);
TypeTableEntry *expected_param_type = builtin_fn->param_types[i];
analyze_expression(g, import, context, expected_param_type, child);
}
return builtin_fn->return_type;
} else {
add_node_error(g, node,
buf_sprintf("invalid builtin function: '%s'", buf_ptr(name)));
return g->builtin_types.entry_invalid;
}
}
static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
TypeTableEntry *expected_type, AstNode *node)
{
@ -2091,6 +2126,9 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import
return g->builtin_types.entry_invalid;
}
} else if (fn_ref_expr->type == NodeTypeSymbol) {
if (node->data.fn_call_expr.is_builtin) {
return analyze_builtin_fn_call_expr(g, import, context, expected_type, node);
}
name = &fn_ref_expr->data.symbol;
} else {
add_node_error(g, node,
@ -2126,12 +2164,12 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import
if (fn_proto->is_var_args) {
if (actual_param_count < expected_param_count) {
add_node_error(g, node,
buf_sprintf("wrong number of arguments. Expected at least %d, got %d.",
buf_sprintf("expected at least %d arguments, got %d",
expected_param_count, actual_param_count));
}
} else if (expected_param_count != actual_param_count) {
add_node_error(g, node,
buf_sprintf("wrong number of arguments. Expected %d, got %d.",
buf_sprintf("expected %d arguments, got %d",
expected_param_count, actual_param_count));
}

View File

@ -148,6 +148,20 @@ struct FnTableEntry {
HashMap<Buf *, LabelTableEntry *, buf_hash, buf_eql_buf> label_table;
};
enum BuiltinFnId {
BuiltinFnIdInvalid,
BuiltinFnIdArithmeticWithOverflow,
};
struct BuiltinFnEntry {
BuiltinFnId id;
Buf name;
int param_count;
TypeTableEntry *return_type;
TypeTableEntry **param_types;
LLVMValueRef fn_val;
};
struct CodeGen {
LLVMModuleRef module;
ZigList<ErrorMsg*> errors;
@ -161,6 +175,7 @@ struct CodeGen {
HashMap<Buf *, LLVMValueRef, buf_hash, buf_eql_buf> str_table;
HashMap<Buf *, bool, buf_hash, buf_eql_buf> link_table;
HashMap<Buf *, ImportTableEntry *, buf_hash, buf_eql_buf> import_table;
HashMap<Buf *, BuiltinFnEntry *, buf_hash, buf_eql_buf> builtin_fn_table;
struct {
TypeTableEntry *entry_bool;
@ -342,6 +357,10 @@ struct WhileNode {
bool contains_break;
};
struct FnCallNode {
BuiltinFnEntry *builtin_fn;
};
struct CodeGenNode {
union {
TypeNode type_node; // for NodeTypeType
@ -363,17 +382,11 @@ struct CodeGenNode {
ParamDeclNode param_decl_node; // for NodeTypeParamDecl
ImportNode import_node; // for NodeTypeUse
WhileNode while_node; // for NodeTypeWhileExpr
FnCallNode fn_call_node; // for NodeTypeFnCallExpr
} data;
ExprNode expr_node; // for all the expression nodes
};
static inline Buf *hack_get_fn_call_name(CodeGen *g, AstNode *node) {
// Assume that the expression evaluates to a simple name and return the buf
// TODO after type checking works we should be able to remove this hack
assert(node->type == NodeTypeSymbol);
return &node->data.symbol;
}
void semantic_analyze(CodeGen *g);
void add_node_error(CodeGen *g, AstNode *node, Buf *msg);
void alloc_codegen_node(AstNode *node);

View File

@ -22,6 +22,7 @@ CodeGen *codegen_create(Buf *root_source_dir) {
g->str_table.init(32);
g->link_table.init(32);
g->import_table.init(32);
g->builtin_fn_table.init(32);
g->build_type = CodeGenBuildTypeDebug;
g->root_source_dir = root_source_dir;
@ -139,6 +140,41 @@ static TypeTableEntry *get_expr_type(AstNode *node) {
return cast_type ? cast_type : node->codegen_node->expr_node.type_entry;
}
static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeFnCallExpr);
AstNode *fn_ref_expr = node->data.fn_call_expr.fn_ref_expr;
assert(fn_ref_expr->type == NodeTypeSymbol);
BuiltinFnEntry *builtin_fn = node->codegen_node->data.fn_call_node.builtin_fn;
switch (builtin_fn->id) {
case BuiltinFnIdInvalid:
zig_unreachable();
case BuiltinFnIdArithmeticWithOverflow:
{
int fn_call_param_count = node->data.fn_call_expr.params.length;
assert(fn_call_param_count == 3);
LLVMValueRef op1 = gen_expr(g, node->data.fn_call_expr.params.at(0));
LLVMValueRef op2 = gen_expr(g, node->data.fn_call_expr.params.at(1));
LLVMValueRef ptr_result = gen_expr(g, node->data.fn_call_expr.params.at(2));
LLVMValueRef params[] = {
op1,
op2,
};
add_debug_source_node(g, node);
LLVMValueRef result_struct = LLVMBuildCall(g->builder, builtin_fn->fn_val, params, 2, "");
LLVMValueRef result = LLVMBuildExtractValue(g->builder, result_struct, 0, "");
LLVMValueRef overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, "");
LLVMBuildStore(g->builder, result, ptr_result);
return overflow_bit;
}
}
zig_unreachable();
}
static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeFnCallExpr);
@ -159,7 +195,15 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
zig_unreachable();
}
} else if (fn_ref_expr->type == NodeTypeSymbol) {
Buf *name = hack_get_fn_call_name(g, fn_ref_expr);
if (node->data.fn_call_expr.is_builtin) {
return gen_builtin_fn_call_expr(g, node);
}
// Assume that the expression evaluates to a simple name and return the buf
// TODO after we support function pointers we can make this generic
assert(fn_ref_expr->type == NodeTypeSymbol);
Buf *name = &fn_ref_expr->data.symbol;
struct_type = nullptr;
first_param_expr = nullptr;
fn_table_entry = g->cur_fn->import_entry->fn_table.get(name);
@ -2167,6 +2211,64 @@ static void define_builtin_types(CodeGen *g) {
}
}
static void define_builtin_fns_int(CodeGen *g, TypeTableEntry *type_entry) {
assert(type_entry->id == TypeTableEntryIdInt);
struct OverflowFn {
const char *bare_name;
const char *signed_name;
const char *unsigned_name;
};
OverflowFn overflow_fns[] = {
{"add", "sadd", "uadd"},
{"sub", "ssub", "usub"},
{"mul", "smul", "umul"},
};
for (int i = 0; i < sizeof(overflow_fns)/sizeof(overflow_fns[0]); i += 1) {
OverflowFn *overflow_fn = &overflow_fns[i];
BuiltinFnEntry *builtin_fn = allocate<BuiltinFnEntry>(1);
buf_resize(&builtin_fn->name, 0);
buf_appendf(&builtin_fn->name, "%s_with_overflow_%s", overflow_fn->bare_name, buf_ptr(&type_entry->name));
builtin_fn->id = BuiltinFnIdArithmeticWithOverflow;
builtin_fn->return_type = g->builtin_types.entry_bool;
builtin_fn->param_count = 3;
builtin_fn->param_types = allocate<TypeTableEntry *>(builtin_fn->param_count);
builtin_fn->param_types[0] = type_entry;
builtin_fn->param_types[1] = type_entry;
builtin_fn->param_types[2] = get_pointer_to_type(g, type_entry, false, false);
const char *signed_str = type_entry->data.integral.is_signed ?
overflow_fn->signed_name : overflow_fn->unsigned_name;
Buf *llvm_name = buf_sprintf("llvm.%s.with.overflow.i%" PRIu64, signed_str, type_entry->size_in_bits);
LLVMTypeRef return_elem_types[] = {
type_entry->type_ref,
LLVMInt1Type(),
};
LLVMTypeRef param_types[] = {
type_entry->type_ref,
type_entry->type_ref,
};
LLVMTypeRef return_struct_type = LLVMStructType(return_elem_types, 2, false);
LLVMTypeRef fn_type = LLVMFunctionType(return_struct_type, param_types, 2, false);
builtin_fn->fn_val = LLVMAddFunction(g->module, buf_ptr(llvm_name), fn_type);
assert(LLVMGetIntrinsicID(builtin_fn->fn_val));
g->builtin_fn_table.put(&builtin_fn->name, builtin_fn);
}
}
static void define_builtin_fns(CodeGen *g) {
define_builtin_fns_int(g, g->builtin_types.entry_u8);
define_builtin_fns_int(g, g->builtin_types.entry_u16);
define_builtin_fns_int(g, g->builtin_types.entry_u32);
define_builtin_fns_int(g, g->builtin_types.entry_u64);
define_builtin_fns_int(g, g->builtin_types.entry_i8);
define_builtin_fns_int(g, g->builtin_types.entry_i16);
define_builtin_fns_int(g, g->builtin_types.entry_i32);
define_builtin_fns_int(g, g->builtin_types.entry_i64);
}
static void init(CodeGen *g, Buf *source_path) {
@ -2228,9 +2330,10 @@ static void init(CodeGen *g, Buf *source_path) {
"", 0, !g->strip_debug_symbols);
// This is for debug stuff that doesn't have a real file.
g->dummy_di_file = nullptr; //LLVMZigCreateFile(g->dbuilder, "", "");
g->dummy_di_file = nullptr;
define_builtin_types(g);
define_builtin_fns(g);
}

View File

@ -1313,7 +1313,7 @@ static AstNode *ast_parse_struct_val_expr(ParseContext *pc, int *token_index) {
}
/*
PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType
PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType | (token(AtSign) token(Symbol) FnCallExpression)
KeywordLiteral : token(Unreachable) | token(Void) | token(True) | token(False) | token(Null)
*/
static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool mandatory) {
@ -1356,6 +1356,18 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool
AstNode *node = ast_create_node(pc, NodeTypeNullLiteral, token);
*token_index += 1;
return node;
} else if (token->id == TokenIdAtSign) {
*token_index += 1;
Token *name_tok = ast_eat_token(pc, token_index, TokenIdSymbol);
AstNode *name_node = ast_create_node(pc, NodeTypeSymbol, name_tok);
ast_buf_from_token(pc, name_tok, &name_node->data.symbol);
AstNode *node = ast_create_node(pc, NodeTypeFnCallExpr, token);
node->data.fn_call_expr.fn_ref_expr = name_node;
ast_eat_token(pc, token_index, TokenIdLParen);
ast_parse_fn_call_param_list(pc, token_index, &node->data.fn_call_expr.params);
node->data.fn_call_expr.is_builtin = true;
return node;
} else if (token->id == TokenIdSymbol) {
Token *next_token = &pc->tokens->at(*token_index + 1);

View File

@ -176,6 +176,7 @@ struct AstNodeBinOpExpr {
struct AstNodeFnCallExpr {
AstNode *fn_ref_expr;
ZigList<AstNode *> params;
bool is_builtin;
};
struct AstNodeArrayAccessExpr {

View File

@ -376,6 +376,10 @@ void tokenize(Buf *buf, Tokenization *out) {
begin_token(&t, TokenIdTilde);
end_token(&t);
break;
case '@':
begin_token(&t, TokenIdAtSign);
end_token(&t);
break;
case '-':
begin_token(&t, TokenIdDash);
t.state = TokenizeStateSawDash;
@ -1074,6 +1078,7 @@ static const char * token_name(Token *token) {
case TokenIdMaybe: return "Maybe";
case TokenIdDoubleQuestion: return "DoubleQuestion";
case TokenIdMaybeAssign: return "MaybeAssign";
case TokenIdAtSign: return "AtSign";
}
return "(invalid token)";
}

View File

@ -89,6 +89,7 @@ enum TokenId {
TokenIdMaybe,
TokenIdDoubleQuestion,
TokenIdMaybeAssign,
TokenIdAtSign,
};
struct Token {

View File

@ -63,10 +63,6 @@ pub fn parse_u64(buf: []u8, radix: u8, result: &u64) -> bool {
return true;
}
x *= radix;
x += digit;
/* TODO intrinsics mul and add with overflow
// x *= radix
if (@mul_with_overflow_u64(x, radix, &x)) {
return true;
@ -76,7 +72,6 @@ pub fn parse_u64(buf: []u8, radix: u8, result: &u64) -> bool {
if (@add_with_overflow_u64(x, digit, &x)) {
return true;
}
*/
i += 1;
}

View File

@ -953,6 +953,24 @@ fn f(c: u8) -> u8 {
} else {
2
}
}
)SOURCE", "OK\n");
add_simple_case("overflow intrinsics", R"SOURCE(
use "std.zig";
pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
var result: u8;
if (!@add_with_overflow_u8(250, 100, &result)) {
print_str("BAD\n");
}
if (@add_with_overflow_u8(100, 150, &result)) {
print_str("BAD\n");
}
if (result != 250) {
print_str("BAD\n");
}
print_str("OK\n");
return 0;
}
)SOURCE", "OK\n");
}
@ -995,7 +1013,7 @@ fn a() {
b(1);
}
fn b(a: i32, b: i32, c: i32) { }
)SOURCE", 1, ".tmp_source.zig:3:6: error: wrong number of arguments. Expected 3, got 1.");
)SOURCE", 1, ".tmp_source.zig:3:6: error: expected 3 arguments, got 1");
add_compile_fail_case("invalid type", R"SOURCE(
fn a() -> bogus {}