Sema: fix generic instantiation false negatives

The problem was that types of non-anytype parameters were being included
as part of the check to see if generic function instantiations were
equal. Now, Module.Fn additionally stores the information for whether each
parameter is anytype or not. `generic_poison` cannot be used to signal
this because the type is still needed for comptime arguments; in such
case the type will not be present in the newly generated function
prototype.

This presented one additional challenge: we need to compare equality of
two values where one of them is post-coercion and the other is not. So
we make some minor adjustments to `Type.eql` to support this. I think
this small complexity tradeoff is worth it because it means the compiler
does much less work on the hot path that a generic function is called
and there is already an existing matching instantiation.

closes #11146
This commit is contained in:
Andrew Kelley 2022-04-14 00:36:54 -07:00
parent 9b82e7f558
commit 2a00df9c09
4 changed files with 178 additions and 66 deletions

View File

@ -1394,7 +1394,15 @@ pub const Fn = struct {
/// there is a `TypedValue` here for each parameter of the function.
/// Non-comptime parameters are marked with a `generic_poison` for the value.
/// Non-anytype parameters are marked with a `generic_poison` for the type.
comptime_args: ?[*]TypedValue = null,
/// These never have .generic_poison for the Type
/// because the Type is needed to pass to `Type.eql` and for inserting comptime arguments
/// into the inst_map when analyzing the body of a generic function instantiation.
/// Instead, the is_anytype knowledge is communicated via `anytype_args`.
comptime_args: ?[*]TypedValue,
/// When comptime_args is null, this is undefined. Otherwise, this flags each
/// parameter and tells whether it is anytype.
/// TODO apply the same enhancement for param_names below to this field.
anytype_args: [*]bool,
/// The ZIR instruction that is a function instruction. Use this to find
/// the body. We store this rather than the body directly so that when ZIR
/// is regenerated on update(), we can map this to the new corresponding
@ -4782,18 +4790,24 @@ pub fn analyzeFnBody(mod: *Module, decl: *Decl, func: *Fn, arena: Allocator) Sem
else => continue,
};
if (func.comptime_args) |comptime_args| {
const param_ty = if (func.comptime_args) |comptime_args| t: {
const arg_tv = comptime_args[total_param_index];
if (arg_tv.val.tag() != .generic_poison) {
// We have a comptime value for this parameter.
const arg = try sema.addConstant(arg_tv.ty, arg_tv.val);
sema.inst_map.putAssumeCapacityNoClobber(inst, arg);
total_param_index += 1;
continue;
}
}
const param_type = fn_ty_info.param_types[runtime_param_index];
const opt_opv = sema.typeHasOnePossibleValue(&inner_block, param.src, param_type) catch |err| switch (err) {
const arg_val = if (arg_tv.val.tag() != .generic_poison)
arg_tv.val
else if (arg_tv.ty.onePossibleValue()) |opv|
opv
else
break :t arg_tv.ty;
const arg = try sema.addConstant(arg_tv.ty, arg_val);
sema.inst_map.putAssumeCapacityNoClobber(inst, arg);
total_param_index += 1;
continue;
} else fn_ty_info.param_types[runtime_param_index];
const opt_opv = sema.typeHasOnePossibleValue(&inner_block, param.src, param_ty) catch |err| switch (err) {
error.NeededSourceLocation => unreachable,
error.GenericPoison => unreachable,
error.ComptimeReturn => unreachable,
@ -4801,7 +4815,7 @@ pub fn analyzeFnBody(mod: *Module, decl: *Decl, func: *Fn, arena: Allocator) Sem
else => |e| return e,
};
if (opt_opv) |opv| {
const arg = try sema.addConstant(param_type, opv);
const arg = try sema.addConstant(param_ty, opv);
sema.inst_map.putAssumeCapacityNoClobber(inst, arg);
total_param_index += 1;
runtime_param_index += 1;
@ -4811,7 +4825,7 @@ pub fn analyzeFnBody(mod: *Module, decl: *Decl, func: *Fn, arena: Allocator) Sem
inner_block.instructions.appendAssumeCapacity(arg_index);
sema.air_instructions.appendAssumeCapacity(.{
.tag = .arg,
.data = .{ .ty = param_type },
.data = .{ .ty = param_ty },
});
sema.inst_map.putAssumeCapacityNoClobber(inst, Air.indexToRef(arg_index));
total_param_index += 1;

View File

@ -4707,6 +4707,8 @@ const GenericCallAdapter = struct {
generic_fn: *Module.Fn,
precomputed_hash: u64,
func_ty_info: Type.Payload.Function.Data,
/// Unlike comptime_args, the Type here is not always present.
/// .generic_poison is used to communicate non-anytype parameters.
comptime_tvs: []const TypedValue,
target: std.Target,
@ -4719,20 +4721,29 @@ const GenericCallAdapter = struct {
const other_comptime_args = other_key.comptime_args.?;
for (other_comptime_args[0..ctx.func_ty_info.param_types.len]) |other_arg, i| {
if (other_arg.ty.tag() != .generic_poison) {
// anytype parameter
if (!other_arg.ty.eql(ctx.comptime_tvs[i].ty, ctx.target)) {
const this_arg = ctx.comptime_tvs[i];
const this_is_comptime = this_arg.val.tag() != .generic_poison;
const other_is_comptime = other_arg.val.tag() != .generic_poison;
const this_is_anytype = this_arg.ty.tag() != .generic_poison;
const other_is_anytype = other_key.anytype_args[i];
if (other_is_anytype != this_is_anytype) return false;
if (other_is_comptime != this_is_comptime) return false;
if (this_is_anytype) {
// Both are anytype parameters.
if (!this_arg.ty.eql(other_arg.ty, ctx.target)) {
return false;
}
}
if (other_arg.val.tag() != .generic_poison) {
// comptime parameter
if (ctx.comptime_tvs[i].val.tag() == .generic_poison) {
// No match because the instantiation has a comptime parameter
// but the callsite does not.
return false;
if (this_is_comptime) {
// Both are comptime and anytype parameters with matching types.
if (!this_arg.val.eql(other_arg.val, other_arg.ty, ctx.target)) {
return false;
}
}
if (!other_arg.val.eql(ctx.comptime_tvs[i].val, other_arg.ty, ctx.target)) {
} else if (this_is_comptime) {
// Both are comptime parameters but not anytype parameters.
if (!this_arg.val.eql(other_arg.val, other_arg.ty, ctx.target)) {
return false;
}
}
@ -5227,28 +5238,61 @@ fn instantiateGenericCall(
const comptime_tvs = try sema.arena.alloc(TypedValue, func_ty_info.param_types.len);
const target = sema.mod.getTarget();
for (func_ty_info.param_types) |param_ty, i| {
const is_comptime = func_ty_info.paramIsComptime(i);
if (is_comptime) {
const arg_src = call_src; // TODO better source location
const casted_arg = try sema.coerce(block, param_ty, uncasted_args[i], arg_src);
if (try sema.resolveMaybeUndefVal(block, arg_src, casted_arg)) |arg_val| {
if (param_ty.tag() != .generic_poison) {
arg_val.hash(param_ty, &hasher, target);
{
var i: usize = 0;
for (fn_info.param_body) |inst| {
var is_comptime = false;
var is_anytype = false;
switch (zir_tags[inst]) {
.param => {
is_comptime = func_ty_info.paramIsComptime(i);
},
.param_comptime => {
is_comptime = true;
},
.param_anytype => {
is_anytype = true;
is_comptime = func_ty_info.paramIsComptime(i);
},
.param_anytype_comptime => {
is_anytype = true;
is_comptime = true;
},
else => continue,
}
if (is_comptime) {
const arg_src = call_src; // TODO better source location
const arg_ty = sema.typeOf(uncasted_args[i]);
const arg_val = try sema.resolveValue(block, arg_src, uncasted_args[i]);
arg_val.hash(arg_ty, &hasher, target);
if (is_anytype) {
arg_ty.hashWithHasher(&hasher, target);
comptime_tvs[i] = .{
.ty = arg_ty,
.val = arg_val,
};
} else {
comptime_tvs[i] = .{
.ty = Type.initTag(.generic_poison),
.val = arg_val,
};
}
} else if (is_anytype) {
const arg_ty = sema.typeOf(uncasted_args[i]);
arg_ty.hashWithHasher(&hasher, target);
comptime_tvs[i] = .{
// This will be different than `param_ty` in the case of `generic_poison`.
.ty = sema.typeOf(casted_arg),
.val = arg_val,
.ty = arg_ty,
.val = Value.initTag(.generic_poison),
};
} else {
return sema.failWithNeededComptime(block, arg_src);
comptime_tvs[i] = .{
.ty = Type.initTag(.generic_poison),
.val = Value.initTag(.generic_poison),
};
}
} else {
comptime_tvs[i] = .{
.ty = sema.typeOf(uncasted_args[i]),
.val = Value.initTag(.generic_poison),
};
i += 1;
}
}
@ -5411,19 +5455,48 @@ fn instantiateGenericCall(
errdefer new_func.deinit(gpa);
assert(new_func == new_module_func);
const anytype_args = try new_decl_arena_allocator.alloc(bool, func_ty_info.param_types.len);
new_func.anytype_args = anytype_args.ptr;
arg_i = 0;
for (fn_info.param_body) |inst| {
var is_comptime = false;
var is_anytype = false;
switch (zir_tags[inst]) {
.param_comptime, .param_anytype_comptime, .param, .param_anytype => {},
.param => {
is_comptime = func_ty_info.paramIsComptime(arg_i);
},
.param_comptime => {
is_comptime = true;
},
.param_anytype => {
is_anytype = true;
is_comptime = func_ty_info.paramIsComptime(arg_i);
},
.param_anytype_comptime => {
is_anytype = true;
is_comptime = true;
},
else => continue,
}
// We populate the Type here regardless because it is needed by
// `GenericCallAdapter.eql` as well as function body analysis.
// Whether it is anytype is communicated by `anytype_args`.
const arg = child_sema.inst_map.get(inst).?;
const copied_arg_ty = try child_sema.typeOf(arg).copy(new_decl_arena_allocator);
if (child_sema.resolveMaybeUndefValAllowVariables(
&child_block,
.unneeded,
arg,
) catch unreachable) |arg_val| {
anytype_args[arg_i] = is_anytype;
const arg_src = call_src; // TODO: better source location
if (try sema.typeRequiresComptime(block, arg_src, copied_arg_ty)) {
is_comptime = true;
}
if (is_comptime) {
const arg_val = (child_sema.resolveMaybeUndefValAllowVariables(
&child_block,
.unneeded,
arg,
) catch unreachable).?;
child_sema.comptime_args[arg_i] = .{
.ty = copied_arg_ty,
.val = try arg_val.copy(new_decl_arena_allocator),
@ -5480,22 +5553,7 @@ fn instantiateGenericCall(
const comptime_args = callee.comptime_args.?;
const new_fn_info = callee.owner_decl.ty.fnInfo();
const runtime_args_len = count: {
var count: u32 = 0;
var arg_i: usize = 0;
for (fn_info.param_body) |inst| {
switch (zir_tags[inst]) {
.param_comptime, .param_anytype_comptime, .param, .param_anytype => {
if (comptime_args[arg_i].val.tag() == .generic_poison) {
count += 1;
}
arg_i += 1;
},
else => continue,
}
}
break :count count;
};
const runtime_args_len = @intCast(u32, new_fn_info.param_types.len);
const runtime_args = try sema.arena.alloc(Air.Inst.Ref, runtime_args_len);
{
var runtime_i: u32 = 0;
@ -5505,7 +5563,9 @@ fn instantiateGenericCall(
.param_comptime, .param_anytype_comptime, .param, .param_anytype => {},
else => continue,
}
const is_runtime = comptime_args[total_i].val.tag() == .generic_poison;
const is_runtime = comptime_args[total_i].val.tag() == .generic_poison and
comptime_args[total_i].ty.hasRuntimeBits() and
!comptime_args[total_i].ty.comptimeOnly();
if (is_runtime) {
const param_ty = new_fn_info.param_types[runtime_i];
const arg_src = call_src; // TODO: better source location
@ -6562,6 +6622,7 @@ fn funcCommon(
.zir_body_inst = func_inst,
.owner_decl = sema.owner_decl,
.comptime_args = comptime_args,
.anytype_args = undefined,
.hash = hash,
.lbrace_line = src_locs.lbrace_line,
.rbrace_line = src_locs.rbrace_line,

View File

@ -954,6 +954,10 @@ pub const Value = extern union {
assert(ty.enumFieldCount() == 1);
break :blk 0;
},
.enum_literal => i: {
const name = val.castTag(.enum_literal).?.data;
break :i ty.enumFieldIndex(name).?;
},
// Assume it is already an integer and return it directly.
else => return val,
};
@ -2023,6 +2027,11 @@ pub const Value = extern union {
/// This function is used by hash maps and so treats floating-point NaNs as equal
/// to each other, and not equal to other floating-point values.
/// Similarly, it treats `undef` as a distinct value from all other values.
/// This function has to be able to support implicit coercion of `a` to `ty`. That is,
/// `ty` will be an exactly correct Type for `b` but it may be a post-coerced Type
/// for `a`. This function must act *as if* `a` has been coerced to `ty`. This complication
/// is required in order to make generic function instantiation effecient - specifically
/// the insertion into the monomorphized function table.
pub fn eql(a: Value, b: Value, ty: Type, target: Target) bool {
const a_tag = a.tag();
const b_tag = b.tag();
@ -2200,8 +2209,18 @@ pub const Value = extern union {
}
return order(a, b, target).compare(.eq);
},
else => return order(a, b, target).compare(.eq),
.Optional => {
if (a.tag() != .opt_payload and b.tag() == .opt_payload) {
var buffer: Payload.SubValue = .{
.base = .{ .tag = .opt_payload },
.data = a,
};
return eql(Value.initPayload(&buffer.base), b, ty, target);
}
},
else => {},
}
return order(a, b, target).compare(.eq);
}
/// This function is used by hash maps and so treats floating-point NaNs as equal

View File

@ -306,3 +306,21 @@ test "anonymous struct return type referencing comptime parameter" {
try expect(s.data == 1234);
try expect(s.end == 5678);
}
test "generic function instantiation non-duplicates" {
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.os.tag == .wasi) return error.SkipZigTest;
const S = struct {
fn copy(comptime T: type, dest: []T, source: []const T) void {
@export(foo, .{ .name = "test_generic_instantiation_non_dupe" });
for (source) |s, i| dest[i] = s;
}
fn foo() callconv(.C) void {}
};
var buffer: [100]u8 = undefined;
S.copy(u8, &buffer, "hello");
S.copy(u8, &buffer, "hello2");
}