AstGen: forward result type through unary float builtins

Uses a new `float_op_result_ty` ZIR instruction tag.
This commit is contained in:
David Rubin 2025-08-28 07:46:12 -07:00 committed by GitHub
parent a31950aa57
commit 73a0b5441b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 109 additions and 26 deletions

View File

@ -9390,24 +9390,25 @@ fn builtinCall(
.embed_file => return simpleUnOp(gz, scope, ri, node, .{ .rl = .{ .coerced_ty = .slice_const_u8_type } }, params[0], .embed_file),
.error_name => return simpleUnOp(gz, scope, ri, node, .{ .rl = .{ .coerced_ty = .anyerror_type } }, params[0], .error_name),
.set_runtime_safety => return simpleUnOp(gz, scope, ri, node, coerced_bool_ri, params[0], .set_runtime_safety),
.sqrt => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .sqrt),
.sin => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .sin),
.cos => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .cos),
.tan => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .tan),
.exp => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .exp),
.exp2 => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .exp2),
.log => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .log),
.log2 => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .log2),
.log10 => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .log10),
.abs => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .abs),
.floor => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .floor),
.ceil => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .ceil),
.trunc => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .trunc),
.round => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .round),
.tag_name => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .tag_name),
.type_name => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .type_name),
.Frame => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .frame_type),
.sqrt => return floatUnOp(gz, scope, ri, node, params[0], .sqrt),
.sin => return floatUnOp(gz, scope, ri, node, params[0], .sin),
.cos => return floatUnOp(gz, scope, ri, node, params[0], .cos),
.tan => return floatUnOp(gz, scope, ri, node, params[0], .tan),
.exp => return floatUnOp(gz, scope, ri, node, params[0], .exp),
.exp2 => return floatUnOp(gz, scope, ri, node, params[0], .exp2),
.log => return floatUnOp(gz, scope, ri, node, params[0], .log),
.log2 => return floatUnOp(gz, scope, ri, node, params[0], .log2),
.log10 => return floatUnOp(gz, scope, ri, node, params[0], .log10),
.floor => return floatUnOp(gz, scope, ri, node, params[0], .floor),
.ceil => return floatUnOp(gz, scope, ri, node, params[0], .ceil),
.trunc => return floatUnOp(gz, scope, ri, node, params[0], .trunc),
.round => return floatUnOp(gz, scope, ri, node, params[0], .round),
.int_from_float => return typeCast(gz, scope, ri, node, params[0], .int_from_float, builtin_name),
.float_from_int => return typeCast(gz, scope, ri, node, params[0], .float_from_int, builtin_name),
.ptr_from_int => return typeCast(gz, scope, ri, node, params[0], .ptr_from_int, builtin_name),
@ -9860,6 +9861,26 @@ fn simpleUnOp(
return rvalue(gz, ri, result, node);
}
fn floatUnOp(
gz: *GenZir,
scope: *Scope,
ri: ResultInfo,
node: Ast.Node.Index,
operand_node: Ast.Node.Index,
tag: Zir.Inst.Tag,
) InnerError!Zir.Inst.Ref {
const result_type = try ri.rl.resultType(gz, node);
const operand_ri: ResultInfo.Loc = if (result_type) |rt| .{
.ty = try gz.addExtendedPayload(.float_op_result_ty, Zir.Inst.UnNode{
.node = gz.nodeIndexToRelative(node),
.operand = rt,
}),
} else .none;
const operand = try expr(gz, scope, .{ .rl = operand_ri }, operand_node);
const result = try gz.addUnNode(tag, operand, node);
return rvalue(gz, ri, result, node);
}
fn negation(
gz: *GenZir,
scope: *Scope,

View File

@ -889,6 +889,19 @@ fn builtinCall(astrl: *AstRlAnnotate, block: ?*Block, ri: ResultInfo, node: Ast.
.frame_address => return true,
// These builtins take a single argument with a known result type, but do not consume their
// result pointer.
.sqrt,
.sin,
.cos,
.tan,
.exp,
.exp2,
.log,
.log2,
.log10,
.floor,
.ceil,
.trunc,
.round,
.size_of,
.bit_size_of,
.align_of,
@ -918,20 +931,7 @@ fn builtinCall(astrl: *AstRlAnnotate, block: ?*Block, ri: ResultInfo, node: Ast.
// result pointer.
.int_from_ptr,
.int_from_enum,
.sqrt,
.sin,
.cos,
.tan,
.exp,
.exp2,
.log,
.log2,
.log10,
.abs,
.floor,
.ceil,
.trunc,
.round,
.tag_name,
.type_name,
.Frame,

View File

@ -2111,6 +2111,11 @@ pub const Inst = struct {
/// This instruction is always `noreturn`, however, it is not considered as such by ZIR-level queries. This allows AstGen to assume that
/// any code may have gone here, avoiding false-positive "unreachable code" errors.
astgen_error,
/// Given a type, strips away any error unions or optionals stacked
/// on top and returns the base type. That base type must be a float.
/// For example: Provided with error{Foo}!?f64, returns f64.
/// `operand` is `operand: Air.Inst.Ref`.
float_op_result_ty,
pub const InstData = struct {
opcode: Extended,
@ -4436,6 +4441,7 @@ fn findTrackableInner(
.tuple_decl,
.dbg_empty_stmt,
.astgen_error,
.float_op_result_ty,
=> return,
// `@TypeOf` has a body.

View File

@ -1154,6 +1154,10 @@ pub const Inst = struct {
pub fn fromValue(v: Value) Ref {
return .fromIntern(v.toIntern());
}
pub fn fromType(t: Type) Ref {
return .fromIntern(t.toIntern());
}
};
/// All instructions have an 8-byte payload, which is contained within

View File

@ -1468,6 +1468,7 @@ fn analyzeBodyInner(
continue;
},
.astgen_error => return error.AnalysisFail,
.float_op_result_ty => try sema.zirFloatOpResultType(block, extended),
};
},
@ -25922,6 +25923,28 @@ fn zirBranchHint(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat
}
}
fn zirFloatOpResultType(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {
const pt = sema.pt;
const zcu = pt.zcu;
const extra = sema.code.extraData(Zir.Inst.UnNode, extended.operand).data;
const operand_src = block.builtinCallArgSrc(extra.node, 0);
const raw_ty = try sema.resolveTypeOrPoison(block, operand_src, extra.operand) orelse return .generic_poison_type;
const float_ty = raw_ty.optEuBaseType(zcu);
switch (float_ty.scalarType(zcu).zigTypeTag(zcu)) {
.float, .comptime_float => {},
else => return sema.fail(
block,
operand_src,
"expected vector of floats or float type, found '{f}'",
.{float_ty.fmt(sema.pt)},
),
}
return .fromType(float_ty);
}
fn requireRuntimeBlock(sema: *Sema, block: *Block, src: LazySrcLoc, runtime_src: ?LazySrcLoc) !void {
if (block.isComptime()) {
const msg, const fail_block = msg: {

View File

@ -567,6 +567,7 @@ const Writer = struct {
.work_group_size,
.work_group_id,
.branch_hint,
.float_op_result_ty,
=> {
const inst_data = self.code.extraData(Zir.Inst.UnNode, extended.operand).data;
try self.writeInstRef(stream, inst_data.operand);

View File

@ -1737,3 +1737,31 @@ test "comptime calls are only memoized when float arguments are bit-for-bit equa
try comptime testMemoization();
try comptime testVectorMemoization(@Vector(4, f32));
}
test "result location forwarded through unary float builtins" {
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_spirv) return error.SkipZigTest;
const S = struct {
var x: u32 = 10;
};
var y: f64 = 0.0;
y = @sqrt(@floatFromInt(S.x));
y = @sin(@floatFromInt(S.x));
y = @cos(@floatFromInt(S.x));
y = @tan(@floatFromInt(S.x));
y = @exp(@floatFromInt(S.x));
y = @exp2(@floatFromInt(S.x));
y = @log(@floatFromInt(S.x));
y = @log2(@floatFromInt(S.x));
y = @log10(@floatFromInt(S.x));
y = @floor(@floatFromInt(S.x));
y = @ceil(@floatFromInt(S.x));
y = @trunc(@floatFromInt(S.x));
y = @round(@floatFromInt(S.x));
}