diff --git a/src/arch/x86_64/Emit.zig b/src/arch/x86_64/Emit.zig index 0c9027b260..3f3c696dfb 100644 --- a/src/arch/x86_64/Emit.zig +++ b/src/arch/x86_64/Emit.zig @@ -304,25 +304,13 @@ fn mirJmpCall(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void { }); return; } - const modrm_ext: u3 = switch (tag) { - .jmp_near => 0x4, - .call_near => 0x2, - else => unreachable, - }; if (ops.reg1 == .none) { // JMP/CALL [imm] const imm = emit.mir.instructions.items(.data)[inst].imm; - const encoder = try Encoder.init(emit.code, 7); - encoder.opcode_1byte(0xff); - encoder.modRm_SIBDisp0(modrm_ext); - encoder.sib_disp32(); - encoder.imm32(imm); - return; + return lowerToMEnc(tag, RegisterOrMemory.mem(null, imm), emit.code); } // JMP/CALL reg - const encoder = try Encoder.init(emit.code, 2); - encoder.opcode_1byte(0xff); - encoder.modRm_direct(modrm_ext, ops.reg1.lowId()); + return lowerToMEnc(tag, RegisterOrMemory.reg(ops.reg1), emit.code); } const CondType = enum { @@ -628,7 +616,7 @@ inline fn getOpCode(tag: Tag, enc: Encoding) ?u8 { } } -inline fn getModRmExt(tag: Tag) u3 { +inline fn getModRmExt(tag: Tag) ?u3 { return switch (tag) { .adc => 0x2, .add => 0x0, @@ -639,8 +627,9 @@ inline fn getModRmExt(tag: Tag) u3 { .sbb => 0x3, .cmp => 0x7, .mov => 0x0, + .jmp_near => 0x4, .call_near => 0x2, - else => unreachable, + else => null, }; } @@ -692,6 +681,67 @@ fn lowerToDEnc(tag: Tag, imm: i32, code: *std.ArrayList(u8)) InnerError!void { encoder.imm32(imm); } +fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8)) InnerError!void { + const opc = getOpCode(tag, .m).?; + const modrm_ext = getModRmExt(tag).?; + switch (reg_or_mem) { + .register => |reg| { + if (reg.size() != 64) return error.EmitFail; + const encoder = try Encoder.init(code, 3); + encoder.rex(.{ + .w = false, + .b = reg.isExtended(), + }); + encoder.opcode_1byte(opc); + encoder.modRm_direct(modrm_ext, reg.lowId()); + }, + .memory => |mem_op| { + const encoder = try Encoder.init(code, 8); + if (mem_op.reg) |reg| { + if (reg.size() != 64) return error.EmitFail; + encoder.rex(.{ + .w = false, + .b = reg.isExtended(), + }); + encoder.opcode_1byte(opc); + if (reg.lowId() == 4) { + if (mem_op.disp == 0) { + encoder.modRm_SIBDisp0(modrm_ext); + encoder.sib_base(reg.lowId()); + } else if (immOpSize(mem_op.disp) == 8) { + encoder.modRm_SIBDisp8(modrm_ext); + encoder.sib_baseDisp8(reg.lowId()); + encoder.disp8(@intCast(i8, mem_op.disp)); + } else { + encoder.modRm_SIBDisp32(modrm_ext); + encoder.sib_baseDisp32(reg.lowId()); + encoder.disp32(mem_op.disp); + } + } else { + if (mem_op.disp == 0) { + encoder.modRm_indirectDisp0(modrm_ext, reg.lowId()); + } else if (immOpSize(mem_op.disp) == 8) { + encoder.modRm_indirectDisp8(modrm_ext, reg.lowId()); + encoder.disp8(@intCast(i8, mem_op.disp)); + } else { + encoder.modRm_indirectDisp32(modrm_ext, reg.lowId()); + encoder.disp32(mem_op.disp); + } + } + } else { + encoder.opcode_1byte(opc); + if (mem_op.rip) { + encoder.modRm_RIPDisp32(modrm_ext); + } else { + encoder.modRm_SIBDisp0(modrm_ext); + encoder.sib_disp32(); + } + encoder.disp32(mem_op.disp); + } + }, + } +} + fn lowerToTdEnc(tag: Tag, moffs: i64, reg: Register, code: *std.ArrayList(u8)) InnerError!void { return lowerToTdFdEnc(tag, reg, moffs, code, true); } @@ -772,7 +822,7 @@ fn lowerToOiEnc(tag: Tag, reg: Register, imm: i64, code: *std.ArrayList(u8)) Inn fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.ArrayList(u8)) InnerError!void { var opc = getOpCode(tag, .mi).?; - const modrm_ext = getModRmExt(tag); + const modrm_ext = getModRmExt(tag).?; switch (reg_or_mem) { .register => |dst_reg| { if (dst_reg.size() == 8) { @@ -819,16 +869,25 @@ fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.Arr .b = dst_reg.isExtended(), }); encoder.opcode_1byte(opc); - if (dst_mem.disp == 0) { - encoder.modRm_indirectDisp0(modrm_ext, dst_reg.lowId()); - } else if (immOpSize(dst_mem.disp) == 8) { - encoder.modRm_indirectDisp8(modrm_ext, dst_reg.lowId()); - encoder.disp8(@intCast(i8, dst_mem.disp)); - } else { - if (dst_reg.lowId() == 4) { + if (dst_reg.lowId() == 4) { + if (dst_mem.disp == 0) { + encoder.modRm_SIBDisp0(modrm_ext); + encoder.sib_base(dst_reg.lowId()); + } else if (immOpSize(dst_mem.disp) == 8) { + encoder.modRm_SIBDisp8(modrm_ext); + encoder.sib_baseDisp8(dst_reg.lowId()); + encoder.disp8(@intCast(i8, dst_mem.disp)); + } else { encoder.modRm_SIBDisp32(modrm_ext); encoder.sib_baseDisp32(dst_reg.lowId()); encoder.disp32(dst_mem.disp); + } + } else { + if (dst_mem.disp == 0) { + encoder.modRm_indirectDisp0(modrm_ext, dst_reg.lowId()); + } else if (immOpSize(dst_mem.disp) == 8) { + encoder.modRm_indirectDisp8(modrm_ext, dst_reg.lowId()); + encoder.disp8(@intCast(i8, dst_mem.disp)); } else { encoder.modRm_indirectDisp32(modrm_ext, dst_reg.lowId()); encoder.disp32(dst_mem.disp); @@ -886,16 +945,25 @@ fn lowerToRmEnc( .b = src_reg.isExtended(), }); encoder.opcode_1byte(opc); - if (src_mem.disp == 0) { - encoder.modRm_indirectDisp0(reg.lowId(), src_reg.lowId()); - } else if (immOpSize(src_mem.disp) == 8) { - encoder.modRm_indirectDisp8(reg.lowId(), src_reg.lowId()); - encoder.disp8(@intCast(i8, src_mem.disp)); - } else { - if (src_reg.lowId() == 4) { + if (src_reg.lowId() == 4) { + if (src_mem.disp == 0) { + encoder.modRm_SIBDisp0(reg.lowId()); + encoder.sib_base(src_reg.lowId()); + } else if (immOpSize(src_mem.disp) == 8) { + encoder.modRm_SIBDisp8(reg.lowId()); + encoder.sib_baseDisp8(src_reg.lowId()); + encoder.disp8(@intCast(i8, src_mem.disp)); + } else { encoder.modRm_SIBDisp32(reg.lowId()); encoder.sib_baseDisp32(src_reg.lowId()); encoder.disp32(src_mem.disp); + } + } else { + if (src_mem.disp == 0) { + encoder.modRm_indirectDisp0(reg.lowId(), src_reg.lowId()); + } else if (immOpSize(src_mem.disp) == 8) { + encoder.modRm_indirectDisp8(reg.lowId(), src_reg.lowId()); + encoder.disp8(@intCast(i8, src_mem.disp)); } else { encoder.modRm_indirectDisp32(reg.lowId(), src_reg.lowId()); encoder.disp32(src_mem.disp); @@ -960,16 +1028,25 @@ fn lowerToMrEnc( .b = dst_reg.isExtended(), }); encoder.opcode_1byte(opc); - if (dst_mem.disp == 0) { - encoder.modRm_indirectDisp0(reg.lowId(), dst_reg.lowId()); - } else if (immOpSize(dst_mem.disp) == 8) { - encoder.modRm_indirectDisp8(reg.lowId(), dst_reg.lowId()); - encoder.disp8(@intCast(i8, dst_mem.disp)); - } else { - if (dst_reg.lowId() == 4) { + if (dst_reg.lowId() == 4) { + if (dst_mem.disp == 0) { + encoder.modRm_SIBDisp0(reg.lowId()); + encoder.sib_base(dst_reg.lowId()); + } else if (immOpSize(dst_mem.disp) == 8) { + encoder.modRm_SIBDisp8(reg.lowId()); + encoder.sib_baseDisp8(dst_reg.lowId()); + encoder.disp8(@intCast(i8, dst_mem.disp)); + } else { encoder.modRm_SIBDisp32(reg.lowId()); encoder.sib_baseDisp32(dst_reg.lowId()); encoder.disp32(dst_mem.disp); + } + } else { + if (dst_mem.disp == 0) { + encoder.modRm_indirectDisp0(reg.lowId(), dst_reg.lowId()); + } else if (immOpSize(dst_mem.disp) == 8) { + encoder.modRm_indirectDisp8(reg.lowId(), dst_reg.lowId()); + encoder.disp8(@intCast(i8, dst_mem.disp)); } else { encoder.modRm_indirectDisp32(reg.lowId(), dst_reg.lowId()); encoder.disp32(dst_mem.disp); @@ -1099,7 +1176,7 @@ fn mirArithScaleDst(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void if (ops.reg2 == .none) { // OP [reg1 + scale*rax + 0], imm32 var opc = getOpCode(tag, .mi).?; - const modrm_ext = getModRmExt(tag); + const modrm_ext = getModRmExt(tag).?; if (ops.reg1.size() == 8) { opc -= 1; } @@ -1153,7 +1230,7 @@ fn mirArithScaleImm(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void if (ops.reg1.size() == 8) { opc -= 1; } - const modrm_ext = getModRmExt(tag); + const modrm_ext = getModRmExt(tag).?; const encoder = try Encoder.init(emit.code, 2); encoder.rex(.{ .w = ops.reg1.size() == 64, @@ -1641,3 +1718,24 @@ test "lower FD/TD encoding" { try lowerToFdEnc(.mov, .al, 0x10, code.buffer()); try expectEqualHexStrings("\xa0\x10", code.emitted(), "mov al, ds:0x10"); } + +test "lower M encoding" { + var code = TestEmitCode.init(); + defer code.deinit(); + try lowerToMEnc(.jmp_near, RegisterOrMemory.reg(.r12), code.buffer()); + try expectEqualHexStrings("\x41\xFF\xE4", code.emitted(), "jmp r12"); + try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(.r12, 0), code.buffer()); + try expectEqualHexStrings("\x41\xFF\x24\x24", code.emitted(), "jmp qword ptr [r12]"); + try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(.r12, 0x10), code.buffer()); + try expectEqualHexStrings("\x41\xFF\x64\x24\x10", code.emitted(), "jmp qword ptr [r12 + 0x10]"); + try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(.r12, 0x1000), code.buffer()); + try expectEqualHexStrings( + "\x41\xFF\xA4\x24\x00\x10\x00\x00", + code.emitted(), + "jmp qword ptr [r12 + 0x1000]", + ); + try lowerToMEnc(.jmp_near, RegisterOrMemory.rip(0x10), code.buffer()); + try expectEqualHexStrings("\xFF\x25\x10\x00\x00\x00", code.emitted(), "jmp qword ptr [rip + 0x10]"); + try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(null, 0x10), code.buffer()); + try expectEqualHexStrings("\xFF\x24\x25\x10\x00\x00\x00", code.emitted(), "jmp qword ptr [ds:0x10]"); +}