diff --git a/src/AstGen.zig b/src/AstGen.zig index d3235ace53..9dc09ecd27 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -535,7 +535,7 @@ fn expr(gz: *GenZir, scope: *Scope, rl: ResultLoc, node: Ast.Node.Index) InnerEr return rvalue(gz, rl, .void_value, node); }, .assign_bit_shift_left_sat => { - try assignBinOpExt(gz, scope, node, .shl_with_saturation, Zir.Inst.SaturatingArithmetic); + try assignOpExt(gz, scope, node, .shl_with_saturation, Zir.Inst.SaturatingArithmetic); return rvalue(gz, rl, .void_value, node); }, .assign_bit_shift_right => { @@ -568,7 +568,7 @@ fn expr(gz: *GenZir, scope: *Scope, rl: ResultLoc, node: Ast.Node.Index) InnerEr return rvalue(gz, rl, .void_value, node); }, .assign_sub_sat => { - try assignBinOpExt(gz, scope, node, .sub_with_saturation, Zir.Inst.SaturatingArithmetic); + try assignOpExt(gz, scope, node, .sub_with_saturation, Zir.Inst.SaturatingArithmetic); return rvalue(gz, rl, .void_value, node); }, .assign_mod => { @@ -584,7 +584,7 @@ fn expr(gz: *GenZir, scope: *Scope, rl: ResultLoc, node: Ast.Node.Index) InnerEr return rvalue(gz, rl, .void_value, node); }, .assign_add_sat => { - try assignBinOpExt(gz, scope, node, .add_with_saturation, Zir.Inst.SaturatingArithmetic); + try assignOpExt(gz, scope, node, .add_with_saturation, Zir.Inst.SaturatingArithmetic); return rvalue(gz, rl, .void_value, node); }, .assign_mul => { @@ -596,26 +596,28 @@ fn expr(gz: *GenZir, scope: *Scope, rl: ResultLoc, node: Ast.Node.Index) InnerEr return rvalue(gz, rl, .void_value, node); }, .assign_mul_sat => { - try assignBinOpExt(gz, scope, node, .mul_with_saturation, Zir.Inst.SaturatingArithmetic); + try assignOpExt(gz, scope, node, .mul_with_saturation, Zir.Inst.SaturatingArithmetic); return rvalue(gz, rl, .void_value, node); }, // zig fmt: off .bit_shift_left => return shiftOp(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .shl), - .bit_shift_left_sat => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .shl_with_saturation, Zir.Inst.SaturatingArithmetic), .bit_shift_right => return shiftOp(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .shr), .add => return simpleBinOp(gz, scope, rl, node, .add), .add_wrap => return simpleBinOp(gz, scope, rl, node, .addwrap), - .add_sat => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .add_with_saturation, Zir.Inst.SaturatingArithmetic), .sub => return simpleBinOp(gz, scope, rl, node, .sub), .sub_wrap => return simpleBinOp(gz, scope, rl, node, .subwrap), - .sub_sat => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .sub_with_saturation, Zir.Inst.SaturatingArithmetic), .mul => return simpleBinOp(gz, scope, rl, node, .mul), .mul_wrap => return simpleBinOp(gz, scope, rl, node, .mulwrap), - .mul_sat => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .mul_with_saturation, Zir.Inst.SaturatingArithmetic), .div => return simpleBinOp(gz, scope, rl, node, .div), .mod => return simpleBinOp(gz, scope, rl, node, .mod_rem), + + .add_sat => return simpleBinOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .add_with_saturation, Zir.Inst.SaturatingArithmetic), + .sub_sat => return simpleBinOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .sub_with_saturation, Zir.Inst.SaturatingArithmetic), + .mul_sat => return simpleBinOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .mul_with_saturation, Zir.Inst.SaturatingArithmetic), + .bit_shift_left_sat => return simpleBinOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .shl_with_saturation, Zir.Inst.SaturatingArithmetic), + .bit_and => { const current_ampersand_token = main_tokens[node]; if (token_tags[current_ampersand_token + 1] == .ampersand) { @@ -2713,9 +2715,7 @@ fn assignOp( _ = try gz.addBin(.store, lhs_ptr, result); } -// TODO: is there an existing way to do this? -// TODO: likely rename this to reflect result_loc == .none or add more params to make it more general -fn binOpExt( +fn simpleBinOpExt( gz: *GenZir, scope: *Scope, rl: ResultLoc, @@ -2735,9 +2735,7 @@ fn binOpExt( return rvalue(gz, rl, result, infix_node); } -// TODO: is there an existing method to accomplish this? -// TODO: likely rename this to indicate rhs type coercion or add more params to make it more general -fn assignBinOpExt( +fn assignOpExt( gz: *GenZir, scope: *Scope, infix_node: Ast.Node.Index, diff --git a/src/Sema.zig b/src/Sema.zig index a41d330285..be10b6d663 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -6164,7 +6164,7 @@ fn zirNegate( const lhs = sema.resolveInst(.zero); const rhs = sema.resolveInst(inst_data.operand); - return sema.analyzeArithmetic(block, tag_override, lhs, rhs, src, lhs_src, rhs_src, null); + return sema.analyzeArithmetic(block, tag_override, lhs, rhs, src, lhs_src, rhs_src); } fn zirArithmetic( @@ -6184,7 +6184,7 @@ fn zirArithmetic( const lhs = sema.resolveInst(extra.lhs); const rhs = sema.resolveInst(extra.rhs); - return sema.analyzeArithmetic(block, zir_tag, lhs, rhs, sema.src, lhs_src, rhs_src, null); + return sema.analyzeArithmetic(block, zir_tag, lhs, rhs, sema.src, lhs_src, rhs_src); } fn zirOverflowArithmetic( @@ -6216,11 +6216,90 @@ fn zirSatArithmetic( const lhs = sema.resolveInst(extra.lhs); const rhs = sema.resolveInst(extra.rhs); - return sema.analyzeArithmetic(block, .extended, lhs, rhs, sema.src, lhs_src, rhs_src, extended); + return sema.analyzeSatArithmetic(block, lhs, rhs, sema.src, lhs_src, rhs_src, extended); +} + +fn analyzeSatArithmetic( + sema: *Sema, + block: *Scope.Block, + lhs: Air.Inst.Ref, + rhs: Air.Inst.Ref, + src: LazySrcLoc, + lhs_src: LazySrcLoc, + rhs_src: LazySrcLoc, + extended: Zir.Inst.Extended.InstData, +) CompileError!Air.Inst.Ref { + const lhs_ty = sema.typeOf(lhs); + const rhs_ty = sema.typeOf(rhs); + const lhs_zig_ty_tag = try lhs_ty.zigTypeTagOrPoison(); + const rhs_zig_ty_tag = try rhs_ty.zigTypeTagOrPoison(); + if (lhs_zig_ty_tag == .Vector and rhs_zig_ty_tag == .Vector) { + if (lhs_ty.arrayLen() != rhs_ty.arrayLen()) { + return sema.mod.fail(&block.base, src, "vector length mismatch: {d} and {d}", .{ + lhs_ty.arrayLen(), rhs_ty.arrayLen(), + }); + } + return sema.mod.fail(&block.base, src, "TODO implement support for vectors in zirBinOp", .{}); + } else if (lhs_zig_ty_tag == .Vector or rhs_zig_ty_tag == .Vector) { + return sema.mod.fail(&block.base, src, "mixed scalar and vector operands to binary expression: '{}' and '{}'", .{ + lhs_ty, rhs_ty, + }); + } + + if (lhs_zig_ty_tag == .Pointer or rhs_zig_ty_tag == .Pointer) + return sema.mod.fail(&block.base, src, "TODO implement support for pointers in zirSatArithmetic", .{}); + + const instructions = &[_]Air.Inst.Ref{ lhs, rhs }; + const resolved_type = try sema.resolvePeerTypes(block, src, instructions, .{ .override = &[_]LazySrcLoc{ lhs_src, rhs_src } }); + const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src); + const casted_rhs = try sema.coerce(block, resolved_type, rhs, rhs_src); + + const scalar_type = if (resolved_type.zigTypeTag() == .Vector) + resolved_type.elemType() + else + resolved_type; + + const scalar_tag = scalar_type.zigTypeTag(); + + const is_int = scalar_tag == .Int or scalar_tag == .ComptimeInt; + + if (!is_int) + return sema.mod.fail(&block.base, src, "invalid operands to binary expression: '{s}' and '{s}'", .{ + @tagName(lhs_zig_ty_tag), @tagName(rhs_zig_ty_tag), + }); + + if (try sema.resolveMaybeUndefVal(block, lhs_src, casted_lhs)) |lhs_val| { + if (try sema.resolveMaybeUndefVal(block, rhs_src, casted_rhs)) |rhs_val| { + if (lhs_val.isUndef() or rhs_val.isUndef()) { + return sema.addConstUndef(resolved_type); + } + // incase rhs is 0, simply return lhs without doing any calculations + if (rhs_val.compareWithZero(.eq)) { + switch (extended.opcode) { + .add_with_saturation, .sub_with_saturation => return sema.addConstant(scalar_type, lhs_val), + else => {}, + } + } + + return sema.mod.fail(&block.base, src, "TODO implement comptime saturating arithmetic for operand '{s}'", .{@tagName(extended.opcode)}); + } else { + try sema.requireRuntimeBlock(block, rhs_src); + } + } else { + try sema.requireRuntimeBlock(block, lhs_src); + } + + const air_tag: Air.Inst.Tag = switch (extended.opcode) { + .add_with_saturation => .addsat, + .sub_with_saturation => .subsat, + .mul_with_saturation => .mulsat, + .shl_with_saturation => .shl_sat, + else => return sema.mod.fail(&block.base, src, "TODO implement arithmetic for extended opcode '{s}'", .{@tagName(extended.opcode)}), + }; + + return block.addBinOp(air_tag, casted_lhs, casted_rhs); } -// TODO: audit - not sure if its a good idea to reuse this, adding `opt_extended` param -// FIXME: somehow, rhs of <<| is required to be Log2T. this should accept T fn analyzeArithmetic( sema: *Sema, block: *Scope.Block, @@ -6231,7 +6310,6 @@ fn analyzeArithmetic( src: LazySrcLoc, lhs_src: LazySrcLoc, rhs_src: LazySrcLoc, - opt_extended: ?Zir.Inst.Extended.InstData, ) CompileError!Air.Inst.Ref { const lhs_ty = sema.typeOf(lhs); const rhs_ty = sema.typeOf(rhs);