stage2: add support for optionals in the LLVM backend

We can now codegen optionals! This includes the following instructions:
- is_null
- is_null_ptr
- is_non_null
- is_non_null_ptr
- optional_payload
- optional_payload_ptr
- br_void

Also includes a test for optionals.
This commit is contained in:
Timon Kruiper 2021-01-09 16:22:43 +01:00 committed by Andrew Kelley
parent 3ad9cb8b47
commit d4ec0279d3
4 changed files with 176 additions and 15 deletions

View File

@ -453,13 +453,23 @@ pub fn expr(mod: *Module, scope: *Scope, rl: ResultLoc, node: ast.Node.Index) In
return rvalue(mod, scope, rl, result);
},
.unwrap_optional => {
const operand = try expr(mod, scope, rl, node_datas[node].lhs);
const op: zir.Inst.Tag = switch (rl) {
.ref => .optional_payload_safe_ptr,
else => .optional_payload_safe,
};
const src = token_starts[main_tokens[node]];
return addZIRUnOp(mod, scope, src, op, operand);
switch (rl) {
.ref => return addZIRUnOp(
mod,
scope,
src,
.optional_payload_safe_ptr,
try expr(mod, scope, .ref, node_datas[node].lhs),
),
else => return rvalue(mod, scope, rl, try addZIRUnOp(
mod,
scope,
src,
.optional_payload_safe,
try expr(mod, scope, .none, node_datas[node].lhs),
)),
}
},
.block_two, .block_two_semicolon => {
const statements = [2]ast.Node.Index{ node_datas[node].lhs, node_datas[node].rhs };
@ -1701,7 +1711,12 @@ fn orelseCatchExpr(
// This could be a pointer or value depending on the `rl` parameter.
block_scope.break_count += 1;
const operand = try expr(mod, &block_scope.base, block_scope.break_result_loc, lhs);
const operand = try expr(
mod,
&block_scope.base,
if (block_scope.break_result_loc == .ref) .ref else .none,
lhs,
);
const cond = try addZIRUnOp(mod, &block_scope.base, src, cond_op, operand);
const condbr = try addZIRInstSpecial(mod, &block_scope.base, src, zir.Inst.CondBr, .{

View File

@ -397,6 +397,7 @@ pub const LLVMIRModule = struct {
.block => try self.genBlock(inst.castTag(.block).?),
.br => try self.genBr(inst.castTag(.br).?),
.breakpoint => try self.genBreakpoint(inst.castTag(.breakpoint).?),
.br_void => try self.genBrVoid(inst.castTag(.br_void).?),
.call => try self.genCall(inst.castTag(.call).?),
.cmp_eq => try self.genCmp(inst.castTag(.cmp_eq).?, .eq),
.cmp_gt => try self.genCmp(inst.castTag(.cmp_gt).?, .gt),
@ -406,6 +407,10 @@ pub const LLVMIRModule = struct {
.cmp_neq => try self.genCmp(inst.castTag(.cmp_neq).?, .neq),
.condbr => try self.genCondBr(inst.castTag(.condbr).?),
.intcast => try self.genIntCast(inst.castTag(.intcast).?),
.is_non_null => try self.genIsNonNull(inst.castTag(.is_non_null).?, false),
.is_non_null_ptr => try self.genIsNonNull(inst.castTag(.is_non_null_ptr).?, true),
.is_null => try self.genIsNull(inst.castTag(.is_null).?, false),
.is_null_ptr => try self.genIsNull(inst.castTag(.is_null_ptr).?, true),
.load => try self.genLoad(inst.castTag(.load).?),
.loop => try self.genLoop(inst.castTag(.loop).?),
.not => try self.genNot(inst.castTag(.not).?),
@ -414,6 +419,8 @@ pub const LLVMIRModule = struct {
.store => try self.genStore(inst.castTag(.store).?),
.sub => try self.genSub(inst.castTag(.sub).?),
.unreach => self.genUnreach(inst.castTag(.unreach).?),
.optional_payload => try self.genOptionalPayload(inst.castTag(.optional_payload).?, false),
.optional_payload_ptr => try self.genOptionalPayload(inst.castTag(.optional_payload_ptr).?, true),
.dbg_stmt => blk: {
// TODO: implement debug info
break :blk null;
@ -534,21 +541,29 @@ pub const LLVMIRModule = struct {
}
fn genBr(self: *LLVMIRModule, inst: *Inst.Br) !?*const llvm.Value {
// Get the block that we want to break to.
var block = self.blocks.get(inst.block).?;
_ = self.builder.buildBr(block.parent_bb);
// If the break doesn't break a value, then we don't have to add
// the values to the lists.
if (!inst.operand.ty.hasCodeGenBits()) return null;
if (!inst.operand.ty.hasCodeGenBits()) {
// TODO: in astgen these instructions should turn into `br_void` instructions.
_ = self.builder.buildBr(block.parent_bb);
} else {
const val = try self.resolveInst(inst.operand);
// For the phi node, we need the basic blocks and the values of the
// break instructions.
try block.break_bbs.append(self.gpa, self.builder.getInsertBlock());
// For the phi node, we need the basic blocks and the values of the
// break instructions.
try block.break_bbs.append(self.gpa, self.builder.getInsertBlock());
try block.break_vals.append(self.gpa, val);
const val = try self.resolveInst(inst.operand);
try block.break_vals.append(self.gpa, val);
_ = self.builder.buildBr(block.parent_bb);
}
return null;
}
fn genBrVoid(self: *LLVMIRModule, inst: *Inst.BrVoid) !?*const llvm.Value {
var block = self.blocks.get(inst.block).?;
_ = self.builder.buildBr(block.parent_bb);
return null;
}
@ -591,6 +606,44 @@ pub const LLVMIRModule = struct {
return null;
}
fn genIsNonNull(self: *LLVMIRModule, inst: *Inst.UnOp, operand_is_ptr: bool) !?*const llvm.Value {
const operand = try self.resolveInst(inst.operand);
if (operand_is_ptr) {
const index_type = self.context.intType(32);
var indices: [2]*const llvm.Value = .{
index_type.constNull(),
index_type.constInt(1, false),
};
return self.builder.buildLoad(self.builder.buildInBoundsGEP(operand, &indices, 2, ""), "");
} else {
return self.builder.buildExtractValue(operand, 1, "");
}
}
fn genIsNull(self: *LLVMIRModule, inst: *Inst.UnOp, operand_is_ptr: bool) !?*const llvm.Value {
return self.builder.buildNot((try self.genIsNonNull(inst, operand_is_ptr)).?, "");
}
fn genOptionalPayload(self: *LLVMIRModule, inst: *Inst.UnOp, operand_is_ptr: bool) !?*const llvm.Value {
const operand = try self.resolveInst(inst.operand);
if (operand_is_ptr) {
const index_type = self.context.intType(32);
var indices: [2]*const llvm.Value = .{
index_type.constNull(),
index_type.constNull(),
};
return self.builder.buildInBoundsGEP(operand, &indices, 2, "");
} else {
return self.builder.buildExtractValue(operand, 0, "");
}
}
fn genAdd(self: *LLVMIRModule, inst: *Inst.BinOp) !?*const llvm.Value {
const lhs = try self.resolveInst(inst.lhs);
const rhs = try self.resolveInst(inst.rhs);
@ -751,6 +804,13 @@ pub const LLVMIRModule = struct {
// TODO: consider using buildInBoundsGEP2 for opaque pointers
return self.builder.buildInBoundsGEP(val, &indices, 2, "");
},
.ref_val => {
const elem_value = tv.val.castTag(.ref_val).?.data;
const elem_type = tv.ty.castPointer().?.data;
const alloca = self.buildAlloca(try self.getLLVMType(elem_type, src));
_ = self.builder.buildStore(try self.genTypedValue(src, .{ .ty = elem_type, .val = elem_value }), alloca);
return alloca;
},
else => return self.fail(src, "TODO implement const of pointer type '{}'", .{tv.ty}),
},
.Array => {
@ -765,6 +825,29 @@ pub const LLVMIRModule = struct {
return self.fail(src, "TODO handle more array values", .{});
}
},
.Optional => {
if (!tv.ty.isPtrLikeOptional()) {
var buf: Type.Payload.ElemType = undefined;
const child_type = tv.ty.optionalChild(&buf);
const llvm_child_type = try self.getLLVMType(child_type, src);
if (tv.val.tag() == .null_value) {
var optional_values: [2]*const llvm.Value = .{
llvm_child_type.constNull(),
self.context.intType(1).constNull(),
};
return self.context.constStruct(&optional_values, 2, false);
} else {
var optional_values: [2]*const llvm.Value = .{
try self.genTypedValue(src, .{ .ty = child_type, .val = tv.val }),
self.context.intType(1).constAllOnes(),
};
return self.context.constStruct(&optional_values, 2, false);
}
} else {
return self.fail(src, "TODO implement const of optional pointer", .{});
}
},
else => return self.fail(src, "TODO implement const of type '{}'", .{tv.ty}),
}
}
@ -790,6 +873,20 @@ pub const LLVMIRModule = struct {
const elem_type = try self.getLLVMType(t.elemType(), src);
return elem_type.arrayType(@intCast(c_uint, t.abiSize(self.module.getTarget())));
},
.Optional => {
if (!t.isPtrLikeOptional()) {
var buf: Type.Payload.ElemType = undefined;
const child_type = t.optionalChild(&buf);
var optional_types: [2]*const llvm.Type = .{
try self.getLLVMType(child_type, src),
self.context.intType(1),
};
return self.context.structType(&optional_types, 2, false);
} else {
return self.fail(src, "TODO implement optional pointers as actual pointers", .{});
}
},
else => return self.fail(src, "TODO implement getLLVMType for type '{}'", .{t}),
}
}

View File

@ -21,9 +21,15 @@ pub const Context = opaque {
pub const voidType = LLVMVoidTypeInContext;
extern fn LLVMVoidTypeInContext(C: *const Context) *const Type;
pub const structType = LLVMStructTypeInContext;
extern fn LLVMStructTypeInContext(C: *const Context, ElementTypes: [*]*const Type, ElementCount: c_uint, Packed: LLVMBool) *const Type;
pub const constString = LLVMConstStringInContext;
extern fn LLVMConstStringInContext(C: *const Context, Str: [*]const u8, Length: c_uint, DontNullTerminate: LLVMBool) *const Value;
pub const constStruct = LLVMConstStructInContext;
extern fn LLVMConstStructInContext(C: *const Context, ConstantVals: [*]*const Value, Count: c_uint, Packed: LLVMBool) *const Value;
pub const createBasicBlock = LLVMCreateBasicBlockInContext;
extern fn LLVMCreateBasicBlockInContext(C: *const Context, Name: [*:0]const u8) *const BasicBlock;
@ -204,6 +210,9 @@ pub const Builder = opaque {
pub const buildPhi = LLVMBuildPhi;
extern fn LLVMBuildPhi(*const Builder, Ty: *const Type, Name: [*:0]const u8) *const Value;
pub const buildExtractValue = LLVMBuildExtractValue;
extern fn LLVMBuildExtractValue(*const Builder, AggVal: *const Value, Index: c_uint, Name: [*:0]const u8) *const Value;
};
pub const IntPredicate = extern enum {

View File

@ -132,4 +132,44 @@ pub fn addCases(ctx: *TestContext) !void {
\\}
, "");
}
{
var case = ctx.exeUsingLlvmBackend("optionals", linux_x64);
case.addCompareOutput(
\\fn assert(ok: bool) void {
\\ if (!ok) unreachable;
\\}
\\
\\export fn main() c_int {
\\ var opt_val: ?i32 = 10;
\\ var null_val: ?i32 = null;
\\
\\ var val1: i32 = opt_val.?;
\\ const val1_1: i32 = opt_val.?;
\\ var ptr_val1 = &(opt_val.?);
\\ const ptr_val1_1 = &(opt_val.?);
\\
\\ var val2: i32 = null_val orelse 20;
\\ const val2_2: i32 = null_val orelse 20;
\\
\\ var value: i32 = 20;
\\ var ptr_val2 = &(null_val orelse value);
\\
\\ const val3 = opt_val orelse 30;
\\
\\ assert(val1 == 10);
\\ assert(val1_1 == 10);
\\ assert(ptr_val1.* == 10);
\\ assert(ptr_val1_1.* == 10);
\\
\\ assert(val2 == 20);
\\ assert(val2_2 == 20);
\\ assert(ptr_val2.* == 20);
\\
\\ assert(val3 == 10);
\\ return 0;
\\}
, "");
}
}