From 8c664d3f6a59e412f33bca8c969f70ceb3545b11 Mon Sep 17 00:00:00 2001 From: Jakub Konka Date: Thu, 23 Dec 2021 18:49:03 +0100 Subject: [PATCH] stage2: support multibyte opcodes and refactor 1byte opcode changes --- src/arch/x86_64/Emit.zig | 224 +++++++++++++++++++-------------------- 1 file changed, 110 insertions(+), 114 deletions(-) diff --git a/src/arch/x86_64/Emit.zig b/src/arch/x86_64/Emit.zig index 118dc9cc76..e766d45a01 100644 --- a/src/arch/x86_64/Emit.zig +++ b/src/arch/x86_64/Emit.zig @@ -193,8 +193,7 @@ fn mirNop(emit: *Emit) InnerError!void { } fn mirSyscall(emit: *Emit) InnerError!void { - const encoder = try Encoder.init(emit.code, 2); - encoder.opcode_2byte(0x0f, 0x05); + return lowerToZoEnc(.syscall, emit.code); } fn mirPushPop(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void { @@ -470,6 +469,7 @@ const Tag = enum { @"test", brk, nop, + syscall, ret_near, ret_far, }; @@ -509,78 +509,104 @@ const Encoding = enum { td, }; -inline fn getOpCode(tag: Tag, enc: Encoding) ?u8 { +const OpCode = union(enum) { + one_byte: u8, + two_byte: struct { _1: u8, _2: u8 }, + + fn oneByte(opc: u8) OpCode { + return .{ .one_byte = opc }; + } + + fn twoByte(opc1: u8, opc2: u8) OpCode { + return .{ .two_byte = .{ ._1 = opc1, ._2 = opc2 } }; + } + + fn encode(opc: OpCode, encoder: Encoder) void { + switch (opc) { + .one_byte => |v| encoder.opcode_1byte(v), + .two_byte => |v| encoder.opcode_2byte(v._1, v._2), + } + } + + fn encodeWithReg(opc: OpCode, encoder: Encoder, reg: Register) void { + assert(opc == .one_byte); + encoder.opcode_withReg(opc.one_byte, reg.lowId()); + } +}; + +inline fn getOpCode(tag: Tag, enc: Encoding, is_one_byte: bool) ?OpCode { switch (enc) { .zo => return switch (tag) { - .ret_near => 0xc3, - .ret_far => 0xcb, - .brk => 0xcc, - .nop => 0x90, + .ret_near => OpCode.oneByte(0xc3), + .ret_far => OpCode.oneByte(0xcb), + .brk => OpCode.oneByte(0xcc), + .nop => OpCode.oneByte(0x90), + .syscall => OpCode.twoByte(0x0f, 0x05), else => null, }, .d => return switch (tag) { - .jmp_near => 0xe9, - .call_near => 0xe8, + .jmp_near => OpCode.oneByte(0xe9), + .call_near => OpCode.oneByte(0xe8), else => null, }, .m => return switch (tag) { - .jmp_near, .call_near, .push => 0xff, - .pop => 0x8f, + .jmp_near, .call_near, .push => OpCode.oneByte(0xff), + .pop => OpCode.oneByte(0x8f), else => null, }, .o => return switch (tag) { - .push => 0x50, - .pop => 0x58, + .push => OpCode.oneByte(0x50), + .pop => OpCode.oneByte(0x58), else => null, }, .i => return switch (tag) { - .push => 0x68, - .@"test" => 0xa9, - .ret_near => 0xc2, - .ret_far => 0xca, + .push => OpCode.oneByte(if (is_one_byte) 0x6a else 0x68), + .@"test" => OpCode.oneByte(if (is_one_byte) 0xa8 else 0xa9), + .ret_near => OpCode.oneByte(0xc2), + .ret_far => OpCode.oneByte(0xca), else => null, }, .mi => return switch (tag) { - .adc, .add, .sub, .xor, .@"and", .@"or", .sbb, .cmp => 0x81, - .mov => 0xc7, - .@"test" => 0xf7, + .adc, .add, .sub, .xor, .@"and", .@"or", .sbb, .cmp => OpCode.oneByte(if (is_one_byte) 0x80 else 0x81), + .mov => OpCode.oneByte(if (is_one_byte) 0xc6 else 0xc7), + .@"test" => OpCode.oneByte(if (is_one_byte) 0xf6 else 0xf7), else => null, }, .mr => return switch (tag) { - .adc => 0x11, - .add => 0x01, - .sub => 0x29, - .xor => 0x31, - .@"and" => 0x21, - .@"or" => 0x09, - .sbb => 0x19, - .cmp => 0x39, - .mov => 0x89, + .adc => OpCode.oneByte(if (is_one_byte) 0x10 else 0x11), + .add => OpCode.oneByte(if (is_one_byte) 0x00 else 0x01), + .sub => OpCode.oneByte(if (is_one_byte) 0x28 else 0x29), + .xor => OpCode.oneByte(if (is_one_byte) 0x30 else 0x31), + .@"and" => OpCode.oneByte(if (is_one_byte) 0x20 else 0x21), + .@"or" => OpCode.oneByte(if (is_one_byte) 0x08 else 0x09), + .sbb => OpCode.oneByte(if (is_one_byte) 0x18 else 0x19), + .cmp => OpCode.oneByte(if (is_one_byte) 0x38 else 0x39), + .mov => OpCode.oneByte(if (is_one_byte) 0x88 else 0x89), else => null, }, .rm => return switch (tag) { - .adc => 0x13, - .add => 0x03, - .sub => 0x2b, - .xor => 0x33, - .@"and" => 0x23, - .@"or" => 0x0b, - .sbb => 0x1b, - .cmp => 0x3b, - .mov => 0x8b, - .lea => 0x8d, + .adc => OpCode.oneByte(if (is_one_byte) 0x12 else 0x13), + .add => OpCode.oneByte(if (is_one_byte) 0x02 else 0x03), + .sub => OpCode.oneByte(if (is_one_byte) 0x2a else 0x2b), + .xor => OpCode.oneByte(if (is_one_byte) 0x32 else 0x33), + .@"and" => OpCode.oneByte(if (is_one_byte) 0x22 else 0x23), + .@"or" => OpCode.oneByte(if (is_one_byte) 0x0b else 0x0b), + .sbb => OpCode.oneByte(if (is_one_byte) 0x1a else 0x1b), + .cmp => OpCode.oneByte(if (is_one_byte) 0x3a else 0x3b), + .mov => OpCode.oneByte(if (is_one_byte) 0x8a else 0x8b), + .lea => OpCode.oneByte(if (is_one_byte) 0x8c else 0x8d), else => null, }, .oi => return switch (tag) { - .mov => 0xb8, + .mov => OpCode.oneByte(if (is_one_byte) 0xb0 else 0xb8), else => null, }, .fd => return switch (tag) { - .mov => 0xa1, + .mov => OpCode.oneByte(if (is_one_byte) 0xa0 else 0xa1), else => null, }, .td => return switch (tag) { - .mov => 0xa3, + .mov => OpCode.oneByte(if (is_one_byte) 0xa2 else 0xa3), else => null, }, } @@ -648,32 +674,25 @@ const RegisterOrMemory = union(enum) { }; fn lowerToZoEnc(tag: Tag, code: *std.ArrayList(u8)) InnerError!void { - const opc = getOpCode(tag, .zo).?; + const opc = getOpCode(tag, .zo, false).?; const encoder = try Encoder.init(code, 1); - encoder.opcode_1byte(opc); + opc.encode(encoder); } fn lowerToIEnc(tag: Tag, imm: i32, code: *std.ArrayList(u8)) InnerError!void { - var opc = getOpCode(tag, .i).?; if (tag == .ret_far or tag == .ret_near) { const encoder = try Encoder.init(code, 3); - encoder.opcode_1byte(opc); + const opc = getOpCode(tag, .i, false).?; + opc.encode(encoder); encoder.imm16(@intCast(i16, imm)); return; } - if (immOpSize(imm) == 8) { - // TODO I think getOpCode should track this - switch (tag) { - .push => opc += 2, - .@"test" => opc -= 1, - else => return error.EmitFail, - } - } + const opc = getOpCode(tag, .i, immOpSize(imm) == 8).?; const encoder = try Encoder.init(code, 5); if (immOpSize(imm) == 16) { encoder.opcode_1byte(0x66); } - encoder.opcode_1byte(opc); + opc.encode(encoder); if (immOpSize(imm) == 8) { encoder.imm8(@intCast(i8, imm)); } else if (immOpSize(imm) == 16) { @@ -685,7 +704,7 @@ fn lowerToIEnc(tag: Tag, imm: i32, code: *std.ArrayList(u8)) InnerError!void { fn lowerToOEnc(tag: Tag, reg: Register, code: *std.ArrayList(u8)) InnerError!void { if (reg.size() != 16 and reg.size() != 64) return error.EmitFail; // TODO correct for push/pop, but is it universal? - const opc = getOpCode(tag, .o).?; + const opc = getOpCode(tag, .o, false).?; const encoder = try Encoder.init(code, 3); if (reg.size() == 16) { encoder.opcode_1byte(0x66); @@ -694,18 +713,18 @@ fn lowerToOEnc(tag: Tag, reg: Register, code: *std.ArrayList(u8)) InnerError!voi .w = false, .b = reg.isExtended(), }); - encoder.opcode_withReg(opc, reg.lowId()); + opc.encodeWithReg(encoder, reg); } fn lowerToDEnc(tag: Tag, imm: i32, code: *std.ArrayList(u8)) InnerError!void { - const opc = getOpCode(tag, .d).?; + const opc = getOpCode(tag, .d, false).?; const encoder = try Encoder.init(code, 5); - encoder.opcode_1byte(opc); + opc.encode(encoder); encoder.imm32(imm); } fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8)) InnerError!void { - const opc = getOpCode(tag, .m).?; + const opc = getOpCode(tag, .m, false).?; const modrm_ext = getModRmExt(tag).?; switch (reg_or_mem) { .register => |reg| { @@ -715,7 +734,7 @@ fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8)) .w = false, .b = reg.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); encoder.modRm_direct(modrm_ext, reg.lowId()); }, .memory => |mem_op| { @@ -726,7 +745,7 @@ fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8)) .w = false, .b = reg.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); if (reg.lowId() == 4) { if (mem_op.disp == 0) { encoder.modRm_SIBDisp0(modrm_ext); @@ -752,7 +771,7 @@ fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8)) } } } else { - encoder.opcode_1byte(opc); + opc.encode(encoder); if (mem_op.rip) { encoder.modRm_RIPDisp32(modrm_ext); } else { @@ -776,10 +795,10 @@ fn lowerToFdEnc(tag: Tag, reg: Register, moffs: i64, code: *std.ArrayList(u8)) I fn lowerToTdFdEnc(tag: Tag, reg: Register, moffs: i64, code: *std.ArrayList(u8), td: bool) InnerError!void { if (reg.lowId() != Register.rax.lowId()) return error.EmitFail; if (reg.size() != immOpSize(moffs)) return error.EmitFail; - var opc = if (td) getOpCode(tag, .td).? else getOpCode(tag, .fd).?; - if (reg.size() == 8) { - opc -= 1; - } + const opc = if (td) + getOpCode(tag, .td, reg.size() == 8).? + else + getOpCode(tag, .fd, reg.size() == 8).?; const encoder = try Encoder.init(code, 10); if (reg.size() == 16) { encoder.opcode_1byte(0x66); @@ -787,7 +806,7 @@ fn lowerToTdFdEnc(tag: Tag, reg: Register, moffs: i64, code: *std.ArrayList(u8), encoder.rex(.{ .w = reg.size() == 64, }); - encoder.opcode_1byte(opc); + opc.encode(encoder); switch (reg.size()) { 8 => { const moffs8 = try math.cast(i8, moffs); @@ -809,11 +828,8 @@ fn lowerToTdFdEnc(tag: Tag, reg: Register, moffs: i64, code: *std.ArrayList(u8), } fn lowerToOiEnc(tag: Tag, reg: Register, imm: i64, code: *std.ArrayList(u8)) InnerError!void { - var opc = getOpCode(tag, .oi).?; if (reg.size() != immOpSize(imm)) return error.EmitFail; - if (reg.size() == 8) { - opc -= 8; - } + const opc = getOpCode(tag, .oi, reg.size() == 8).?; const encoder = try Encoder.init(code, 10); if (reg.size() == 16) { encoder.opcode_1byte(0x66); @@ -822,7 +838,7 @@ fn lowerToOiEnc(tag: Tag, reg: Register, imm: i64, code: *std.ArrayList(u8)) Inn .w = reg.size() == 64, .b = reg.isExtended(), }); - encoder.opcode_withReg(opc, reg.lowId()); + opc.encodeWithReg(encoder, reg); switch (reg.size()) { 8 => { const imm8 = try math.cast(i8, imm); @@ -844,13 +860,10 @@ fn lowerToOiEnc(tag: Tag, reg: Register, imm: i64, code: *std.ArrayList(u8)) Inn } fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.ArrayList(u8)) InnerError!void { - var opc = getOpCode(tag, .mi).?; const modrm_ext = getModRmExt(tag).?; switch (reg_or_mem) { .register => |dst_reg| { - if (dst_reg.size() == 8) { - opc -= 1; - } + const opc = getOpCode(tag, .mi, dst_reg.size() == 8).?; const encoder = try Encoder.init(code, 7); if (dst_reg.size() == 16) { // 0x66 prefix switches to the non-default size; here we assume a switch from @@ -862,7 +875,7 @@ fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.Arr .w = dst_reg.size() == 64, .b = dst_reg.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); encoder.modRm_direct(modrm_ext, dst_reg.lowId()); switch (dst_reg.size()) { 8 => { @@ -878,6 +891,7 @@ fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.Arr } }, .memory => |dst_mem| { + const opc = getOpCode(tag, .mi, false).?; const encoder = try Encoder.init(code, 12); if (dst_mem.reg) |dst_reg| { // Register dst_reg can either be 64bit or 32bit in size. @@ -891,7 +905,7 @@ fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.Arr .w = false, .b = dst_reg.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); if (dst_reg.lowId() == 4) { if (dst_mem.disp == 0) { encoder.modRm_SIBDisp0(modrm_ext); @@ -917,7 +931,7 @@ fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.Arr } } } else { - encoder.opcode_1byte(opc); + opc.encode(encoder); if (dst_mem.rip) { encoder.modRm_RIPDisp32(modrm_ext); } else { @@ -937,10 +951,7 @@ fn lowerToRmEnc( reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8), ) InnerError!void { - var opc = getOpCode(tag, .rm).?; - if (reg.size() == 8) { - opc -= 1; - } + const opc = getOpCode(tag, .rm, reg.size() == 8).?; switch (reg_or_mem) { .register => |src_reg| { if (reg.size() != src_reg.size()) return error.EmitFail; @@ -950,7 +961,7 @@ fn lowerToRmEnc( .r = reg.isExtended(), .b = src_reg.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); encoder.modRm_direct(reg.lowId(), src_reg.lowId()); }, .memory => |src_mem| { @@ -967,7 +978,7 @@ fn lowerToRmEnc( .r = reg.isExtended(), .b = src_reg.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); if (src_reg.lowId() == 4) { if (src_mem.disp == 0) { encoder.modRm_SIBDisp0(reg.lowId()); @@ -997,7 +1008,7 @@ fn lowerToRmEnc( .w = reg.size() == 64, .r = reg.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); if (src_mem.rip) { encoder.modRm_RIPDisp32(reg.lowId()); } else { @@ -1022,10 +1033,7 @@ fn lowerToMrEnc( // * reg is 32bit - dword ptr // * reg is 16bit - word ptr // * reg is 8bit - byte ptr - var opc = getOpCode(tag, .mr).?; - if (reg.size() == 8) { - opc -= 1; - } + const opc = getOpCode(tag, .mr, reg.size() == 8).?; switch (reg_or_mem) { .register => |dst_reg| { if (dst_reg.size() != reg.size()) return error.EmitFail; @@ -1035,7 +1043,7 @@ fn lowerToMrEnc( .r = reg.isExtended(), .b = dst_reg.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); encoder.modRm_direct(reg.lowId(), dst_reg.lowId()); }, .memory => |dst_mem| { @@ -1050,7 +1058,7 @@ fn lowerToMrEnc( .r = reg.isExtended(), .b = dst_reg.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); if (dst_reg.lowId() == 4) { if (dst_mem.disp == 0) { encoder.modRm_SIBDisp0(reg.lowId()); @@ -1080,7 +1088,7 @@ fn lowerToMrEnc( .w = reg.size() == 64, .r = reg.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); if (dst_mem.rip) { encoder.modRm_RIPDisp32(reg.lowId()); } else { @@ -1168,10 +1176,7 @@ fn mirArithScaleSrc(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]); const scale = ops.flags; // OP reg1, [reg2 + scale*rcx + imm32] - var opc = getOpCode(tag, .rm).?; - if (ops.reg1.size() == 8) { - opc -= 1; - } + const opc = getOpCode(tag, .rm, ops.reg1.size() == 8).?; const imm = emit.mir.instructions.items(.data)[inst].imm; const encoder = try Encoder.init(emit.code, 8); encoder.rex(.{ @@ -1179,7 +1184,7 @@ fn mirArithScaleSrc(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void .r = ops.reg1.isExtended(), .b = ops.reg2.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); if (imm <= math.maxInt(i8)) { encoder.modRm_SIBDisp8(ops.reg1.lowId()); encoder.sib_scaleIndexBaseDisp8(scale, Register.rcx.lowId(), ops.reg2.lowId()); @@ -1198,17 +1203,14 @@ fn mirArithScaleDst(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void if (ops.reg2 == .none) { // OP [reg1 + scale*rax + 0], imm32 - var opc = getOpCode(tag, .mi).?; + const opc = getOpCode(tag, .mi, ops.reg1.size() == 8).?; const modrm_ext = getModRmExt(tag).?; - if (ops.reg1.size() == 8) { - opc -= 1; - } const encoder = try Encoder.init(emit.code, 8); encoder.rex(.{ .w = ops.reg1.size() == 64, .b = ops.reg1.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); encoder.modRm_SIBDisp0(modrm_ext); encoder.sib_scaleIndexBase(scale, Register.rax.lowId(), ops.reg1.lowId()); if (imm <= math.maxInt(i8)) { @@ -1222,17 +1224,14 @@ fn mirArithScaleDst(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void } // OP [reg1 + scale*rax + imm32], reg2 - var opc = getOpCode(tag, .mr).?; - if (ops.reg1.size() == 8) { - opc -= 1; - } + const opc = getOpCode(tag, .mr, ops.reg1.size() == 8).?; const encoder = try Encoder.init(emit.code, 8); encoder.rex(.{ .w = ops.reg1.size() == 64, .r = ops.reg2.isExtended(), .b = ops.reg1.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); if (imm <= math.maxInt(i8)) { encoder.modRm_SIBDisp8(ops.reg2.lowId()); encoder.sib_scaleIndexBaseDisp8(scale, Register.rax.lowId(), ops.reg1.lowId()); @@ -1249,17 +1248,14 @@ fn mirArithScaleImm(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void const scale = ops.flags; const payload = emit.mir.instructions.items(.data)[inst].payload; const imm_pair = emit.mir.extraData(Mir.ImmPair, payload).data; - var opc = getOpCode(tag, .mi).?; - if (ops.reg1.size() == 8) { - opc -= 1; - } + const opc = getOpCode(tag, .mi, ops.reg1.size() == 8).?; const modrm_ext = getModRmExt(tag).?; const encoder = try Encoder.init(emit.code, 2); encoder.rex(.{ .w = ops.reg1.size() == 64, .b = ops.reg1.isExtended(), }); - encoder.opcode_1byte(opc); + opc.encode(encoder); if (imm_pair.dest_off <= math.maxInt(i8)) { encoder.modRm_SIBDisp8(modrm_ext); encoder.sib_scaleIndexBaseDisp8(scale, Register.rax.lowId(), ops.reg1.lowId());