riscv: reorganize binOp and implement cmp_imm_gte MIR

this was an annoying one to do, as there is no (to my knowledge) myriad sequence
that will allow us to do `gte` compares with an immediate without allocating a register.
RISC-V provides a single instruction to do compares, that being `lt`, and so you need to
use more than one for other variants, but in this case, i believe you need to allocate a register.
This commit is contained in:
David Rubin 2024-03-22 20:14:10 -07:00
parent 63bbf66553
commit 5e010b6dea
3 changed files with 210 additions and 162 deletions

View File

@ -850,6 +850,122 @@ fn airSlice(self: *Self, inst: Air.Inst.Index) !void {
return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
}
fn airBinOp(self: *Self, inst: Air.Inst.Index, tag: Air.Inst.Tag) !void {
const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
const lhs = try self.resolveInst(bin_op.lhs);
const rhs = try self.resolveInst(bin_op.rhs);
const lhs_ty = self.typeOf(bin_op.lhs);
const rhs_ty = self.typeOf(bin_op.rhs);
const result: MCValue = if (self.liveness.isUnused(inst)) .dead else try self.binOp(tag, inst, lhs, rhs, lhs_ty, rhs_ty);
return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
}
/// For all your binary operation needs, this function will generate
/// the corresponding Mir instruction(s). Returns the location of the
/// result.
///
/// If the binary operation itself happens to be an Air instruction,
/// pass the corresponding index in the inst parameter. That helps
/// this function do stuff like reusing operands.
///
/// This function does not do any lowering to Mir itself, but instead
/// looks at the lhs and rhs and determines which kind of lowering
/// would be best suitable and then delegates the lowering to other
/// functions.
///
/// `maybe_inst` **needs** to be a bin_op, make sure of that.
fn binOp(
self: *Self,
tag: Air.Inst.Tag,
maybe_inst: ?Air.Inst.Index,
lhs: MCValue,
rhs: MCValue,
lhs_ty: Type,
rhs_ty: Type,
) InnerError!MCValue {
const mod = self.bin_file.comp.module.?;
switch (tag) {
// Arithmetic operations on integers and floats
.add,
.sub,
.cmp_eq,
.cmp_neq,
.cmp_gt,
.cmp_gte,
.cmp_lt,
.cmp_lte,
=> {
switch (lhs_ty.zigTypeTag(mod)) {
.Float => return self.fail("TODO binary operations on floats", .{}),
.Vector => return self.fail("TODO binary operations on vectors", .{}),
.Int => {
assert(lhs_ty.eql(rhs_ty, mod));
const int_info = lhs_ty.intInfo(mod);
if (int_info.bits <= 64) {
if (rhs == .immediate) {
return self.binOpImm(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
}
return self.binOpRegister(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
} else {
return self.fail("TODO binary operations on int with bits > 64", .{});
}
},
else => unreachable,
}
},
.ptr_add,
.ptr_sub,
=> {
switch (lhs_ty.zigTypeTag(mod)) {
.Pointer => {
const ptr_ty = lhs_ty;
const elem_ty = switch (ptr_ty.ptrSize(mod)) {
.One => ptr_ty.childType(mod).childType(mod), // ptr to array, so get array element type
else => ptr_ty.childType(mod),
};
const elem_size = elem_ty.abiSize(mod);
if (elem_size == 1) {
const base_tag: Air.Inst.Tag = switch (tag) {
.ptr_add => .add,
.ptr_sub => .sub,
else => unreachable,
};
return try self.binOpRegister(base_tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
} else {
return self.fail("TODO ptr_add with elem_size > 1", .{});
}
},
else => unreachable,
}
},
// These instructions have unsymteric bit sizes on RHS and LHS.
.shr,
.shl,
=> {
switch (lhs_ty.zigTypeTag(mod)) {
.Float => return self.fail("TODO binary operations on floats", .{}),
.Vector => return self.fail("TODO binary operations on vectors", .{}),
.Int => {
const int_info = lhs_ty.intInfo(mod);
if (int_info.bits <= 64) {
if (rhs == .immediate) {
return self.binOpImm(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
}
return self.binOpRegister(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
} else {
return self.fail("TODO binary operations on int with bits > 64", .{});
}
},
else => unreachable,
}
},
else => unreachable,
}
}
/// Don't call this function directly. Use binOp instead.
///
/// Calling this function signals an intention to generate a Mir
@ -963,7 +1079,6 @@ fn binOpImm(
lhs_ty: Type,
rhs_ty: Type,
) !MCValue {
_ = rhs_ty;
assert(rhs == .immediate);
const lhs_is_register = lhs == .register;
@ -1006,142 +1121,44 @@ fn binOpImm(
const mir_tag: Mir.Inst.Tag = switch (tag) {
.shl => .slli,
.shr => .srli,
.cmp_gte => .cmp_imm_gte,
else => return self.fail("TODO: binOpImm {s}", .{@tagName(tag)}),
};
_ = try self.addInst(.{
.tag = mir_tag,
.data = .{
.i_type = .{
.rd = dest_reg,
.rs1 = lhs_reg,
.imm12 = math.cast(i12, rhs.immediate) orelse {
return self.fail("TODO: binOpImm larger than i12 i_type payload", .{});
},
},
},
});
// generate the struct for OF checks
return MCValue{ .register = dest_reg };
}
/// For all your binary operation needs, this function will generate
/// the corresponding Mir instruction(s). Returns the location of the
/// result.
///
/// If the binary operation itself happens to be an Air instruction,
/// pass the corresponding index in the inst parameter. That helps
/// this function do stuff like reusing operands.
///
/// This function does not do any lowering to Mir itself, but instead
/// looks at the lhs and rhs and determines which kind of lowering
/// would be best suitable and then delegates the lowering to other
/// functions.
///
/// `maybe_inst` **needs** to be a bin_op, make sure of that.
fn binOp(
self: *Self,
tag: Air.Inst.Tag,
maybe_inst: ?Air.Inst.Index,
lhs: MCValue,
rhs: MCValue,
lhs_ty: Type,
rhs_ty: Type,
) InnerError!MCValue {
const mod = self.bin_file.comp.module.?;
switch (tag) {
// Arithmetic operations on integers and floats
.add,
.sub,
.cmp_eq,
.cmp_neq,
.cmp_gt,
.cmp_gte,
.cmp_lt,
.cmp_lte,
// apply some special operations needed
switch (mir_tag) {
.slli,
.srli,
=> {
switch (lhs_ty.zigTypeTag(mod)) {
.Float => return self.fail("TODO binary operations on floats", .{}),
.Vector => return self.fail("TODO binary operations on vectors", .{}),
.Int => {
assert(lhs_ty.eql(rhs_ty, mod));
const int_info = lhs_ty.intInfo(mod);
if (int_info.bits <= 64) {
if (rhs == .immediate) {
return self.binOpImm(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
}
return self.binOpRegister(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
} else {
return self.fail("TODO binary operations on int with bits > 64", .{});
}
},
else => unreachable,
}
_ = try self.addInst(.{
.tag = mir_tag,
.data = .{ .i_type = .{
.rd = dest_reg,
.rs1 = lhs_reg,
.imm12 = math.cast(i12, rhs.immediate) orelse {
return self.fail("TODO: binOpImm larger than i12 i_type payload", .{});
},
} },
});
},
.ptr_add,
.ptr_sub,
=> {
switch (lhs_ty.zigTypeTag(mod)) {
.Pointer => {
const ptr_ty = lhs_ty;
const elem_ty = switch (ptr_ty.ptrSize(mod)) {
.One => ptr_ty.childType(mod).childType(mod), // ptr to array, so get array element type
else => ptr_ty.childType(mod),
};
const elem_size = elem_ty.abiSize(mod);
.cmp_imm_gte => {
const imm_reg = try self.copyToTmpRegister(rhs_ty, .{ .immediate = rhs.immediate - 1 });
if (elem_size == 1) {
const base_tag: Air.Inst.Tag = switch (tag) {
.ptr_add => .add,
.ptr_sub => .sub,
else => unreachable,
};
return try self.binOpRegister(base_tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
} else {
return self.fail("TODO ptr_add with elem_size > 1", .{});
}
},
else => unreachable,
}
},
// These instructions have unsymteric bit sizes.
.shr,
.shl,
=> {
switch (lhs_ty.zigTypeTag(mod)) {
.Float => return self.fail("TODO binary operations on floats", .{}),
.Vector => return self.fail("TODO binary operations on vectors", .{}),
.Int => {
const int_info = lhs_ty.intInfo(mod);
if (int_info.bits <= 64) {
if (rhs == .immediate) {
return self.binOpImm(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
}
return self.binOpRegister(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
} else {
return self.fail("TODO binary operations on int with bits > 64", .{});
}
},
else => unreachable,
}
_ = try self.addInst(.{
.tag = mir_tag,
.data = .{ .r_type = .{
.rd = dest_reg,
.rs1 = imm_reg,
.rs2 = lhs_reg,
} },
});
},
else => unreachable,
}
}
fn airBinOp(self: *Self, inst: Air.Inst.Index, tag: Air.Inst.Tag) !void {
const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
const lhs = try self.resolveInst(bin_op.lhs);
const rhs = try self.resolveInst(bin_op.rhs);
const lhs_ty = self.typeOf(bin_op.lhs);
const rhs_ty = self.typeOf(bin_op.rhs);
// generate the struct for overflow checks
const result: MCValue = if (self.liveness.isUnused(inst)) .dead else try self.binOp(tag, inst, lhs, rhs, lhs_ty, rhs_ty);
return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
return MCValue{ .register = dest_reg };
}
fn airPtrArithmetic(self: *Self, inst: Air.Inst.Index, tag: Air.Inst.Tag) !void {
@ -2101,8 +2118,12 @@ fn airCondBr(self: *Self, inst: Air.Inst.Index) !void {
const else_body: []const Air.Inst.Index = @ptrCast(self.air.extra[extra.end + then_body.len ..][0..extra.data.else_body_len]);
const liveness_condbr = self.liveness.getCondBr(inst);
// A branch to the false section. Uses beq
const reloc = try self.condBr(cond_ty, cond);
const cond_reg = try self.register_manager.allocReg(inst, gp);
const cond_reg_lock = self.register_manager.lockRegAssumeUnused(cond_reg);
defer self.register_manager.unlockReg(cond_reg_lock);
// A branch to the false section. Uses bne
const reloc = try self.condBr(cond_ty, cond, cond_reg);
// If the condition dies here in this condbr instruction, process
// that death now instead of later as this has an effect on
@ -2233,19 +2254,14 @@ fn airCondBr(self: *Self, inst: Air.Inst.Index) !void {
}
}
fn condBr(self: *Self, cond_ty: Type, condition: MCValue) !Mir.Inst.Index {
_ = cond_ty;
const reg = switch (condition) {
.register => |r| r,
else => try self.copyToTmpRegister(Type.bool, condition),
};
fn condBr(self: *Self, cond_ty: Type, condition: MCValue, cond_reg: Register) !Mir.Inst.Index {
try self.genSetReg(cond_ty, cond_reg, condition);
return try self.addInst(.{
.tag = .bne,
.data = .{
.b_type = .{
.rs1 = reg,
.rs1 = cond_reg,
.rs2 = .zero,
.inst = undefined,
},
@ -2739,6 +2755,7 @@ fn genSetStack(self: *Self, ty: Type, stack_offset: u32, src_val: MCValue) Inner
} else return self.fail("TODO genSetStack for {s}", .{@tagName(self.bin_file.tag)});
};
// setup the src pointer
_ = try self.addInst(.{
.tag = .load_symbol,
.data = .{
@ -2789,7 +2806,7 @@ fn genInlineMemcpy(
// compare count to length
const compare_inst = try self.addInst(.{
.tag = .cmp_gt,
.tag = .cmp_eq,
.data = .{ .r_type = .{
.rd = tmp,
.rs1 = count,
@ -2861,9 +2878,12 @@ fn genSetReg(self: *Self, ty: Type, reg: Register, src_val: MCValue) InnerError!
} },
});
} else {
// TODO: use a more advanced myriad seq to do this without a reg.
// see: https://github.com/llvm/llvm-project/blob/081a66ffacfe85a37ff775addafcf3371e967328/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.cpp#L224
const temp = try self.register_manager.allocReg(null, gp);
const maybe_temp_lock = self.register_manager.lockReg(temp);
defer if (maybe_temp_lock) |temp_lock| self.register_manager.unlockReg(temp_lock);
const temp_lock = self.register_manager.lockRegAssumeUnused(temp);
defer self.register_manager.unlockReg(temp_lock);
const lo32: i32 = @truncate(x);
const carry: i32 = if (lo32 < 0) 1 else 0;

View File

@ -59,6 +59,7 @@ pub fn emitMir(
.cmp_eq => try emit.mirRType(inst),
.cmp_gt => try emit.mirRType(inst),
.cmp_imm_gte => try emit.mirRType(inst),
.beq => try emit.mirBType(inst),
.bne => try emit.mirBType(inst),
@ -185,14 +186,27 @@ fn mirRType(emit: *Emit, inst: Mir.Inst.Index) !void {
switch (tag) {
.add => try emit.writeInstruction(Instruction.add(rd, rs1, rs2)),
.sub => try emit.writeInstruction(Instruction.sub(rd, rs1, rs2)),
.cmp_gt => try emit.writeInstruction(Instruction.slt(rd, rs1, rs2)),
.cmp_gt => {
// rs1 > rs2
try emit.writeInstruction(Instruction.slt(rd, rs1, rs2));
},
.cmp_eq => {
// rs1 == rs2
// if equal, write 0 to rd
try emit.writeInstruction(Instruction.xor(rd, rs1, rs2));
// if rd == 0, set rd to 1
try emit.writeInstruction(Instruction.sltiu(rd, rd, 1));
},
.sllw => try emit.writeInstruction(Instruction.sllw(rd, rs1, rs2)),
.srlw => try emit.writeInstruction(Instruction.srlw(rd, rs1, rs2)),
.@"or" => try emit.writeInstruction(Instruction.@"or"(rd, rs1, rs2)),
.cmp_imm_gte => {
// rd = rs1 >= imm12
// see the docstring for cmp_imm_gte to see why we use r_type here
try emit.writeInstruction(Instruction.slt(rd, rs1, rs2));
try emit.writeInstruction(Instruction.xori(rd, rd, 1));
},
else => unreachable,
}
}
@ -220,30 +234,34 @@ fn mirIType(emit: *Emit, inst: Mir.Inst.Index) !void {
const tag = emit.mir.instructions.items(.tag)[inst];
const i_type = emit.mir.instructions.items(.data)[inst].i_type;
const rd = i_type.rd;
const rs1 = i_type.rs1;
const imm12 = i_type.imm12;
switch (tag) {
.addi => try emit.writeInstruction(Instruction.addi(i_type.rd, i_type.rs1, i_type.imm12)),
.jalr => try emit.writeInstruction(Instruction.jalr(i_type.rd, i_type.imm12, i_type.rs1)),
.addi => try emit.writeInstruction(Instruction.addi(rd, rs1, imm12)),
.jalr => try emit.writeInstruction(Instruction.jalr(rd, imm12, rs1)),
.ld => try emit.writeInstruction(Instruction.ld(i_type.rd, i_type.imm12, i_type.rs1)),
.lw => try emit.writeInstruction(Instruction.lw(i_type.rd, i_type.imm12, i_type.rs1)),
.lh => try emit.writeInstruction(Instruction.lh(i_type.rd, i_type.imm12, i_type.rs1)),
.lb => try emit.writeInstruction(Instruction.lb(i_type.rd, i_type.imm12, i_type.rs1)),
.ld => try emit.writeInstruction(Instruction.ld(rd, imm12, rs1)),
.lw => try emit.writeInstruction(Instruction.lw(rd, imm12, rs1)),
.lh => try emit.writeInstruction(Instruction.lh(rd, imm12, rs1)),
.lb => try emit.writeInstruction(Instruction.lb(rd, imm12, rs1)),
.sd => try emit.writeInstruction(Instruction.sd(i_type.rd, i_type.imm12, i_type.rs1)),
.sw => try emit.writeInstruction(Instruction.sw(i_type.rd, i_type.imm12, i_type.rs1)),
.sh => try emit.writeInstruction(Instruction.sh(i_type.rd, i_type.imm12, i_type.rs1)),
.sb => try emit.writeInstruction(Instruction.sb(i_type.rd, i_type.imm12, i_type.rs1)),
.sd => try emit.writeInstruction(Instruction.sd(rd, imm12, rs1)),
.sw => try emit.writeInstruction(Instruction.sw(rd, imm12, rs1)),
.sh => try emit.writeInstruction(Instruction.sh(rd, imm12, rs1)),
.sb => try emit.writeInstruction(Instruction.sb(rd, imm12, rs1)),
.ldr_ptr_stack => try emit.writeInstruction(Instruction.add(i_type.rd, i_type.rs1, .sp)),
.ldr_ptr_stack => try emit.writeInstruction(Instruction.add(rd, rs1, .sp)),
.abs => {
try emit.writeInstruction(Instruction.sraiw(i_type.rd, i_type.rs1, @intCast(i_type.imm12)));
try emit.writeInstruction(Instruction.xor(i_type.rs1, i_type.rs1, i_type.rd));
try emit.writeInstruction(Instruction.subw(i_type.rs1, i_type.rs1, i_type.rd));
try emit.writeInstruction(Instruction.sraiw(rd, rs1, @intCast(imm12)));
try emit.writeInstruction(Instruction.xor(rs1, rs1, rd));
try emit.writeInstruction(Instruction.subw(rs1, rs1, rd));
},
.srli => try emit.writeInstruction(Instruction.srli(i_type.rd, i_type.rs1, @intCast(i_type.imm12))),
.slli => try emit.writeInstruction(Instruction.slli(i_type.rd, i_type.rs1, @intCast(i_type.imm12))),
.srli => try emit.writeInstruction(Instruction.srli(rd, rs1, @intCast(imm12))),
.slli => try emit.writeInstruction(Instruction.slli(rd, rs1, @intCast(imm12))),
else => unreachable,
}
@ -471,12 +489,13 @@ fn instructionSize(emit: *Emit, inst: Mir.Inst.Index) usize {
.dbg_prologue_end,
=> 0,
.psuedo_epilogue => 12, // 3 * 4
.psuedo_prologue => 16, // 4 * 4
.psuedo_epilogue => 12,
.psuedo_prologue => 16,
.abs => 12, // 3 * 4
.abs => 12,
.cmp_eq => 8,
.cmp_imm_gte => 8,
else => 4,
};

View File

@ -57,12 +57,21 @@ pub const Inst = struct {
/// Jumps. Uses `inst` payload.
j,
// TODO: Maybe create a special data for compares that includes the ops
/// Compare equal, uses r_type
// NOTE: Maybe create a special data for compares that includes the ops
/// Register `==`, uses r_type
cmp_eq,
/// Compare greater than, uses r_type
/// Register `>`, uses r_type
cmp_gt,
/// Immediate `>=`, uses r_type
///
/// Note: this uses r_type because RISC-V does not provide a good way
/// to do `>=` comparisons on immediates. Usually we would just subtract
/// 1 from the immediate and do a `>` comparison, however there is no `>`
/// register to immedate comparison in RISC-V. This leads us to need to
/// allocate a register for temporary use.
cmp_imm_gte,
/// Branch if equal Uses b_type
beq,
/// Branch if not eql Uses b_type