From e26ccd5166000f81a589c446d04102c21045bff6 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 16 Nov 2017 21:15:15 -0500 Subject: [PATCH] debug safety for unions --- src/all_types.hpp | 1 + src/codegen.cpp | 37 ++++++++++++++++++++++++++----------- test/cases/union.zig | 2 +- test/debug_safety.zig | 20 ++++++++++++++++++++ 4 files changed, 48 insertions(+), 12 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 86c9720f69..2b09131bef 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1317,6 +1317,7 @@ enum PanicMsgId { PanicMsgIdUnwrapMaybeFail, PanicMsgIdInvalidErrorCode, PanicMsgIdIncorrectAlignment, + PanicMsgIdBadUnionField, PanicMsgIdCount, }; diff --git a/src/codegen.cpp b/src/codegen.cpp index 3777c3a87a..eb56d26cae 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -810,6 +810,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) { return buf_create_from_str("invalid error code"); case PanicMsgIdIncorrectAlignment: return buf_create_from_str("incorrect alignment"); + case PanicMsgIdBadUnionField: + return buf_create_from_str("access of inactive union field"); } zig_unreachable(); } @@ -2415,6 +2417,23 @@ static LLVMValueRef ir_render_union_field_ptr(CodeGen *g, IrExecutable *executab return bitcasted_union_field_ptr; } + if (ir_want_debug_safety(g, &instruction->base)) { + LLVMValueRef tag_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, union_type->data.unionation.gen_tag_index, ""); + LLVMValueRef tag_value = gen_load_untyped(g, tag_field_ptr, 0, false, ""); + LLVMValueRef expected_tag_value = LLVMConstInt(union_type->data.unionation.tag_type->type_ref, + field->value, false); + + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnionCheckOk"); + LLVMBasicBlockRef bad_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnionCheckFail"); + LLVMValueRef ok_val = LLVMBuildICmp(g->builder, LLVMIntEQ, tag_value, expected_tag_value, ""); + LLVMBuildCondBr(g->builder, ok_val, ok_block, bad_block); + + LLVMPositionBuilderAtEnd(g->builder, bad_block); + gen_debug_safety_crash(g, PanicMsgIdBadUnionField); + + LLVMPositionBuilderAtEnd(g->builder, ok_block); + } + LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, union_type->data.unionation.gen_union_index, ""); LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr, field_type_ref, ""); return bitcasted_union_field_ptr; @@ -3977,21 +3996,17 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) { LLVMValueRef union_value_ref; { - unsigned field_count; - LLVMValueRef fields[2]; - fields[0] = correctly_typed_value; if (pad_bytes == 0) { - field_count = 1; + union_value_ref = correctly_typed_value; } else { + LLVMValueRef fields[2]; fields[0] = correctly_typed_value; fields[1] = LLVMGetUndef(LLVMArrayType(LLVMInt8Type(), (unsigned)pad_bytes)); - field_count = 2; - } - - if (make_unnamed_struct || type_entry->data.unionation.gen_tag_index != SIZE_MAX) { - union_value_ref = LLVMConstStruct(fields, field_count, false); - } else { - union_value_ref = LLVMConstNamedStruct(union_type_ref, fields, field_count); + if (make_unnamed_struct || type_entry->data.unionation.gen_tag_index != SIZE_MAX) { + union_value_ref = LLVMConstStruct(fields, 2, false); + } else { + union_value_ref = LLVMConstNamedStruct(union_type_ref, fields, 2); + } } } diff --git a/test/cases/union.zig b/test/cases/union.zig index 377374c157..1abebb3b30 100644 --- a/test/cases/union.zig +++ b/test/cases/union.zig @@ -41,7 +41,7 @@ const Foo = union { test "basic unions" { var foo = Foo { .int = 1 }; assert(foo.int == 1); - foo.float = 12.34; + foo = Foo {.float = 12.34}; assert(foo.float == 12.34); } diff --git a/test/debug_safety.zig b/test/debug_safety.zig index 9e9ff98349..36f8d020c3 100644 --- a/test/debug_safety.zig +++ b/test/debug_safety.zig @@ -260,4 +260,24 @@ pub fn addCases(cases: &tests.CompareOutputContext) { \\ return int_slice[0]; \\} ); + + cases.addDebugSafety("bad union field access", + \\pub fn panic(message: []const u8) -> noreturn { + \\ @import("std").os.exit(126); + \\} + \\ + \\const Foo = union { + \\ float: f32, + \\ int: u32, + \\}; + \\ + \\pub fn main() -> %void { + \\ var f = Foo { .int = 42 }; + \\ bar(&f); + \\} + \\ + \\fn bar(f: &Foo) { + \\ f.float = 12.34; + \\} + ); }