stage2: add lowering of M encoding

Examples include jmp / call near with memory or register operand
like `jmp [rax]`, or even RIP-relative `call [rip + 0x10]`.
This commit is contained in:
Jakub Konka 2021-12-23 01:22:07 +01:00
parent 1167e248ef
commit 9078cb0197

View File

@ -304,25 +304,13 @@ fn mirJmpCall(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
});
return;
}
const modrm_ext: u3 = switch (tag) {
.jmp_near => 0x4,
.call_near => 0x2,
else => unreachable,
};
if (ops.reg1 == .none) {
// JMP/CALL [imm]
const imm = emit.mir.instructions.items(.data)[inst].imm;
const encoder = try Encoder.init(emit.code, 7);
encoder.opcode_1byte(0xff);
encoder.modRm_SIBDisp0(modrm_ext);
encoder.sib_disp32();
encoder.imm32(imm);
return;
return lowerToMEnc(tag, RegisterOrMemory.mem(null, imm), emit.code);
}
// JMP/CALL reg
const encoder = try Encoder.init(emit.code, 2);
encoder.opcode_1byte(0xff);
encoder.modRm_direct(modrm_ext, ops.reg1.lowId());
return lowerToMEnc(tag, RegisterOrMemory.reg(ops.reg1), emit.code);
}
const CondType = enum {
@ -628,7 +616,7 @@ inline fn getOpCode(tag: Tag, enc: Encoding) ?u8 {
}
}
inline fn getModRmExt(tag: Tag) u3 {
inline fn getModRmExt(tag: Tag) ?u3 {
return switch (tag) {
.adc => 0x2,
.add => 0x0,
@ -639,8 +627,9 @@ inline fn getModRmExt(tag: Tag) u3 {
.sbb => 0x3,
.cmp => 0x7,
.mov => 0x0,
.jmp_near => 0x4,
.call_near => 0x2,
else => unreachable,
else => null,
};
}
@ -692,6 +681,67 @@ fn lowerToDEnc(tag: Tag, imm: i32, code: *std.ArrayList(u8)) InnerError!void {
encoder.imm32(imm);
}
fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8)) InnerError!void {
const opc = getOpCode(tag, .m).?;
const modrm_ext = getModRmExt(tag).?;
switch (reg_or_mem) {
.register => |reg| {
if (reg.size() != 64) return error.EmitFail;
const encoder = try Encoder.init(code, 3);
encoder.rex(.{
.w = false,
.b = reg.isExtended(),
});
encoder.opcode_1byte(opc);
encoder.modRm_direct(modrm_ext, reg.lowId());
},
.memory => |mem_op| {
const encoder = try Encoder.init(code, 8);
if (mem_op.reg) |reg| {
if (reg.size() != 64) return error.EmitFail;
encoder.rex(.{
.w = false,
.b = reg.isExtended(),
});
encoder.opcode_1byte(opc);
if (reg.lowId() == 4) {
if (mem_op.disp == 0) {
encoder.modRm_SIBDisp0(modrm_ext);
encoder.sib_base(reg.lowId());
} else if (immOpSize(mem_op.disp) == 8) {
encoder.modRm_SIBDisp8(modrm_ext);
encoder.sib_baseDisp8(reg.lowId());
encoder.disp8(@intCast(i8, mem_op.disp));
} else {
encoder.modRm_SIBDisp32(modrm_ext);
encoder.sib_baseDisp32(reg.lowId());
encoder.disp32(mem_op.disp);
}
} else {
if (mem_op.disp == 0) {
encoder.modRm_indirectDisp0(modrm_ext, reg.lowId());
} else if (immOpSize(mem_op.disp) == 8) {
encoder.modRm_indirectDisp8(modrm_ext, reg.lowId());
encoder.disp8(@intCast(i8, mem_op.disp));
} else {
encoder.modRm_indirectDisp32(modrm_ext, reg.lowId());
encoder.disp32(mem_op.disp);
}
}
} else {
encoder.opcode_1byte(opc);
if (mem_op.rip) {
encoder.modRm_RIPDisp32(modrm_ext);
} else {
encoder.modRm_SIBDisp0(modrm_ext);
encoder.sib_disp32();
}
encoder.disp32(mem_op.disp);
}
},
}
}
fn lowerToTdEnc(tag: Tag, moffs: i64, reg: Register, code: *std.ArrayList(u8)) InnerError!void {
return lowerToTdFdEnc(tag, reg, moffs, code, true);
}
@ -772,7 +822,7 @@ fn lowerToOiEnc(tag: Tag, reg: Register, imm: i64, code: *std.ArrayList(u8)) Inn
fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.ArrayList(u8)) InnerError!void {
var opc = getOpCode(tag, .mi).?;
const modrm_ext = getModRmExt(tag);
const modrm_ext = getModRmExt(tag).?;
switch (reg_or_mem) {
.register => |dst_reg| {
if (dst_reg.size() == 8) {
@ -819,16 +869,25 @@ fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.Arr
.b = dst_reg.isExtended(),
});
encoder.opcode_1byte(opc);
if (dst_mem.disp == 0) {
encoder.modRm_indirectDisp0(modrm_ext, dst_reg.lowId());
} else if (immOpSize(dst_mem.disp) == 8) {
encoder.modRm_indirectDisp8(modrm_ext, dst_reg.lowId());
encoder.disp8(@intCast(i8, dst_mem.disp));
} else {
if (dst_reg.lowId() == 4) {
if (dst_reg.lowId() == 4) {
if (dst_mem.disp == 0) {
encoder.modRm_SIBDisp0(modrm_ext);
encoder.sib_base(dst_reg.lowId());
} else if (immOpSize(dst_mem.disp) == 8) {
encoder.modRm_SIBDisp8(modrm_ext);
encoder.sib_baseDisp8(dst_reg.lowId());
encoder.disp8(@intCast(i8, dst_mem.disp));
} else {
encoder.modRm_SIBDisp32(modrm_ext);
encoder.sib_baseDisp32(dst_reg.lowId());
encoder.disp32(dst_mem.disp);
}
} else {
if (dst_mem.disp == 0) {
encoder.modRm_indirectDisp0(modrm_ext, dst_reg.lowId());
} else if (immOpSize(dst_mem.disp) == 8) {
encoder.modRm_indirectDisp8(modrm_ext, dst_reg.lowId());
encoder.disp8(@intCast(i8, dst_mem.disp));
} else {
encoder.modRm_indirectDisp32(modrm_ext, dst_reg.lowId());
encoder.disp32(dst_mem.disp);
@ -886,16 +945,25 @@ fn lowerToRmEnc(
.b = src_reg.isExtended(),
});
encoder.opcode_1byte(opc);
if (src_mem.disp == 0) {
encoder.modRm_indirectDisp0(reg.lowId(), src_reg.lowId());
} else if (immOpSize(src_mem.disp) == 8) {
encoder.modRm_indirectDisp8(reg.lowId(), src_reg.lowId());
encoder.disp8(@intCast(i8, src_mem.disp));
} else {
if (src_reg.lowId() == 4) {
if (src_reg.lowId() == 4) {
if (src_mem.disp == 0) {
encoder.modRm_SIBDisp0(reg.lowId());
encoder.sib_base(src_reg.lowId());
} else if (immOpSize(src_mem.disp) == 8) {
encoder.modRm_SIBDisp8(reg.lowId());
encoder.sib_baseDisp8(src_reg.lowId());
encoder.disp8(@intCast(i8, src_mem.disp));
} else {
encoder.modRm_SIBDisp32(reg.lowId());
encoder.sib_baseDisp32(src_reg.lowId());
encoder.disp32(src_mem.disp);
}
} else {
if (src_mem.disp == 0) {
encoder.modRm_indirectDisp0(reg.lowId(), src_reg.lowId());
} else if (immOpSize(src_mem.disp) == 8) {
encoder.modRm_indirectDisp8(reg.lowId(), src_reg.lowId());
encoder.disp8(@intCast(i8, src_mem.disp));
} else {
encoder.modRm_indirectDisp32(reg.lowId(), src_reg.lowId());
encoder.disp32(src_mem.disp);
@ -960,16 +1028,25 @@ fn lowerToMrEnc(
.b = dst_reg.isExtended(),
});
encoder.opcode_1byte(opc);
if (dst_mem.disp == 0) {
encoder.modRm_indirectDisp0(reg.lowId(), dst_reg.lowId());
} else if (immOpSize(dst_mem.disp) == 8) {
encoder.modRm_indirectDisp8(reg.lowId(), dst_reg.lowId());
encoder.disp8(@intCast(i8, dst_mem.disp));
} else {
if (dst_reg.lowId() == 4) {
if (dst_reg.lowId() == 4) {
if (dst_mem.disp == 0) {
encoder.modRm_SIBDisp0(reg.lowId());
encoder.sib_base(dst_reg.lowId());
} else if (immOpSize(dst_mem.disp) == 8) {
encoder.modRm_SIBDisp8(reg.lowId());
encoder.sib_baseDisp8(dst_reg.lowId());
encoder.disp8(@intCast(i8, dst_mem.disp));
} else {
encoder.modRm_SIBDisp32(reg.lowId());
encoder.sib_baseDisp32(dst_reg.lowId());
encoder.disp32(dst_mem.disp);
}
} else {
if (dst_mem.disp == 0) {
encoder.modRm_indirectDisp0(reg.lowId(), dst_reg.lowId());
} else if (immOpSize(dst_mem.disp) == 8) {
encoder.modRm_indirectDisp8(reg.lowId(), dst_reg.lowId());
encoder.disp8(@intCast(i8, dst_mem.disp));
} else {
encoder.modRm_indirectDisp32(reg.lowId(), dst_reg.lowId());
encoder.disp32(dst_mem.disp);
@ -1099,7 +1176,7 @@ fn mirArithScaleDst(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
if (ops.reg2 == .none) {
// OP [reg1 + scale*rax + 0], imm32
var opc = getOpCode(tag, .mi).?;
const modrm_ext = getModRmExt(tag);
const modrm_ext = getModRmExt(tag).?;
if (ops.reg1.size() == 8) {
opc -= 1;
}
@ -1153,7 +1230,7 @@ fn mirArithScaleImm(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
if (ops.reg1.size() == 8) {
opc -= 1;
}
const modrm_ext = getModRmExt(tag);
const modrm_ext = getModRmExt(tag).?;
const encoder = try Encoder.init(emit.code, 2);
encoder.rex(.{
.w = ops.reg1.size() == 64,
@ -1641,3 +1718,24 @@ test "lower FD/TD encoding" {
try lowerToFdEnc(.mov, .al, 0x10, code.buffer());
try expectEqualHexStrings("\xa0\x10", code.emitted(), "mov al, ds:0x10");
}
test "lower M encoding" {
var code = TestEmitCode.init();
defer code.deinit();
try lowerToMEnc(.jmp_near, RegisterOrMemory.reg(.r12), code.buffer());
try expectEqualHexStrings("\x41\xFF\xE4", code.emitted(), "jmp r12");
try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(.r12, 0), code.buffer());
try expectEqualHexStrings("\x41\xFF\x24\x24", code.emitted(), "jmp qword ptr [r12]");
try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(.r12, 0x10), code.buffer());
try expectEqualHexStrings("\x41\xFF\x64\x24\x10", code.emitted(), "jmp qword ptr [r12 + 0x10]");
try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(.r12, 0x1000), code.buffer());
try expectEqualHexStrings(
"\x41\xFF\xA4\x24\x00\x10\x00\x00",
code.emitted(),
"jmp qword ptr [r12 + 0x1000]",
);
try lowerToMEnc(.jmp_near, RegisterOrMemory.rip(0x10), code.buffer());
try expectEqualHexStrings("\xFF\x25\x10\x00\x00\x00", code.emitted(), "jmp qword ptr [rip + 0x10]");
try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(null, 0x10), code.buffer());
try expectEqualHexStrings("\xFF\x24\x25\x10\x00\x00\x00", code.emitted(), "jmp qword ptr [ds:0x10]");
}