stage2: add inline memset for x86_64 backend

* introduce new Mir tag `mov_mem_index_imm` which selects instruction
  of the form `OP ptr [reg + rax*1 + imm32], imm32` where the encoded
  flags select the appropriate ptr width for memory store operation
  (note that scale is fixed and set at 1)
This commit is contained in:
Jakub Konka 2022-01-18 11:58:22 +01:00
parent 4938fb8f5c
commit aaa641feba
4 changed files with 181 additions and 9 deletions

View File

@ -3148,7 +3148,7 @@ fn genSetStack(self: *Self, ty: Type, stack_offset: u32, mcv: MCValue) InnerErro
2 => return self.genSetStack(ty, stack_offset, .{ .immediate = 0xaaaa }),
4 => return self.genSetStack(ty, stack_offset, .{ .immediate = 0xaaaaaaaa }),
8 => return self.genSetStack(ty, stack_offset, .{ .immediate = 0xaaaaaaaaaaaaaaaa }),
else => return self.fail("TODO implement memset", .{}),
else => return self.genInlineMemset(ty, stack_offset, .{ .immediate = 0xaa }),
}
},
.compare_flags_unsigned => |op| {
@ -3398,6 +3398,97 @@ fn genInlineMemcpy(
try self.performReloc(loop_reloc);
}
fn genInlineMemset(self: *Self, ty: Type, stack_offset: u32, value: MCValue) InnerError!void {
try self.register_manager.getReg(.rax, null);
const abi_size = ty.abiSize(self.target.*);
const adj_off = stack_offset + abi_size;
if (adj_off > 128) {
return self.fail("TODO inline memset with large stack offset", .{});
}
const negative_offset = @bitCast(u32, -@intCast(i32, adj_off));
// We are actually counting `abi_size` bytes; however, we reuse the index register
// as both the counter and offset scaler, hence we need to subtract one from `abi_size`
// and count until -1.
if (abi_size > math.maxInt(i32)) {
// movabs rax, abi_size - 1
const payload = try self.addExtra(Mir.Imm64.encode(abi_size - 1));
_ = try self.addInst(.{
.tag = .movabs,
.ops = (Mir.Ops{
.reg1 = .rax,
}).encode(),
.data = .{ .payload = payload },
});
} else {
// mov rax, abi_size - 1
_ = try self.addInst(.{
.tag = .mov,
.ops = (Mir.Ops{
.reg1 = .rax,
}).encode(),
.data = .{ .imm = @truncate(u32, abi_size - 1) },
});
}
// loop:
// cmp rax, -1
const loop_start = try self.addInst(.{
.tag = .cmp,
.ops = (Mir.Ops{
.reg1 = .rax,
}).encode(),
.data = .{ .imm = @bitCast(u32, @as(i32, -1)) },
});
// je end
const loop_reloc = try self.addInst(.{
.tag = .cond_jmp_eq_ne,
.ops = (Mir.Ops{ .flags = 0b01 }).encode(),
.data = .{ .inst = undefined },
});
switch (value) {
.immediate => |x| {
if (x > math.maxInt(i32)) {
return self.fail("TODO inline memset for value immediate larger than 32bits", .{});
}
// mov byte ptr [rbp + rax + stack_offset], imm
const payload = try self.addExtra(Mir.ImmPair{
.dest_off = negative_offset,
.operand = @truncate(u32, x),
});
_ = try self.addInst(.{
.tag = .mov_mem_index_imm,
.ops = (Mir.Ops{
.reg1 = .rbp,
}).encode(),
.data = .{ .payload = payload },
});
},
else => return self.fail("TODO inline memset for value of type {}", .{value}),
}
// sub rax, 1
_ = try self.addInst(.{
.tag = .sub,
.ops = (Mir.Ops{
.reg1 = .rax,
}).encode(),
.data = .{ .imm = 1 },
});
// jmp loop
_ = try self.addInst(.{
.tag = .jmp,
.ops = (Mir.Ops{ .flags = 0b00 }).encode(),
.data = .{ .inst = loop_start },
});
// end:
try self.performReloc(loop_reloc);
}
fn genSetReg(self: *Self, ty: Type, reg: Register, mcv: MCValue) InnerError!void {
switch (mcv) {
.dead => unreachable,
@ -3639,7 +3730,7 @@ fn airArrayToSlice(self: *Self, inst: Air.Inst.Index) !void {
const stack_offset = try self.allocMem(inst, 16, 16);
const array_ty = ptr_ty.childType();
const array_len = array_ty.arrayLenIncludingSentinel();
try self.genSetStack(Type.initTag(.usize), stack_offset + 8, ptr);
try self.genSetStack(ptr_ty, stack_offset + 8, ptr);
try self.genSetStack(Type.initTag(.u64), stack_offset + 16, .{ .immediate = array_len });
break :blk .{ .stack_offset = stack_offset };
};

View File

@ -115,6 +115,16 @@ pub fn lowerMir(emit: *Emit) InnerError!void {
.cmp_scale_imm => try emit.mirArithScaleImm(.cmp, inst),
.mov_scale_imm => try emit.mirArithScaleImm(.mov, inst),
.adc_mem_index_imm => try emit.mirArithMemIndexImm(.adc, inst),
.add_mem_index_imm => try emit.mirArithMemIndexImm(.add, inst),
.sub_mem_index_imm => try emit.mirArithMemIndexImm(.sub, inst),
.xor_mem_index_imm => try emit.mirArithMemIndexImm(.xor, inst),
.and_mem_index_imm => try emit.mirArithMemIndexImm(.@"and", inst),
.or_mem_index_imm => try emit.mirArithMemIndexImm(.@"or", inst),
.sbb_mem_index_imm => try emit.mirArithMemIndexImm(.sbb, inst),
.cmp_mem_index_imm => try emit.mirArithMemIndexImm(.cmp, inst),
.mov_mem_index_imm => try emit.mirArithMemIndexImm(.mov, inst),
.movabs => try emit.mirMovabs(inst),
.lea => try emit.mirLea(inst),
@ -549,6 +559,29 @@ fn mirArithScaleImm(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
}), imm_pair.operand, emit.code) catch |err| emit.failWithLoweringError(err);
}
fn mirArithMemIndexImm(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
assert(ops.reg2 == .none);
const payload = emit.mir.instructions.items(.data)[inst].payload;
const imm_pair = emit.mir.extraData(Mir.ImmPair, payload).data;
const ptr_size: Memory.PtrSize = switch (ops.flags) {
0b00 => .byte_ptr,
0b01 => .word_ptr,
0b10 => .dword_ptr,
0b11 => .qword_ptr,
};
const scale_index = ScaleIndex{
.scale = 0,
.index = .rax,
};
// OP ptr [reg1 + rax*1 + imm32], imm32
return lowerToMiEnc(tag, RegisterOrMemory.mem(ptr_size, .{
.disp = imm_pair.dest_off,
.base = ops.reg1,
.scale_index = scale_index,
}), imm_pair.operand, emit.code) catch |err| emit.failWithLoweringError(err);
}
fn mirMovabs(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
const tag = emit.mir.instructions.items(.tag)[inst];
assert(tag == .movabs);

View File

@ -79,6 +79,13 @@ pub const Inst = struct {
/// * Data field `payload` points at `ImmPair`.
adc_scale_imm,
/// ops flags: form:
/// 0b00 byte ptr [reg1 + rax + imm32], imm8
/// 0b01 word ptr [reg1 + rax + imm32], imm16
/// 0b10 dword ptr [reg1 + rax + imm32], imm32
/// 0b11 qword ptr [reg1 + rax + imm32], imm32 (sign-extended to imm64)
adc_mem_index_imm,
// The following instructions all have the same encoding as `adc`.
add,
@ -86,81 +93,97 @@ pub const Inst = struct {
add_scale_src,
add_scale_dst,
add_scale_imm,
add_mem_index_imm,
sub,
sub_mem_imm,
sub_scale_src,
sub_scale_dst,
sub_scale_imm,
sub_mem_index_imm,
xor,
xor_mem_imm,
xor_scale_src,
xor_scale_dst,
xor_scale_imm,
xor_mem_index_imm,
@"and",
and_mem_imm,
and_scale_src,
and_scale_dst,
and_scale_imm,
and_mem_index_imm,
@"or",
or_mem_imm,
or_scale_src,
or_scale_dst,
or_scale_imm,
or_mem_index_imm,
rol,
rol_mem_imm,
rol_scale_src,
rol_scale_dst,
rol_scale_imm,
rol_mem_index_imm,
ror,
ror_mem_imm,
ror_scale_src,
ror_scale_dst,
ror_scale_imm,
ror_mem_index_imm,
rcl,
rcl_mem_imm,
rcl_scale_src,
rcl_scale_dst,
rcl_scale_imm,
rcl_mem_index_imm,
rcr,
rcr_mem_imm,
rcr_scale_src,
rcr_scale_dst,
rcr_scale_imm,
rcr_mem_index_imm,
shl,
shl_mem_imm,
shl_scale_src,
shl_scale_dst,
shl_scale_imm,
shl_mem_index_imm,
sal,
sal_mem_imm,
sal_scale_src,
sal_scale_dst,
sal_scale_imm,
sal_mem_index_imm,
shr,
shr_mem_imm,
shr_scale_src,
shr_scale_dst,
shr_scale_imm,
shr_mem_index_imm,
sar,
sar_mem_imm,
sar_scale_src,
sar_scale_dst,
sar_scale_imm,
sar_mem_index_imm,
sbb,
sbb_mem_imm,
sbb_scale_src,
sbb_scale_dst,
sbb_scale_imm,
sbb_mem_index_imm,
cmp,
cmp_mem_imm,
cmp_scale_src,
cmp_scale_dst,
cmp_scale_imm,
cmp_mem_index_imm,
mov,
mov_mem_imm,
mov_scale_src,
mov_scale_dst,
mov_scale_imm,
mov_mem_index_imm,
/// ops flags: form:
/// 0b00 reg1, [reg2 + imm32]

View File

@ -64,6 +64,7 @@ pub fn printMir(print: *const Print, w: anytype, mir_to_air_map: std.AutoHashMap
.@"or" => try print.mirArith(.@"or", inst, w),
.sbb => try print.mirArith(.sbb, inst, w),
.cmp => try print.mirArith(.cmp, inst, w),
.mov => try print.mirArith(.mov, inst, w),
.adc_mem_imm => try print.mirArithMemImm(.adc, inst, w),
.add_mem_imm => try print.mirArithMemImm(.add, inst, w),
@ -73,6 +74,7 @@ pub fn printMir(print: *const Print, w: anytype, mir_to_air_map: std.AutoHashMap
.or_mem_imm => try print.mirArithMemImm(.@"or", inst, w),
.sbb_mem_imm => try print.mirArithMemImm(.sbb, inst, w),
.cmp_mem_imm => try print.mirArithMemImm(.cmp, inst, w),
.mov_mem_imm => try print.mirArithMemImm(.mov, inst, w),
.adc_scale_src => try print.mirArithScaleSrc(.adc, inst, w),
.add_scale_src => try print.mirArithScaleSrc(.add, inst, w),
@ -82,6 +84,7 @@ pub fn printMir(print: *const Print, w: anytype, mir_to_air_map: std.AutoHashMap
.or_scale_src => try print.mirArithScaleSrc(.@"or", inst, w),
.sbb_scale_src => try print.mirArithScaleSrc(.sbb, inst, w),
.cmp_scale_src => try print.mirArithScaleSrc(.cmp, inst, w),
.mov_scale_src => try print.mirArithScaleSrc(.mov, inst, w),
.adc_scale_dst => try print.mirArithScaleDst(.adc, inst, w),
.add_scale_dst => try print.mirArithScaleDst(.add, inst, w),
@ -91,6 +94,7 @@ pub fn printMir(print: *const Print, w: anytype, mir_to_air_map: std.AutoHashMap
.or_scale_dst => try print.mirArithScaleDst(.@"or", inst, w),
.sbb_scale_dst => try print.mirArithScaleDst(.sbb, inst, w),
.cmp_scale_dst => try print.mirArithScaleDst(.cmp, inst, w),
.mov_scale_dst => try print.mirArithScaleDst(.mov, inst, w),
.adc_scale_imm => try print.mirArithScaleImm(.adc, inst, w),
.add_scale_imm => try print.mirArithScaleImm(.add, inst, w),
@ -100,11 +104,18 @@ pub fn printMir(print: *const Print, w: anytype, mir_to_air_map: std.AutoHashMap
.or_scale_imm => try print.mirArithScaleImm(.@"or", inst, w),
.sbb_scale_imm => try print.mirArithScaleImm(.sbb, inst, w),
.cmp_scale_imm => try print.mirArithScaleImm(.cmp, inst, w),
.mov => try print.mirArith(.mov, inst, w),
.mov_scale_src => try print.mirArithScaleSrc(.mov, inst, w),
.mov_scale_dst => try print.mirArithScaleDst(.mov, inst, w),
.mov_scale_imm => try print.mirArithScaleImm(.mov, inst, w),
.adc_mem_index_imm => try print.mirArithMemIndexImm(.adc, inst, w),
.add_mem_index_imm => try print.mirArithMemIndexImm(.add, inst, w),
.sub_mem_index_imm => try print.mirArithMemIndexImm(.sub, inst, w),
.xor_mem_index_imm => try print.mirArithMemIndexImm(.xor, inst, w),
.and_mem_index_imm => try print.mirArithMemIndexImm(.@"and", inst, w),
.or_mem_index_imm => try print.mirArithMemIndexImm(.@"or", inst, w),
.sbb_mem_index_imm => try print.mirArithMemIndexImm(.sbb, inst, w),
.cmp_mem_index_imm => try print.mirArithMemIndexImm(.cmp, inst, w),
.mov_mem_index_imm => try print.mirArithMemIndexImm(.mov, inst, w),
.movabs => try print.mirMovabs(inst, w),
.lea => try print.mirLea(inst, w),
@ -316,11 +327,11 @@ fn mirArithScaleDst(print: *const Print, tag: Mir.Inst.Tag, inst: Mir.Inst.Index
if (ops.reg2 == .none) {
// OP [reg1 + scale*rax + 0], imm32
try w.print("{s} [{s} + {d}*rcx + 0], {d}\n", .{ @tagName(tag), @tagName(ops.reg1), scale, imm });
try w.print("{s} [{s} + {d}*rax + 0], {d}\n", .{ @tagName(tag), @tagName(ops.reg1), scale, imm });
}
// OP [reg1 + scale*rax + imm32], reg2
try w.print("{s} [{s} + {d}*rcx + {d}], {s}\n", .{ @tagName(tag), @tagName(ops.reg1), scale, imm, @tagName(ops.reg2) });
try w.print("{s} [{s} + {d}*rax + {d}], {s}\n", .{ @tagName(tag), @tagName(ops.reg1), scale, imm, @tagName(ops.reg2) });
}
fn mirArithScaleImm(print: *const Print, tag: Mir.Inst.Tag, inst: Mir.Inst.Index, w: anytype) !void {
@ -328,7 +339,21 @@ fn mirArithScaleImm(print: *const Print, tag: Mir.Inst.Tag, inst: Mir.Inst.Index
const scale = ops.flags;
const payload = print.mir.instructions.items(.data)[inst].payload;
const imm_pair = print.mir.extraData(Mir.ImmPair, payload).data;
try w.print("{s} [{s} + {d}*rcx + {d}], {d}\n", .{ @tagName(tag), @tagName(ops.reg1), scale, imm_pair.dest_off, imm_pair.operand });
try w.print("{s} [{s} + {d}*rax + {d}], {d}\n", .{ @tagName(tag), @tagName(ops.reg1), scale, imm_pair.dest_off, imm_pair.operand });
}
fn mirArithMemIndexImm(print: *const Print, tag: Mir.Inst.Tag, inst: Mir.Inst.Index, w: anytype) !void {
const ops = Mir.Ops.decode(print.mir.instructions.items(.ops)[inst]);
const payload = print.mir.instructions.items(.data)[inst].payload;
const imm_pair = print.mir.extraData(Mir.ImmPair, payload).data;
try w.print("{s} ", .{@tagName(tag)});
switch (ops.flags) {
0b00 => try w.print("byte ptr ", .{}),
0b01 => try w.print("word ptr ", .{}),
0b10 => try w.print("dword ptr ", .{}),
0b11 => try w.print("qword ptr ", .{}),
}
try w.print("[{s} + 1*rax + {d}], {d}\n", .{ @tagName(ops.reg1), imm_pair.dest_off, imm_pair.operand });
}
fn mirMovabs(print: *const Print, inst: Mir.Inst.Index, w: anytype) !void {