From 799fedf612aa8742c446b015c12d21707a1dbec0 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 7 Aug 2021 20:34:28 -0700 Subject: [PATCH] stage2: pass some error union tests * Value: rename `error_union` to `eu_payload` and clarify the intended usage in the doc comments. The way error unions is represented with Value is fixed to not have ambiguous values. * Fix codegen for error union constants in all the backends. * Implement the AIR instructions having to do with error unions in the LLVM backend. --- src/Sema.zig | 16 ++++---- src/codegen.zig | 2 +- src/codegen/c.zig | 25 +++++------- src/codegen/llvm.zig | 78 ++++++++++++++++++++++++------------- src/codegen/wasm.zig | 18 ++++----- src/value.zig | 43 ++++++++++++++------ test/behavior.zig | 3 +- test/behavior/if.zig | 42 -------------------- test/behavior/if_stage1.zig | 45 +++++++++++++++++++++ 9 files changed, 155 insertions(+), 117 deletions(-) create mode 100644 test/behavior/if_stage1.zig diff --git a/src/Sema.zig b/src/Sema.zig index a783a48c64..96a09553f5 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -3468,7 +3468,7 @@ fn zirErrUnionPayload( if (val.getError()) |name| { return sema.mod.fail(&block.base, src, "caught unexpected error '{s}'", .{name}); } - const data = val.castTag(.error_union).?.data; + const data = val.castTag(.eu_payload).?.data; const result_ty = operand_ty.errorUnionPayload(); return sema.addConstant(result_ty, data); } @@ -3539,8 +3539,7 @@ fn zirErrUnionCode(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) Compi if (try sema.resolveDefinedValue(block, src, operand)) |val| { assert(val.getError() != null); - const data = val.castTag(.error_union).?.data; - return sema.addConstant(result_ty, data); + return sema.addConstant(result_ty, val); } try sema.requireRuntimeBlock(block, src); @@ -3566,8 +3565,7 @@ fn zirErrUnionCodePtr(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) Co if (try sema.resolveDefinedValue(block, src, operand)) |pointer_val| { if (try pointer_val.pointerDeref(sema.arena)) |val| { assert(val.getError() != null); - const data = val.castTag(.error_union).?.data; - return sema.addConstant(result_ty, data); + return sema.addConstant(result_ty, val); } } @@ -8900,7 +8898,9 @@ fn wrapErrorUnion( if (try sema.resolveMaybeUndefVal(block, inst_src, inst)) |val| { if (inst_ty.zigTypeTag() != .ErrorSet) { _ = try sema.coerce(block, dest_payload_ty, inst, inst_src); - } else switch (dest_err_set_ty.tag()) { + return sema.addConstant(dest_type, try Value.Tag.eu_payload.create(sema.arena, val)); + } + switch (dest_err_set_ty.tag()) { .anyerror => {}, .error_set_single => { const expected_name = val.castTag(.@"error").?.data.name; @@ -8946,9 +8946,7 @@ fn wrapErrorUnion( }, else => unreachable, } - - // Create a SubValue for the error_union payload. - return sema.addConstant(dest_type, try Value.Tag.error_union.create(sema.arena, val)); + return sema.addConstant(dest_type, val); } try sema.requireRuntimeBlock(block, inst_src); diff --git a/src/codegen.zig b/src/codegen.zig index f5cdc518f6..38cc27d5bc 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -4815,7 +4815,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { .ErrorUnion => { const error_type = typed_value.ty.errorUnionSet(); const payload_type = typed_value.ty.errorUnionPayload(); - const sub_val = typed_value.val.castTag(.error_union).?.data; + const sub_val = typed_value.val.castTag(.eu_payload).?.data; if (!payload_type.hasCodeGenBits()) { // We use the error type directly as the type. diff --git a/src/codegen/c.zig b/src/codegen/c.zig index 65ad4bac8e..a67e2438c2 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -350,32 +350,25 @@ pub const DeclGen = struct { .ErrorUnion => { const error_type = t.errorUnionSet(); const payload_type = t.errorUnionPayload(); - const sub_val = val.castTag(.error_union).?.data; if (!payload_type.hasCodeGenBits()) { // We use the error type directly as the type. - return dg.renderValue(writer, error_type, sub_val); + const err_val = if (val.errorUnionIsPayload()) Value.initTag(.zero) else val; + return dg.renderValue(writer, error_type, err_val); } try writer.writeByte('('); try dg.renderType(writer, t); try writer.writeAll("){"); - if (val.getError()) |_| { - try writer.writeAll(" .error = "); - try dg.renderValue( - writer, - error_type, - sub_val, - ); - try writer.writeAll(" }"); - } else { + if (val.castTag(.eu_payload)) |pl| { + const payload_val = pl.data; try writer.writeAll(" .payload = "); - try dg.renderValue( - writer, - payload_type, - sub_val, - ); + try dg.renderValue(writer, payload_type, payload_val); try writer.writeAll(", .error = 0 }"); + } else { + try writer.writeAll(" .error = "); + try dg.renderValue(writer, error_type, val); + try writer.writeAll(" }"); } }, .Enum => { diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index eccd5fa04f..7cfbc8da5e 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -593,7 +593,7 @@ pub const DeclGen = struct { try self.llvmType(ptr_type), try self.llvmType(Type.initTag(.usize)), }; - return self.context.structType(&fields, 2, .False); + return self.context.structType(&fields, fields.len, .False); } else { const elem_type = try self.llvmType(t.elemType()); return elem_type.pointerType(0); @@ -621,10 +621,14 @@ pub const DeclGen = struct { .ErrorUnion => { const error_type = t.errorUnionSet(); const payload_type = t.errorUnionPayload(); + const llvm_error_type = try self.llvmType(error_type); if (!payload_type.hasCodeGenBits()) { - return self.llvmType(error_type); + return llvm_error_type; } - return self.todo("implement llvmType for error unions", .{}); + const llvm_payload_type = try self.llvmType(payload_type); + + const fields: [2]*const llvm.Type = .{ llvm_error_type, llvm_payload_type }; + return self.context.structType(&fields, fields.len, .False); }, .ErrorSet => { return self.context.intType(16); @@ -846,14 +850,25 @@ pub const DeclGen = struct { .ErrorUnion => { const error_type = tv.ty.errorUnionSet(); const payload_type = tv.ty.errorUnionPayload(); - const sub_val = tv.val.castTag(.error_union).?.data; + const is_pl = tv.val.errorUnionIsPayload(); if (!payload_type.hasCodeGenBits()) { // We use the error type directly as the type. - return self.genTypedValue(.{ .ty = error_type, .val = sub_val }); + const err_val = if (!is_pl) tv.val else Value.initTag(.zero); + return self.genTypedValue(.{ .ty = error_type, .val = err_val }); } - return self.todo("implement error union const of type '{}'", .{tv.ty}); + const fields: [2]*const llvm.Value = .{ + try self.genTypedValue(.{ + .ty = error_type, + .val = if (is_pl) Value.initTag(.zero) else tv.val, + }), + try self.genTypedValue(.{ + .ty = payload_type, + .val = if (tv.val.castTag(.eu_payload)) |pl| pl.data else Value.initTag(.undef), + }), + }; + return self.context.constStruct(&fields, fields.len, .False); }, .Struct => { const fields_len = tv.ty.structFieldCount(); @@ -984,10 +999,10 @@ pub const FuncGen = struct { .is_non_null_ptr => try self.airIsNonNull(inst, true), .is_null => try self.airIsNull(inst, false), .is_null_ptr => try self.airIsNull(inst, true), - .is_non_err => try self.airIsErr(inst, true, false), - .is_non_err_ptr => try self.airIsErr(inst, true, true), - .is_err => try self.airIsErr(inst, false, false), - .is_err_ptr => try self.airIsErr(inst, false, true), + .is_non_err => try self.airIsErr(inst, .EQ, false), + .is_non_err_ptr => try self.airIsErr(inst, .EQ, true), + .is_err => try self.airIsErr(inst, .NE, false), + .is_err_ptr => try self.airIsErr(inst, .NE, true), .alloc => try self.airAlloc(inst), .arg => try self.airArg(inst), @@ -1098,7 +1113,7 @@ pub const FuncGen = struct { const inst_ty = self.air.typeOfIndex(inst); switch (self.air.typeOf(bin_op.lhs).zigTypeTag()) { - .Int, .Bool, .Pointer => { + .Int, .Bool, .Pointer, .ErrorSet => { const is_signed = inst_ty.isSignedInt(); const operation = switch (op) { .eq => .EQ, @@ -1256,12 +1271,7 @@ pub const FuncGen = struct { const rhs = try self.resolveInst(bin_op.rhs); const base_ptr = ptr: { - const index_type = self.context.intType(32); - const indices: [2]*const llvm.Value = .{ - index_type.constNull(), - index_type.constInt(0, .False), - }; - const ptr_field_ptr = self.builder.buildInBoundsGEP(lhs, &indices, 2, ""); + const ptr_field_ptr = self.builder.buildStructGEP(lhs, 0, ""); break :ptr self.builder.buildLoad(ptr_field_ptr, ""); }; @@ -1472,7 +1482,7 @@ pub const FuncGen = struct { index_type.constInt(1, .False), }; - return self.builder.buildLoad(self.builder.buildInBoundsGEP(operand, &indices, 2, ""), ""); + return self.builder.buildLoad(self.builder.buildInBoundsGEP(operand, &indices, indices.len, ""), ""); } else { return self.builder.buildExtractValue(operand, 1, ""); } @@ -1488,7 +1498,7 @@ pub const FuncGen = struct { fn airIsErr( self: *FuncGen, inst: Air.Inst.Index, - invert_logic: bool, + op: llvm.IntPredicate, operand_is_ptr: bool, ) !?*const llvm.Value { if (self.liveness.isUnused(inst)) @@ -1498,16 +1508,22 @@ pub const FuncGen = struct { const operand = try self.resolveInst(un_op); const err_union_ty = self.air.typeOf(un_op); const payload_ty = err_union_ty.errorUnionPayload(); + const err_set_ty = try self.dg.llvmType(Type.initTag(.anyerror)); + const zero = err_set_ty.constNull(); if (!payload_ty.hasCodeGenBits()) { const loaded = if (operand_is_ptr) self.builder.buildLoad(operand, "") else operand; - const op: llvm.IntPredicate = if (invert_logic) .EQ else .NE; - const err_set_ty = try self.dg.llvmType(Type.initTag(.anyerror)); - const zero = err_set_ty.constNull(); return self.builder.buildICmp(op, loaded, zero, ""); } - return self.todo("implement 'airIsErr' for error unions with nonzero payload", .{}); + if (operand_is_ptr) { + const err_field_ptr = self.builder.buildStructGEP(operand, 0, ""); + const loaded = self.builder.buildLoad(err_field_ptr, ""); + return self.builder.buildICmp(op, loaded, zero, ""); + } + + const loaded = self.builder.buildExtractValue(operand, 0, ""); + return self.builder.buildICmp(op, loaded, zero, ""); } fn airOptionalPayload( @@ -1552,9 +1568,11 @@ pub const FuncGen = struct { return null; } - _ = operand; - _ = operand_is_ptr; - return self.todo("implement llvm codegen for 'airErrUnionPayload' for type {}", .{self.air.typeOf(ty_op.operand)}); + if (operand_is_ptr) { + return self.builder.buildStructGEP(operand, 1, ""); + } + + return self.builder.buildExtractValue(operand, 1, ""); } fn airErrUnionErr( @@ -1574,7 +1592,13 @@ pub const FuncGen = struct { if (!operand_is_ptr) return operand; return self.builder.buildLoad(operand, ""); } - return self.todo("implement llvm codegen for 'airErrUnionErr'", .{}); + + if (operand_is_ptr) { + const err_field_ptr = self.builder.buildStructGEP(operand, 0, ""); + return self.builder.buildLoad(err_field_ptr, ""); + } + + return self.builder.buildExtractValue(operand, 0, ""); } fn airWrapOptional(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { diff --git a/src/codegen/wasm.zig b/src/codegen/wasm.zig index e21645d1ee..4814ba0b55 100644 --- a/src/codegen/wasm.zig +++ b/src/codegen/wasm.zig @@ -1167,12 +1167,18 @@ pub const Context = struct { try leb.writeULEB128(writer, error_index); }, .ErrorUnion => { - const data = val.castTag(.error_union).?.data; const error_type = ty.errorUnionSet(); const payload_type = ty.errorUnionPayload(); - if (val.getError()) |_| { + if (val.castTag(.eu_payload)) |pl| { + const payload_val = pl.data; + // no error, so write a '0' const + try writer.writeByte(wasm.opcode(.i32_const)); + try leb.writeULEB128(writer, @as(u32, 0)); + // after the error code, we emit the payload + try self.emitConstant(payload_val, payload_type); + } else { // write the error val - try self.emitConstant(data, error_type); + try self.emitConstant(val, error_type); // no payload, so write a '0' const const opcode: wasm.Opcode = buildOpcode(.{ @@ -1181,12 +1187,6 @@ pub const Context = struct { }); try writer.writeByte(wasm.opcode(opcode)); try leb.writeULEB128(writer, @as(u32, 0)); - } else { - // no error, so write a '0' const - try writer.writeByte(wasm.opcode(.i32_const)); - try leb.writeULEB128(writer, @as(u32, 0)); - // after the error code, we emit the payload - try self.emitConstant(data, payload_type); } }, .Optional => { diff --git a/src/value.zig b/src/value.zig index bf80c9d831..562d7171e8 100644 --- a/src/value.zig +++ b/src/value.zig @@ -129,7 +129,13 @@ pub const Value = extern union { /// A specific enum tag, indicated by the field index (declaration order). enum_field_index, @"error", - error_union, + /// When the type is error union: + /// * If the tag is `.@"error"`, the error union is an error. + /// * If the tag is `.eu_payload`, the error union is a payload. + /// * A nested error such as `((anyerror!T1)!T2)` in which the the outer error union + /// is non-error, but the inner error union is an error, is represented as + /// a tag of `.eu_payload`, with a sub-tag of `.@"error"`. + eu_payload, /// A pointer to the payload of an error union, based on a pointer to an error union. eu_payload_ptr, /// An instance of a struct. @@ -228,7 +234,7 @@ pub const Value = extern union { => Payload.Decl, .repeated, - .error_union, + .eu_payload, .eu_payload_ptr, => Payload.SubValue, @@ -450,7 +456,7 @@ pub const Value = extern union { return Value{ .ptr_otherwise = &new_payload.base }; }, .bytes => return self.copyPayloadShallow(allocator, Payload.Bytes), - .repeated, .error_union, .eu_payload_ptr => { + .repeated, .eu_payload, .eu_payload_ptr => { const payload = self.cast(Payload.SubValue).?; const new_payload = try allocator.create(Payload.SubValue); new_payload.* = .{ @@ -642,7 +648,10 @@ pub const Value = extern union { .float_128 => return out_stream.print("{}", .{val.castTag(.float_128).?.data}), .@"error" => return out_stream.print("error.{s}", .{val.castTag(.@"error").?.data.name}), // TODO to print this it should be error{ Set, Items }!T(val), but we need the type for that - .error_union => return out_stream.print("error_union_val({})", .{val.castTag(.error_union).?.data}), + .eu_payload => { + try out_stream.writeAll("(eu_payload) "); + val = val.castTag(.eu_payload).?.data; + }, .inferred_alloc => return out_stream.writeAll("(inferred allocation value)"), .inferred_alloc_comptime => return out_stream.writeAll("(inferred comptime allocation value)"), .eu_payload_ptr => { @@ -1241,7 +1250,7 @@ pub const Value = extern union { .eu_payload_ptr => blk: { const err_union_ptr = self.castTag(.eu_payload_ptr).?.data; const err_union_val = (try err_union_ptr.pointerDeref(allocator)) orelse return null; - break :blk err_union_val.castTag(.error_union).?.data; + break :blk err_union_val.castTag(.eu_payload).?.data; }, .zero, @@ -1351,16 +1360,16 @@ pub const Value = extern union { } /// Valid for all types. Asserts the value is not undefined and not unreachable. + /// Prefer `errorUnionIsPayload` to find out whether something is an error or not + /// because it works without having to figure out the string. pub fn getError(self: Value) ?[]const u8 { return switch (self.tag()) { - .error_union => { - const data = self.castTag(.error_union).?.data; - return if (data.tag() == .@"error") - data.castTag(.@"error").?.data.name - else - null; - }, .@"error" => self.castTag(.@"error").?.data.name, + .int_u64 => @panic("TODO"), + .int_i64 => @panic("TODO"), + .int_big_positive => @panic("TODO"), + .int_big_negative => @panic("TODO"), + .one => @panic("TODO"), .undef => unreachable, .unreachable_value => unreachable, .inferred_alloc => unreachable, @@ -1369,6 +1378,16 @@ pub const Value = extern union { else => null, }; } + + /// Assumes the type is an error union. Returns true if and only if the value is + /// the error union payload, not an error. + pub fn errorUnionIsPayload(val: Value) bool { + return switch (val.tag()) { + .eu_payload => true, + else => false, + }; + } + /// Valid for all types. Asserts the value is not undefined. pub fn isFloat(self: Value) bool { return switch (self.tag()) { diff --git a/test/behavior.zig b/test/behavior.zig index 936268af9c..a800b38458 100644 --- a/test/behavior.zig +++ b/test/behavior.zig @@ -7,6 +7,7 @@ test { _ = @import("behavior/generics.zig"); _ = @import("behavior/eval.zig"); _ = @import("behavior/pointers.zig"); + _ = @import("behavior/if.zig"); if (!builtin.zig_is_stage2) { // Tests that only pass for stage1. @@ -100,7 +101,7 @@ test { _ = @import("behavior/generics_stage1.zig"); _ = @import("behavior/hasdecl.zig"); _ = @import("behavior/hasfield.zig"); - _ = @import("behavior/if.zig"); + _ = @import("behavior/if_stage1.zig"); _ = @import("behavior/import.zig"); _ = @import("behavior/incomplete_struct_param_tld.zig"); _ = @import("behavior/inttoptr.zig"); diff --git a/test/behavior/if.zig b/test/behavior/if.zig index e8c84f4570..191d4817df 100644 --- a/test/behavior/if.zig +++ b/test/behavior/if.zig @@ -65,45 +65,3 @@ test "labeled break inside comptime if inside runtime if" { } try expect(answer == 42); } - -test "const result loc, runtime if cond, else unreachable" { - const Num = enum { - One, - Two, - }; - - var t = true; - const x = if (t) Num.Two else unreachable; - try expect(x == .Two); -} - -test "if prongs cast to expected type instead of peer type resolution" { - const S = struct { - fn doTheTest(f: bool) !void { - var x: i32 = 0; - x = if (f) 1 else 2; - try expect(x == 2); - - var b = true; - const y: i32 = if (b) 1 else 2; - try expect(y == 1); - } - }; - try S.doTheTest(false); - comptime try S.doTheTest(false); -} - -test "while copies its payload" { - const S = struct { - fn doTheTest() !void { - var tmp: ?i32 = 10; - if (tmp) |value| { - // Modify the original variable - tmp = null; - try expectEqual(@as(i32, 10), value); - } else unreachable; - } - }; - try S.doTheTest(); - comptime try S.doTheTest(); -} diff --git a/test/behavior/if_stage1.zig b/test/behavior/if_stage1.zig new file mode 100644 index 0000000000..36500fbaee --- /dev/null +++ b/test/behavior/if_stage1.zig @@ -0,0 +1,45 @@ +const std = @import("std"); +const expect = std.testing.expect; +const expectEqual = std.testing.expectEqual; + +test "const result loc, runtime if cond, else unreachable" { + const Num = enum { + One, + Two, + }; + + var t = true; + const x = if (t) Num.Two else unreachable; + try expect(x == .Two); +} + +test "if prongs cast to expected type instead of peer type resolution" { + const S = struct { + fn doTheTest(f: bool) !void { + var x: i32 = 0; + x = if (f) 1 else 2; + try expect(x == 2); + + var b = true; + const y: i32 = if (b) 1 else 2; + try expect(y == 1); + } + }; + try S.doTheTest(false); + comptime try S.doTheTest(false); +} + +test "while copies its payload" { + const S = struct { + fn doTheTest() !void { + var tmp: ?i32 = 10; + if (tmp) |value| { + // Modify the original variable + tmp = null; + try expectEqual(@as(i32, 10), value); + } else unreachable; + } + }; + try S.doTheTest(); + comptime try S.doTheTest(); +}