From 8d0ac6dc4d32daea3561e7de8eeee9ce34d2c5cb Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 17 Mar 2020 17:33:44 -0400 Subject: [PATCH] `@ptrCast` supports casting a slice to pointer --- lib/std/mem.zig | 46 +++++++++++++++++++++++++++++++--------------- src/analyze.cpp | 20 +++++++++++++++++--- src/analyze.hpp | 2 +- src/ir.cpp | 44 +++++++++++++++++++++++++++++++++++++------- 4 files changed, 86 insertions(+), 26 deletions(-) diff --git a/lib/std/mem.zig b/lib/std/mem.zig index 2ffac680c7..eb2539720a 100644 --- a/lib/std/mem.zig +++ b/lib/std/mem.zig @@ -1750,34 +1750,50 @@ fn BytesAsSliceReturnType(comptime T: type, comptime bytesType: type) type { } pub fn bytesAsSlice(comptime T: type, bytes: var) BytesAsSliceReturnType(T, @TypeOf(bytes)) { - const bytesSlice = if (comptime trait.isPtrTo(.Array)(@TypeOf(bytes))) bytes[0..] else bytes; - // let's not give an undefined pointer to @ptrCast // it may be equal to zero and fail a null check - if (bytesSlice.len == 0) { + if (bytes.len == 0) { return &[0]T{}; } - const bytesType = @TypeOf(bytesSlice); - const alignment = comptime meta.alignment(bytesType); + const Bytes = @TypeOf(bytes); + const alignment = comptime meta.alignment(Bytes); - const castTarget = if (comptime trait.isConstPtr(bytesType)) [*]align(alignment) const T else [*]align(alignment) T; + const cast_target = if (comptime trait.isConstPtr(Bytes)) [*]align(alignment) const T else [*]align(alignment) T; - return @ptrCast(castTarget, bytesSlice.ptr)[0..@divExact(bytes.len, @sizeOf(T))]; + return @ptrCast(cast_target, bytes)[0..@divExact(bytes.len, @sizeOf(T))]; } test "bytesAsSlice" { - const bytes = [_]u8{ 0xDE, 0xAD, 0xBE, 0xEF }; - const slice = bytesAsSlice(u16, bytes[0..]); - testing.expect(slice.len == 2); - testing.expect(bigToNative(u16, slice[0]) == 0xDEAD); - testing.expect(bigToNative(u16, slice[1]) == 0xBEEF); + { + const bytes = [_]u8{ 0xDE, 0xAD, 0xBE, 0xEF }; + const slice = bytesAsSlice(u16, bytes[0..]); + testing.expect(slice.len == 2); + testing.expect(bigToNative(u16, slice[0]) == 0xDEAD); + testing.expect(bigToNative(u16, slice[1]) == 0xBEEF); + } + { + const bytes = [_]u8{ 0xDE, 0xAD, 0xBE, 0xEF }; + var runtime_zero: usize = 0; + const slice = bytesAsSlice(u16, bytes[runtime_zero..]); + testing.expect(slice.len == 2); + testing.expect(bigToNative(u16, slice[0]) == 0xDEAD); + testing.expect(bigToNative(u16, slice[1]) == 0xBEEF); + } } test "bytesAsSlice keeps pointer alignment" { - var bytes = [_]u8{ 0x01, 0x02, 0x03, 0x04 }; - const numbers = bytesAsSlice(u32, bytes[0..]); - comptime testing.expect(@TypeOf(numbers) == []align(@alignOf(@TypeOf(bytes))) u32); + { + var bytes = [_]u8{ 0x01, 0x02, 0x03, 0x04 }; + const numbers = bytesAsSlice(u32, bytes[0..]); + comptime testing.expect(@TypeOf(numbers) == []align(@alignOf(@TypeOf(bytes))) u32); + } + { + var bytes = [_]u8{ 0x01, 0x02, 0x03, 0x04 }; + var runtime_zero: usize = 0; + const numbers = bytesAsSlice(u32, bytes[runtime_zero..]); + comptime testing.expect(@TypeOf(numbers) == []align(@alignOf(@TypeOf(bytes))) u32); + } } test "bytesAsSlice on a packed struct" { diff --git a/src/analyze.cpp b/src/analyze.cpp index afffe47c82..1f511fa834 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -4486,7 +4486,14 @@ static uint32_t get_async_frame_align_bytes(CodeGen *g) { } uint32_t get_ptr_align(CodeGen *g, ZigType *type) { - ZigType *ptr_type = get_src_ptr_type(type); + ZigType *ptr_type; + if (type->id == ZigTypeIdStruct) { + assert(type->data.structure.special == StructSpecialSlice); + TypeStructField *ptr_field = type->data.structure.fields[slice_ptr_index]; + ptr_type = resolve_struct_field_type(g, ptr_field); + } else { + ptr_type = get_src_ptr_type(type); + } if (ptr_type->id == ZigTypeIdPointer) { return (ptr_type->data.pointer.explicit_alignment == 0) ? get_abi_alignment(g, ptr_type->data.pointer.child_type) : ptr_type->data.pointer.explicit_alignment; @@ -4503,8 +4510,15 @@ uint32_t get_ptr_align(CodeGen *g, ZigType *type) { } } -bool get_ptr_const(ZigType *type) { - ZigType *ptr_type = get_src_ptr_type(type); +bool get_ptr_const(CodeGen *g, ZigType *type) { + ZigType *ptr_type; + if (type->id == ZigTypeIdStruct) { + assert(type->data.structure.special == StructSpecialSlice); + TypeStructField *ptr_field = type->data.structure.fields[slice_ptr_index]; + ptr_type = resolve_struct_field_type(g, ptr_field); + } else { + ptr_type = get_src_ptr_type(type); + } if (ptr_type->id == ZigTypeIdPointer) { return ptr_type->data.pointer.is_const; } else if (ptr_type->id == ZigTypeIdFn) { diff --git a/src/analyze.hpp b/src/analyze.hpp index c7cbab4b8c..b6404b1882 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -76,7 +76,7 @@ void resolve_top_level_decl(CodeGen *g, Tld *tld, AstNode *source_node, bool all ZigType *get_src_ptr_type(ZigType *type); uint32_t get_ptr_align(CodeGen *g, ZigType *type); -bool get_ptr_const(ZigType *type); +bool get_ptr_const(CodeGen *g, ZigType *type); ZigType *validate_var_type(CodeGen *g, AstNode *source_node, ZigType *type_entry); ZigType *container_ref_type(ZigType *type_entry); bool type_is_complete(ZigType *type_entry); diff --git a/src/ir.cpp b/src/ir.cpp index 5558560606..15edfbfc7b 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -25479,11 +25479,22 @@ static IrInstGen *ir_analyze_instruction_err_set_cast(IrAnalyze *ira, IrInstSrcE static Error resolve_ptr_align(IrAnalyze *ira, ZigType *ty, uint32_t *result_align) { Error err; - ZigType *ptr_type = get_src_ptr_type(ty); + ZigType *ptr_type; + if (is_slice(ty)) { + TypeStructField *ptr_field = ty->data.structure.fields[slice_ptr_index]; + ptr_type = resolve_struct_field_type(ira->codegen, ptr_field); + } else { + ptr_type = get_src_ptr_type(ty); + } assert(ptr_type != nullptr); if (ptr_type->id == ZigTypeIdPointer) { if ((err = type_resolve(ira->codegen, ptr_type->data.pointer.child_type, ResolveStatusAlignmentKnown))) return err; + } else if (is_slice(ptr_type)) { + TypeStructField *ptr_field = ptr_type->data.structure.fields[slice_ptr_index]; + ZigType *slice_ptr_type = resolve_struct_field_type(ira->codegen, ptr_field); + if ((err = type_resolve(ira->codegen, slice_ptr_type->data.pointer.child_type, ResolveStatusAlignmentKnown))) + return err; } *result_align = get_ptr_align(ira->codegen, ty); @@ -27615,10 +27626,18 @@ static IrInstGen *ir_analyze_ptr_cast(IrAnalyze *ira, IrInst* source_instr, IrIn // We have a check for zero bits later so we use get_src_ptr_type to // validate src_type and dest_type. - ZigType *src_ptr_type = get_src_ptr_type(src_type); - if (src_ptr_type == nullptr) { - ir_add_error(ira, ptr_src, buf_sprintf("expected pointer, found '%s'", buf_ptr(&src_type->name))); - return ira->codegen->invalid_inst_gen; + ZigType *if_slice_ptr_type; + if (is_slice(src_type)) { + TypeStructField *ptr_field = src_type->data.structure.fields[slice_ptr_index]; + if_slice_ptr_type = resolve_struct_field_type(ira->codegen, ptr_field); + } else { + if_slice_ptr_type = src_type; + + ZigType *src_ptr_type = get_src_ptr_type(src_type); + if (src_ptr_type == nullptr) { + ir_add_error(ira, ptr_src, buf_sprintf("expected pointer, found '%s'", buf_ptr(&src_type->name))); + return ira->codegen->invalid_inst_gen; + } } ZigType *dest_ptr_type = get_src_ptr_type(dest_type); @@ -27628,7 +27647,7 @@ static IrInstGen *ir_analyze_ptr_cast(IrAnalyze *ira, IrInst* source_instr, IrIn return ira->codegen->invalid_inst_gen; } - if (get_ptr_const(src_type) && !get_ptr_const(dest_type)) { + if (get_ptr_const(ira->codegen, src_type) && !get_ptr_const(ira->codegen, dest_type)) { ir_add_error(ira, source_instr, buf_sprintf("cast discards const qualifier")); return ira->codegen->invalid_inst_gen; } @@ -27646,7 +27665,10 @@ static IrInstGen *ir_analyze_ptr_cast(IrAnalyze *ira, IrInst* source_instr, IrIn if ((err = type_resolve(ira->codegen, src_type, ResolveStatusZeroBitsKnown))) return ira->codegen->invalid_inst_gen; - if (type_has_bits(ira->codegen, dest_type) && !type_has_bits(ira->codegen, src_type) && safety_check_on) { + if (safety_check_on && + type_has_bits(ira->codegen, dest_type) && + !type_has_bits(ira->codegen, if_slice_ptr_type)) + { ErrorMsg *msg = ir_add_error(ira, source_instr, buf_sprintf("'%s' and '%s' do not have the same in-memory representation", buf_ptr(&src_type->name), buf_ptr(&dest_type->name))); @@ -27657,6 +27679,14 @@ static IrInstGen *ir_analyze_ptr_cast(IrAnalyze *ira, IrInst* source_instr, IrIn return ira->codegen->invalid_inst_gen; } + // For slices, follow the `ptr` field. + if (is_slice(src_type)) { + TypeStructField *ptr_field = src_type->data.structure.fields[slice_ptr_index]; + IrInstGen *ptr_ref = ir_get_ref(ira, source_instr, ptr, true, false); + IrInstGen *ptr_ptr = ir_analyze_struct_field_ptr(ira, source_instr, ptr_field, ptr_ref, src_type, false); + ptr = ir_get_deref(ira, source_instr, ptr_ptr, nullptr); + } + if (instr_is_comptime(ptr)) { bool dest_allows_addr_zero = ptr_allows_addr_zero(dest_type); UndefAllowed is_undef_allowed = dest_allows_addr_zero ? UndefOk : UndefBad;