From adfcc8851b6bb47b085cfe2526f0797b1f414996 Mon Sep 17 00:00:00 2001 From: John Schmidt Date: Thu, 17 Mar 2022 17:25:42 +0100 Subject: [PATCH 1/3] Implement `@byteSwap` for vectors Make the behavior tests for this a little more primitive to exercise as little extra functionality as possible. --- src/Sema.zig | 85 ++++++++++++++++++++++++++++-------- src/codegen/llvm.zig | 23 ++++++++-- test/behavior/byteswap.zig | 89 ++++++++++++++++++++++++++++---------- 3 files changed, 153 insertions(+), 44 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index 7805d7f095..4e30953fd6 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -13491,30 +13491,77 @@ fn zirByteSwap(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node }; const operand = sema.resolveInst(inst_data.operand); const operand_ty = sema.typeOf(operand); - // TODO implement support for vectors - if (operand_ty.zigTypeTag() != .Int) { - return sema.fail(block, ty_src, "expected integer type, found '{}'", .{ - operand_ty, - }); + + const scalar_ty = if (operand_ty.zigTypeTag() == .Vector) + operand_ty.elemType2() + else + operand_ty; + + switch (operand_ty.zigTypeTag()) { + .Int, .ComptimeInt => {}, + .Vector => { + switch (scalar_ty.zigTypeTag()) { + .Int, .ComptimeInt => {}, + else => return sema.fail(block, ty_src, "expected vector of integer type, found vector of '{}'", .{scalar_ty}), + } + }, + else => return sema.fail(block, ty_src, "expected integer type or vector of integer type, found '{}'", .{operand_ty}), } + const target = sema.mod.getTarget(); - const bits = operand_ty.intInfo(target).bits; - if (bits == 0) return Air.Inst.Ref.zero; - if (operand_ty.intInfo(target).bits % 8 != 0) { - return sema.fail(block, ty_src, "@byteSwap requires the number of bits to be evenly divisible by 8, but {} has {} bits", .{ - operand_ty, - operand_ty.intInfo(target).bits, - }); + const bits = scalar_ty.intInfo(target).bits; + if (bits % 8 != 0) { + return sema.fail( + block, + ty_src, + "@byteSwap requires the number of bits to be evenly divisible by 8, but {} has {} bits", + .{ scalar_ty, bits }, + ); } - const runtime_src = if (try sema.resolveMaybeUndefVal(block, operand_src, operand)) |val| { - if (val.isUndef()) return sema.addConstUndef(operand_ty); - const result_val = try val.byteSwap(operand_ty, target, sema.arena); - return sema.addConstant(operand_ty, result_val); - } else operand_src; + switch (operand_ty.zigTypeTag()) { + .Int, .ComptimeInt => { + if (bits == 0) return Air.Inst.Ref.zero; - try sema.requireRuntimeBlock(block, runtime_src); - return block.addTyOp(.byte_swap, operand_ty, operand); + const runtime_src = if (try sema.resolveMaybeUndefVal(block, operand_src, operand)) |val| { + if (val.isUndef()) return sema.addConstUndef(operand_ty); + const result_val = try val.byteSwap(operand_ty, target, sema.arena); + return sema.addConstant(operand_ty, result_val); + } else operand_src; + + try sema.requireRuntimeBlock(block, runtime_src); + return block.addTyOp(.byte_swap, operand_ty, operand); + }, + .Vector => { + if (bits == 0) { + return sema.addConstant( + operand_ty, + try Value.Tag.repeated.create(sema.arena, Value.zero), + ); + } + + const runtime_src = if (try sema.resolveMaybeUndefVal(block, operand_src, operand)) |val| { + if (val.isUndef()) + return sema.addConstUndef(operand_ty); + + const vec_len = operand_ty.vectorLen(); + var elem_buf: Value.ElemValueBuffer = undefined; + const elems = try sema.arena.alloc(Value, vec_len); + for (elems) |*elem, i| { + const elem_val = val.elemValueBuffer(i, &elem_buf); + elem.* = try elem_val.byteSwap(operand_ty, target, sema.arena); + } + return sema.addConstant( + operand_ty, + try Value.Tag.aggregate.create(sema.arena, elems), + ); + } else operand_src; + + try sema.requireRuntimeBlock(block, runtime_src); + return block.addTyOp(.byte_swap, operand_ty, operand); + }, + else => unreachable, + } } fn zirBitReverse(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 481730452c..5a231ddc6e 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -6078,9 +6078,26 @@ pub const FuncGen = struct { if (bits % 16 == 8) { // If not an even byte-multiple, we need zero-extend + shift-left 1 byte // The truncated result at the end will be the correct bswap - operand_llvm_ty = self.context.intType(bits + 8); - const extended = self.builder.buildZExt(operand, operand_llvm_ty, ""); - operand = self.builder.buildShl(extended, operand_llvm_ty.constInt(8, .False), ""); + const scalar_llvm_ty = self.context.intType(bits + 8); + if (operand_ty.zigTypeTag() == .Vector) { + const vec_len = operand_ty.vectorLen(); + operand_llvm_ty = scalar_llvm_ty.vectorType(vec_len); + + const shifts = try self.gpa.alloc(*const llvm.Value, vec_len); + defer self.gpa.free(shifts); + + for (shifts) |*elem| { + elem.* = scalar_llvm_ty.constInt(8, .False); + } + const shift_vec = llvm.constVector(shifts.ptr, vec_len); + + const extended = self.builder.buildZExt(operand, operand_llvm_ty, ""); + operand = self.builder.buildShl(extended, shift_vec, ""); + } else { + const extended = self.builder.buildZExt(operand, scalar_llvm_ty, ""); + operand = self.builder.buildShl(extended, scalar_llvm_ty.constInt(8, .False), ""); + operand_llvm_ty = scalar_llvm_ty; + } bits = bits + 8; } diff --git a/test/behavior/byteswap.zig b/test/behavior/byteswap.zig index 15046cc0e1..6fd8867d42 100644 --- a/test/behavior/byteswap.zig +++ b/test/behavior/byteswap.zig @@ -52,32 +52,77 @@ test "@byteSwap integers" { try ByteSwapIntTest.run(); } -test "@byteSwap vectors" { - if (builtin.zig_backend == .stage2_llvm) return error.SkipZigTest; +fn vector8() !void { + var v = @Vector(2, u8){ 0x12, 0x13 }; + var result = @byteSwap(u8, v); + try expect(result[0] == 0x12); + try expect(result[1] == 0x13); +} + +test "@byteSwap vectors u8" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; - const ByteSwapVectorTest = struct { - fn run() !void { - try t(u8, 2, [_]u8{ 0x12, 0x13 }, [_]u8{ 0x12, 0x13 }); - try t(u16, 2, [_]u16{ 0x1234, 0x2345 }, [_]u16{ 0x3412, 0x4523 }); - try t(u24, 2, [_]u24{ 0x123456, 0x234567 }, [_]u24{ 0x563412, 0x674523 }); - } - - fn t( - comptime I: type, - comptime n: comptime_int, - input: std.meta.Vector(n, I), - expected_vector: std.meta.Vector(n, I), - ) !void { - const actual_output: [n]I = @byteSwap(I, input); - const expected_output: [n]I = expected_vector; - try std.testing.expectEqual(expected_output, actual_output); - } - }; - comptime try ByteSwapVectorTest.run(); - try ByteSwapVectorTest.run(); + comptime try vector8(); + try vector8(); +} + +fn vector16() !void { + var v = @Vector(2, u16){ 0x1234, 0x2345 }; + var result = @byteSwap(u16, v); + try expect(result[0] == 0x3412); + try expect(result[1] == 0x4523); +} + +test "@byteSwap vectors u16" { + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; + + comptime try vector16(); + try vector16(); +} + +fn vector24() !void { + var v = @Vector(2, u24){ 0x123456, 0x234567 }; + var result = @byteSwap(u24, v); + try expect(result[0] == 0x563412); + try expect(result[1] == 0x674523); +} + +test "@byteSwap vectors u24" { + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; + + comptime try vector24(); + try vector24(); +} + +fn vector0() !void { + var v = @Vector(2, u0){ 0, 0 }; + var result = @byteSwap(u0, v); + try expect(result[0] == 0); + try expect(result[1] == 0); +} + +test "@byteSwap vectors u0" { + // TODO: vector initialization for @Vector(x, u0) currently fails. + if (builtin.zig_backend == .stage2_llvm) return error.SkipZigTest; + + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; + + comptime try vector0(); + try vector0(); } From d7d2ccb7af7349617fa76eef3bdead7c259654a5 Mon Sep 17 00:00:00 2001 From: John Schmidt Date: Thu, 17 Mar 2022 22:29:39 +0100 Subject: [PATCH 2/3] Avoid index out of bounds for one-valued types in zirValidateArrayInit Previously, the code assumed that `ptr_elem_ptr` was always followed by a `store`, but this is not true for types with one value (such as `u0`). --- src/Sema.zig | 14 ++++++-------- test/behavior/byteswap.zig | 3 --- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index 4e30953fd6..7c7eef4818 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -3205,14 +3205,11 @@ fn zirValidateArrayInit( // instruction after it within the same block. // Possible performance enhancement: save the `block_index` between iterations // of the for loop. - const next_air_inst = inst: { - var block_index = block.instructions.items.len - 1; - while (block.instructions.items[block_index] != elem_ptr_air_inst) { - block_index -= 1; - } - first_block_index = @minimum(first_block_index, block_index); - break :inst block.instructions.items[block_index + 1]; - }; + var block_index = block.instructions.items.len - 1; + while (block.instructions.items[block_index] != elem_ptr_air_inst) { + block_index -= 1; + } + first_block_index = @minimum(first_block_index, block_index); // Array has one possible value, so value is always comptime-known if (opt_opv) |opv| { @@ -3222,6 +3219,7 @@ fn zirValidateArrayInit( // If the next instructon is a store with a comptime operand, this element // is comptime. + const next_air_inst = block.instructions.items[block_index + 1]; switch (air_tags[next_air_inst]) { .store => { const bin_op = air_datas[next_air_inst].bin_op; diff --git a/test/behavior/byteswap.zig b/test/behavior/byteswap.zig index 6fd8867d42..6e317b7fc2 100644 --- a/test/behavior/byteswap.zig +++ b/test/behavior/byteswap.zig @@ -114,9 +114,6 @@ fn vector0() !void { } test "@byteSwap vectors u0" { - // TODO: vector initialization for @Vector(x, u0) currently fails. - if (builtin.zig_backend == .stage2_llvm) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; From 4fa506063373b38bed842af53e6405dec7107534 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 17 Mar 2022 18:04:06 -0700 Subject: [PATCH 3/3] Sema: take advantage of checkIntOrVectorAllowComptime --- src/Sema.zig | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index 7c7eef4818..8fda67b652 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -13489,23 +13489,7 @@ fn zirByteSwap(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node }; const operand = sema.resolveInst(inst_data.operand); const operand_ty = sema.typeOf(operand); - - const scalar_ty = if (operand_ty.zigTypeTag() == .Vector) - operand_ty.elemType2() - else - operand_ty; - - switch (operand_ty.zigTypeTag()) { - .Int, .ComptimeInt => {}, - .Vector => { - switch (scalar_ty.zigTypeTag()) { - .Int, .ComptimeInt => {}, - else => return sema.fail(block, ty_src, "expected vector of integer type, found vector of '{}'", .{scalar_ty}), - } - }, - else => return sema.fail(block, ty_src, "expected integer type or vector of integer type, found '{}'", .{operand_ty}), - } - + const scalar_ty = try sema.checkIntOrVectorAllowComptime(block, operand, operand_src); const target = sema.mod.getTarget(); const bits = scalar_ty.intInfo(target).bits; if (bits % 8 != 0) {