From 4a40282391f0b92a83a6a8c269c27a32be92884a Mon Sep 17 00:00:00 2001 From: Vexu Date: Wed, 12 Aug 2020 22:30:14 +0300 Subject: [PATCH] stage2: implement unwrap optional --- src-self-hosted/Module.zig | 26 +++++++++++++++++++-- src-self-hosted/astgen.zig | 12 ++++++++++ src-self-hosted/codegen.zig | 10 ++++++++ src-self-hosted/ir.zig | 22 ++++++++++++++++++ src-self-hosted/zir.zig | 35 ++++++++++++++++++++++++++++ src-self-hosted/zir_sema.zig | 42 +++++++++++++++++++++++++++++++++- test/stage2/compare_output.zig | 5 ++++ 7 files changed, 149 insertions(+), 3 deletions(-) diff --git a/src-self-hosted/Module.zig b/src-self-hosted/Module.zig index 110fe05b8e..bbbcde9a71 100644 --- a/src-self-hosted/Module.zig +++ b/src-self-hosted/Module.zig @@ -2016,6 +2016,28 @@ pub fn addCall( return &inst.base; } +pub fn addUnwrapOptional( + self: *Module, + block: *Scope.Block, + src: usize, + ty: Type, + operand: *Inst, + safety_check: bool, +) !*Inst { + const inst = try block.arena.create(Inst.UnwrapOptional); + inst.* = .{ + .base = .{ + .tag = .unwrap_optional, + .ty = ty, + .src = src, + }, + .operand = operand, + .safety_check = safety_check, + }; + try block.instructions.append(self.gpa, &inst.base); + return &inst.base; +} + pub fn constInst(self: *Module, scope: *Scope, src: usize, typed_value: TypedValue) !*Inst { const const_inst = try scope.arena().create(Inst.Constant); const_inst.* = .{ @@ -2488,9 +2510,9 @@ pub fn coerce(self: *Module, scope: *Scope, dest_type: Type, inst: *Inst) !*Inst if (child_type.eql(inst.ty)) { return self.constInst(scope, inst.src, .{ .ty = dest_type, .val = val }); } - return self.fail(scope, inst.src, "TODO optional wrap {} to {}", .{ val, inst.ty }); + return self.fail(scope, inst.src, "TODO optional wrap {} to {}", .{ val, dest_type }); } else if (child_type.eql(inst.ty)) { - return self.fail(scope, inst.src, "TODO optional wrap {}", .{inst.ty}); + return self.fail(scope, inst.src, "TODO optional wrap {}", .{dest_type}); } } diff --git a/src-self-hosted/astgen.zig b/src-self-hosted/astgen.zig index 57f399d696..ed3f8ab1b6 100644 --- a/src-self-hosted/astgen.zig +++ b/src-self-hosted/astgen.zig @@ -106,6 +106,7 @@ pub fn expr(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node) InnerEr .BoolLiteral => return rlWrap(mod, scope, rl, try boolLiteral(mod, scope, node.castTag(.BoolLiteral).?)), .NullLiteral => return rlWrap(mod, scope, rl, try nullLiteral(mod, scope, node.castTag(.NullLiteral).?)), .OptionalType => return rlWrap(mod, scope, rl, try optionalType(mod, scope, node.castTag(.OptionalType).?)), + .UnwrapOptional => return unwrapOptional(mod, scope, rl, node.castTag(.UnwrapOptional).?), else => return mod.failNode(scope, node, "TODO implement astgen.Expr for {}", .{@tagName(node.tag)}), } } @@ -305,6 +306,17 @@ fn optionalType(mod: *Module, scope: *Scope, node: *ast.Node.SimplePrefixOp) Inn return addZIRUnOp(mod, scope, src, .optional_type, operand); } +fn unwrapOptional(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node.SimpleSuffixOp) InnerError!*zir.Inst { + const tree = scope.tree(); + const src = tree.token_locs[node.rtoken].start; + + const operand = try expr(mod, scope, .lvalue, node.lhs); + const unwrapped_ptr = try addZIRInst(mod, scope, src, zir.Inst.UnwrapOptional, .{ .operand = operand }, .{}); + if (rl == .lvalue) return unwrapped_ptr; + + return rlWrap(mod, scope, rl, try addZIRUnOp(mod, scope, src, .deref, unwrapped_ptr)); +} + /// Identifier token -> String (allocated in scope.arena()) pub fn identifierTokenString(mod: *Module, scope: *Scope, token: ast.TokenIndex) InnerError![]const u8 { const tree = scope.tree(); diff --git a/src-self-hosted/codegen.zig b/src-self-hosted/codegen.zig index 6e8ab34478..887126ba2b 100644 --- a/src-self-hosted/codegen.zig +++ b/src-self-hosted/codegen.zig @@ -668,6 +668,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { .store => return self.genStore(inst.castTag(.store).?), .sub => return self.genSub(inst.castTag(.sub).?), .unreach => return MCValue{ .unreach = {} }, + .unwrap_optional => return self.genUnwrapOptional(inst.castTag(.unwrap_optional).?), } } @@ -817,6 +818,15 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { } } + fn genUnwrapOptional(self: *Self, inst: *ir.Inst.UnwrapOptional) !MCValue { + // No side effects, so if it's unreferenced, do nothing. + if (inst.base.isUnused()) + return MCValue.dead; + switch (arch) { + else => return self.fail(inst.base.src, "TODO implement unwrap optional for {}", .{self.target.cpu.arch}), + } + } + fn genLoad(self: *Self, inst: *ir.Inst.UnOp) !MCValue { const elem_ty = inst.base.ty; if (!elem_ty.hasCodeGenBits()) diff --git a/src-self-hosted/ir.zig b/src-self-hosted/ir.zig index f4262592de..93952a4214 100644 --- a/src-self-hosted/ir.zig +++ b/src-self-hosted/ir.zig @@ -82,6 +82,7 @@ pub const Inst = struct { not, floatcast, intcast, + unwrap_optional, pub fn Type(tag: Tag) type { return switch (tag) { @@ -124,6 +125,7 @@ pub const Inst = struct { .condbr => CondBr, .constant => Constant, .loop => Loop, + .unwrap_optional => UnwrapOptional, }; } @@ -420,6 +422,26 @@ pub const Inst = struct { } }; + pub const UnwrapOptional = struct { + pub const base_tag = Tag.unwrap_optional; + base: Inst, + + operand: *Inst, + safety_check: bool, + + pub fn operandCount(self: *const UnwrapOptional) usize { + return 1; + } + pub fn getOperand(self: *const UnwrapOptional, index: usize) ?*Inst { + var i = index; + + if (i < 1) + return self.operand; + i -= 1; + + return null; + } + }; }; pub const Body = struct { diff --git a/src-self-hosted/zir.zig b/src-self-hosted/zir.zig index c98fcf4a74..84e9731126 100644 --- a/src-self-hosted/zir.zig +++ b/src-self-hosted/zir.zig @@ -214,6 +214,8 @@ pub const Inst = struct { xor, /// Create an optional type '?T' optional_type, + /// Unwraps an optional value 'lhs.?' + unwrap_optional, pub fn Type(tag: Tag) type { return switch (tag) { @@ -301,6 +303,7 @@ pub const Inst = struct { .fntype => FnType, .elemptr => ElemPtr, .condbr => CondBr, + .unwrap_optional => UnwrapOptional, }; } @@ -376,6 +379,7 @@ pub const Inst = struct { .typeof, .xor, .optional_type, + .unwrap_optional, => false, .@"break", @@ -816,6 +820,18 @@ pub const Inst = struct { }, kw_args: struct {}, }; + + pub const UnwrapOptional = struct { + pub const base_tag = Tag.unwrap_optional; + base: Inst, + + positionals: struct { + operand: *Inst, + }, + kw_args: struct { + safety_check: bool = true, + }, + }; }; pub const ErrorMsg = struct { @@ -2141,6 +2157,25 @@ const EmitZIR = struct { }; break :blk &new_inst.base; }, + + .unwrap_optional => blk: { + const old_inst = inst.castTag(.unwrap_optional).?; + + const new_inst = try self.arena.allocator.create(Inst.UnwrapOptional); + new_inst.* = .{ + .base = .{ + .src = inst.src, + .tag = Inst.UnwrapOptional.base_tag, + }, + .positionals = .{ + .operand = try self.resolveInst(new_body, old_inst.operand), + }, + .kw_args = .{ + .safety_check = old_inst.safety_check, + }, + }; + break :blk &new_inst.base; + }, }; try instructions.append(new_inst); try inst_table.put(inst, new_inst); diff --git a/src-self-hosted/zir_sema.zig b/src-self-hosted/zir_sema.zig index e241caefc9..39fbf9221a 100644 --- a/src-self-hosted/zir_sema.zig +++ b/src-self-hosted/zir_sema.zig @@ -107,6 +107,7 @@ pub fn analyzeInst(mod: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError! .boolnot => return analyzeInstBoolNot(mod, scope, old_inst.castTag(.boolnot).?), .typeof => return analyzeInstTypeOf(mod, scope, old_inst.castTag(.typeof).?), .optional_type => return analyzeInstOptionalType(mod, scope, old_inst.castTag(.optional_type).?), + .unwrap_optional => return analyzeInstUnwrapOptional(mod, scope, old_inst.castTag(.unwrap_optional).?), } } @@ -306,8 +307,19 @@ fn analyzeInstRetPtr(mod: *Module, scope: *Scope, inst: *zir.Inst.NoOp) InnerErr fn analyzeInstRef(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp) InnerError!*Inst { const operand = try resolveInst(mod, scope, inst.positionals.operand); - const b = try mod.requireRuntimeBlock(scope, inst.base.src); const ptr_type = try mod.singleConstPtrType(scope, inst.base.src, operand.ty); + + if (operand.value()) |val| { + const ref_payload = try scope.arena().create(Value.Payload.RefVal); + ref_payload.* = .{ .val = val }; + + return mod.constInst(scope, inst.base.src, .{ + .ty = ptr_type, + .val = Value.initPayload(&ref_payload.base), + }); + } + + const b = try mod.requireRuntimeBlock(scope, inst.base.src); return mod.addUnOp(b, inst.base.src, ptr_type, .ref, operand); } @@ -649,6 +661,34 @@ fn analyzeInstOptionalType(mod: *Module, scope: *Scope, optional: *zir.Inst.UnOp })); } +fn analyzeInstUnwrapOptional(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnwrapOptional) InnerError!*Inst { + const operand = try resolveInst(mod, scope, unwrap.positionals.operand); + assert(operand.ty.zigTypeTag() == .Pointer); + + if (operand.ty.elemType().zigTypeTag() != .Optional) { + return mod.fail(scope, unwrap.base.src, "expected optional type, found {}", .{operand.ty.elemType()}); + } + + const child_type = operand.ty.elemType().elemType(); + const child_pointer = if (operand.ty.isConstPtr()) + try mod.singleConstPtrType(scope, unwrap.base.src, child_type) + else + try mod.singleMutPtrType(scope, unwrap.base.src, child_type); + + if (operand.value()) |val| { + if (val.tag() == .null_value) { + return mod.fail(scope, unwrap.base.src, "unable to unwrap null", .{}); + } + return mod.constInst(scope, unwrap.base.src, .{ + .ty = child_pointer, + .val = val, + }); + } + + const b = try mod.requireRuntimeBlock(scope, unwrap.base.src); + return mod.addUnwrapOptional(b, unwrap.base.src, child_pointer, operand, unwrap.kw_args.safety_check); +} + fn analyzeInstFnType(mod: *Module, scope: *Scope, fntype: *zir.Inst.FnType) InnerError!*Inst { const return_type = try resolveType(mod, scope, fntype.positionals.return_type); diff --git a/test/stage2/compare_output.zig b/test/stage2/compare_output.zig index bb3e542f13..477b5d92b1 100644 --- a/test/stage2/compare_output.zig +++ b/test/stage2/compare_output.zig @@ -31,6 +31,11 @@ pub fn addCases(ctx: *TestContext) !void { \\export fn _start() noreturn { \\ print(); \\ + \\ const a: u32 = 2; + \\ const b: ?u32 = a; + \\ const c = b.?; + \\ if (c != 2) unreachable; + \\ \\ exit(); \\} \\