diff --git a/src/Module.zig b/src/Module.zig index 29c091abd3..4ed39c9954 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -66,6 +66,10 @@ import_table: std.StringArrayHashMapUnmanaged(*Scope.File) = .{}, /// to the same function. monomorphed_funcs: MonomorphedFuncsSet = .{}, +/// The set of all comptime function calls that have been cached so that future calls +/// with the same parameters will get the same return value. +memoized_calls: MemoizedCallSet = .{}, + /// We optimize memory usage for a compilation with no compile errors by storing the /// error messages and mapping outside of `Decl`. /// The ErrorMsg memory is owned by the decl, using Module's general purpose allocator. @@ -157,6 +161,60 @@ const MonomorphedFuncsContext = struct { } }; +pub const MemoizedCallSet = std.HashMapUnmanaged( + MemoizedCall.Key, + MemoizedCall.Result, + MemoizedCall, + std.hash_map.default_max_load_percentage, +); + +pub const MemoizedCall = struct { + pub const Key = struct { + func: *Fn, + args: []TypedValue, + }; + + pub const Result = struct { + val: Value, + arena: std.heap.ArenaAllocator.State, + }; + + pub fn eql(ctx: @This(), a: Key, b: Key) bool { + _ = ctx; + + if (a.func != b.func) return false; + + assert(a.args.len == b.args.len); + for (a.args) |a_arg, arg_i| { + const b_arg = b.args[arg_i]; + if (!a_arg.eql(b_arg)) { + return false; + } + } + + return true; + } + + /// Must match `Sema.GenericCallAdapter.hash`. + pub fn hash(ctx: @This(), key: Key) u64 { + _ = ctx; + + var hasher = std.hash.Wyhash.init(0); + + // The generic function Decl is guaranteed to be the first dependency + // of each of its instantiations. + std.hash.autoHash(&hasher, @ptrToInt(key.func)); + + // This logic must be kept in sync with the logic in `analyzeCall` that + // computes the hash. + for (key.args) |arg| { + arg.hash(&hasher); + } + + return hasher.final(); + } +}; + /// A `Module` has zero or one of these depending on whether `-femit-h` is enabled. pub const GlobalEmitH = struct { /// Where to put the output. @@ -2255,15 +2313,26 @@ pub fn deinit(mod: *Module) void { } mod.export_owners.deinit(gpa); - var it = mod.global_error_set.keyIterator(); - while (it.next()) |key| { - gpa.free(key.*); + { + var it = mod.global_error_set.keyIterator(); + while (it.next()) |key| { + gpa.free(key.*); + } + mod.global_error_set.deinit(gpa); } - mod.global_error_set.deinit(gpa); mod.error_name_list.deinit(gpa); mod.test_functions.deinit(gpa); mod.monomorphed_funcs.deinit(gpa); + + { + var it = mod.memoized_calls.iterator(); + while (it.next()) |entry| { + gpa.free(entry.key_ptr.args); + entry.value_ptr.arena.promote(gpa).deinit(); + } + mod.memoized_calls.deinit(gpa); + } } fn freeExportList(gpa: *Allocator, export_list: []*Export) void { diff --git a/src/Sema.zig b/src/Sema.zig index 36ac4d224d..d543b59ac0 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -649,6 +649,24 @@ fn resolveValue( return sema.failWithNeededComptime(block, src); } +/// Value Tag `variable` will cause a compile error. +/// Value Tag `undef` may be returned. +fn resolveConstMaybeUndefVal( + sema: *Sema, + block: *Scope.Block, + src: LazySrcLoc, + inst: Air.Inst.Ref, +) CompileError!Value { + if (try sema.resolveMaybeUndefValAllowVariables(block, src, inst)) |val| { + switch (val.tag()) { + .variable => return sema.failWithNeededComptime(block, src), + .generic_poison => return error.GenericPoison, + else => return val, + } + } + return sema.failWithNeededComptime(block, src); +} + /// Will not return Value Tags: `variable`, `undef`. Instead they will emit compile errors. /// See `resolveValue` for an alternative. fn resolveConstValue( @@ -2565,6 +2583,19 @@ fn analyzeCall( defer merges.results.deinit(gpa); defer merges.br_list.deinit(gpa); + // If it's a comptime function call, we need to memoize it as long as no external + // comptime memory is mutated. + var memoized_call_key: Module.MemoizedCall.Key = undefined; + var delete_memoized_call_key = false; + defer if (delete_memoized_call_key) gpa.free(memoized_call_key.args); + if (is_comptime_call) { + memoized_call_key = .{ + .func = module_fn, + .args = try gpa.alloc(TypedValue, func_ty_info.param_types.len), + }; + delete_memoized_call_key = true; + } + try sema.emitBackwardBranch(&child_block, call_src); // This will have return instructions analyzed as break instructions to @@ -2589,12 +2620,32 @@ fn analyzeCall( const arg_src = call_src; // TODO: better source location const casted_arg = try sema.coerce(&child_block, param_ty, uncasted_args[arg_i], arg_src); try sema.inst_map.putNoClobber(gpa, inst, casted_arg); + + if (is_comptime_call) { + const arg_val = try sema.resolveConstMaybeUndefVal(&child_block, arg_src, casted_arg); + memoized_call_key.args[arg_i] = .{ + .ty = param_ty, + .val = arg_val, + }; + } + arg_i += 1; continue; }, .param_anytype, .param_anytype_comptime => { // No coercion needed. - try sema.inst_map.putNoClobber(gpa, inst, uncasted_args[arg_i]); + const uncasted_arg = uncasted_args[arg_i]; + try sema.inst_map.putNoClobber(gpa, inst, uncasted_arg); + + if (is_comptime_call) { + const arg_src = call_src; // TODO: better source location + const arg_val = try sema.resolveConstMaybeUndefVal(&child_block, arg_src, uncasted_arg); + memoized_call_key.args[arg_i] = .{ + .ty = sema.typeOf(uncasted_arg), + .val = arg_val, + }; + } + arg_i += 1; continue; }, @@ -2626,19 +2677,61 @@ fn analyzeCall( sema.fn_ret_ty = fn_ret_ty; defer sema.fn_ret_ty = parent_fn_ret_ty; - _ = try sema.analyzeBody(&child_block, fn_info.body); - const result = try sema.analyzeBlockBody(block, call_src, &child_block, merges); + // This `res2` is here instead of directly breaking from `res` due to a stage1 + // bug generating invalid LLVM IR. + const res2: Air.Inst.Ref = res2: { + if (is_comptime_call) { + if (mod.memoized_calls.get(memoized_call_key)) |result| { + const ty_inst = try sema.addType(fn_ret_ty); + try sema.air_values.append(gpa, result.val); + sema.air_instructions.set(block_inst, .{ + .tag = .constant, + .data = .{ .ty_pl = .{ + .ty = ty_inst, + .payload = @intCast(u32, sema.air_values.items.len - 1), + } }, + }); + break :res2 Air.indexToRef(block_inst); + } + } - // Much like in `Module.semaDecl`, if the result is a struct or union type, - // we need to resolve the field type expressions right here, right now, while - // the child `Sema` is still available, with the AIR instruction map intact, - // because the field type expressions may reference into it. - if (sema.typeOf(result).zigTypeTag() == .Type) { - const ty = try sema.analyzeAsType(&child_block, call_src, result); - try sema.resolveDeclFields(&child_block, call_src, ty); - } + _ = try sema.analyzeBody(&child_block, fn_info.body); + const result = try sema.analyzeBlockBody(block, call_src, &child_block, merges); - break :res result; + if (is_comptime_call) { + const result_val = try sema.resolveConstMaybeUndefVal(block, call_src, result); + + // TODO: check whether any external comptime memory was mutated by the + // comptime function call. If so, then do not memoize the call here. + { + var arena_allocator = std.heap.ArenaAllocator.init(gpa); + errdefer arena_allocator.deinit(); + const arena = &arena_allocator.allocator; + + for (memoized_call_key.args) |*arg| { + arg.* = try arg.*.copy(arena); + } + + try mod.memoized_calls.put(gpa, memoized_call_key, .{ + .val = result_val, + .arena = arena_allocator.state, + }); + delete_memoized_call_key = false; + } + + // Much like in `Module.semaDecl`, if the result is a struct or union type, + // we need to resolve the field type expressions right here, right now, while + // the child `Sema` is still available, with the AIR instruction map intact, + // because the field type expressions may reference into it. + if (sema.typeOf(result).zigTypeTag() == .Type) { + const ty = try sema.analyzeAsType(&child_block, call_src, result); + try sema.resolveDeclFields(&child_block, call_src, ty); + } + } + + break :res2 result; + }; + break :res res2; } else if (func_ty_info.is_generic) res: { const func_val = try sema.resolveConstValue(block, func_src, func); const module_fn = func_val.castTag(.function).?.data; @@ -3305,31 +3398,9 @@ fn zirEnumToInt(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileE } if (try sema.resolveMaybeUndefVal(block, operand_src, enum_tag)) |enum_tag_val| { - if (enum_tag_val.castTag(.enum_field_index)) |enum_field_payload| { - const field_index = enum_field_payload.data; - switch (enum_tag_ty.tag()) { - .enum_full => { - const enum_full = enum_tag_ty.castTag(.enum_full).?.data; - if (enum_full.values.count() != 0) { - const val = enum_full.values.keys()[field_index]; - return sema.addConstant(int_tag_ty, val); - } else { - // Field index and integer values are the same. - const val = try Value.Tag.int_u64.create(arena, field_index); - return sema.addConstant(int_tag_ty, val); - } - }, - .enum_simple => { - // Field index and integer values are the same. - const val = try Value.Tag.int_u64.create(arena, field_index); - return sema.addConstant(int_tag_ty, val); - }, - else => unreachable, - } - } else { - // Assume it is already an integer and return it directly. - return sema.addConstant(int_tag_ty, enum_tag_val); - } + var buffer: Value.Payload.U64 = undefined; + const val = enum_tag_val.enumToInt(enum_tag_ty, &buffer); + return sema.addConstant(int_tag_ty, try val.copy(sema.arena)); } try sema.requireRuntimeBlock(block, src); @@ -3414,7 +3485,10 @@ fn zirOptionalPayloadPtr( return sema.mod.fail(&block.base, src, "unable to unwrap null", .{}); } // The same Value represents the pointer to the optional and the payload. - return sema.addConstant(child_pointer, pointer_val); + return sema.addConstant( + child_pointer, + try Value.Tag.opt_payload_ptr.create(sema.arena, pointer_val), + ); } } @@ -3451,7 +3525,8 @@ fn zirOptionalPayload( if (val.isNull()) { return sema.mod.fail(&block.base, src, "unable to unwrap null", .{}); } - return sema.addConstant(child_type, val); + const sub_val = val.castTag(.opt_payload).?.data; + return sema.addConstant(child_type, sub_val); } try sema.requireRuntimeBlock(block, src); @@ -9095,7 +9170,7 @@ fn wrapOptional( inst_src: LazySrcLoc, ) !Air.Inst.Ref { if (try sema.resolveMaybeUndefVal(block, inst_src, inst)) |val| { - return sema.addConstant(dest_type, val); + return sema.addConstant(dest_type, try Value.Tag.opt_payload.create(sema.arena, val)); } try sema.requireRuntimeBlock(block, inst_src); diff --git a/src/TypedValue.zig b/src/TypedValue.zig index 48b2c04970..83242b5329 100644 --- a/src/TypedValue.zig +++ b/src/TypedValue.zig @@ -23,9 +23,18 @@ pub const Managed = struct { }; /// Assumes arena allocation. Does a recursive copy. -pub fn copy(self: TypedValue, allocator: *Allocator) error{OutOfMemory}!TypedValue { +pub fn copy(self: TypedValue, arena: *Allocator) error{OutOfMemory}!TypedValue { return TypedValue{ - .ty = try self.ty.copy(allocator), - .val = try self.val.copy(allocator), + .ty = try self.ty.copy(arena), + .val = try self.val.copy(arena), }; } + +pub fn eql(a: TypedValue, b: TypedValue) bool { + if (!a.ty.eql(b.ty)) return false; + return a.val.eql(b.val, a.ty); +} + +pub fn hash(tv: TypedValue, hasher: *std.hash.Wyhash) void { + return tv.val.hash(tv.ty, hasher); +} diff --git a/src/codegen/c.zig b/src/codegen/c.zig index e7994ffd06..bc3e357827 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -319,18 +319,20 @@ pub const DeclGen = struct { .Bool => return writer.print("{}", .{val.toBool()}), .Optional => { var opt_buf: Type.Payload.ElemType = undefined; - const child_type = t.optionalChild(&opt_buf); + const payload_type = t.optionalChild(&opt_buf); if (t.isPtrLikeOptional()) { - return dg.renderValue(writer, child_type, val); + return dg.renderValue(writer, payload_type, val); } try writer.writeByte('('); try dg.renderType(writer, t); - if (val.tag() == .null_value) { - try writer.writeAll("){ .is_null = true }"); - } else { - try writer.writeAll("){ .is_null = false, .payload = "); - try dg.renderValue(writer, child_type, val); + try writer.writeAll("){"); + if (val.castTag(.opt_payload)) |pl| { + const payload_val = pl.data; + try writer.writeAll(" .is_null = false, .payload = "); + try dg.renderValue(writer, payload_type, payload_val); try writer.writeAll(" }"); + } else { + try writer.writeAll(" .is_null = true }"); } }, .ErrorSet => { diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 745fb036db..b80dcf0feb 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -810,27 +810,22 @@ pub const DeclGen = struct { return self.todo("handle more array values", .{}); }, .Optional => { - if (!tv.ty.isPtrLikeOptional()) { - var buf: Type.Payload.ElemType = undefined; - const child_type = tv.ty.optionalChild(&buf); - const llvm_child_type = try self.llvmType(child_type); - - if (tv.val.tag() == .null_value) { - var optional_values: [2]*const llvm.Value = .{ - llvm_child_type.constNull(), - self.context.intType(1).constNull(), - }; - return self.context.constStruct(&optional_values, optional_values.len, .False); - } else { - var optional_values: [2]*const llvm.Value = .{ - try self.genTypedValue(.{ .ty = child_type, .val = tv.val }), - self.context.intType(1).constAllOnes(), - }; - return self.context.constStruct(&optional_values, optional_values.len, .False); - } - } else { + if (tv.ty.isPtrLikeOptional()) { return self.todo("implement const of optional pointer", .{}); } + var buf: Type.Payload.ElemType = undefined; + const payload_type = tv.ty.optionalChild(&buf); + const is_pl = !tv.val.isNull(); + const llvm_i1 = self.context.intType(1); + + const fields: [2]*const llvm.Value = .{ + try self.genTypedValue(.{ + .ty = payload_type, + .val = if (tv.val.castTag(.opt_payload)) |pl| pl.data else Value.initTag(.undef), + }), + if (is_pl) llvm_i1.constAllOnes() else llvm_i1.constNull(), + }; + return self.context.constStruct(&fields, fields.len, .False); }, .Fn => { const fn_decl = switch (tv.val.tag()) { diff --git a/src/codegen/wasm.zig b/src/codegen/wasm.zig index 4814ba0b55..422afef9c4 100644 --- a/src/codegen/wasm.zig +++ b/src/codegen/wasm.zig @@ -1198,7 +1198,12 @@ pub const Context = struct { // When constant has value 'null', set is_null local to '1' // and payload to '0' - if (val.tag() == .null_value) { + if (val.castTag(.opt_payload)) |pl| { + const payload_val = pl.data; + try writer.writeByte(wasm.opcode(.i32_const)); + try leb.writeILEB128(writer, @as(i32, 0)); + try self.emitConstant(payload_val, payload_type); + } else { try writer.writeByte(wasm.opcode(.i32_const)); try leb.writeILEB128(writer, @as(i32, 1)); @@ -1208,10 +1213,6 @@ pub const Context = struct { }); try writer.writeByte(wasm.opcode(opcode)); try leb.writeULEB128(writer, @as(u32, 0)); - } else { - try writer.writeByte(wasm.opcode(.i32_const)); - try leb.writeILEB128(writer, @as(i32, 0)); - try self.emitConstant(val, payload_type); } }, else => |zig_type| return self.fail("Wasm TODO: emitConstant for zigTypeTag {s}", .{zig_type}), diff --git a/src/value.zig b/src/value.zig index 7b3056bfcf..5ac9f142c4 100644 --- a/src/value.zig +++ b/src/value.zig @@ -133,12 +133,21 @@ pub const Value = extern union { /// When the type is error union: /// * If the tag is `.@"error"`, the error union is an error. /// * If the tag is `.eu_payload`, the error union is a payload. - /// * A nested error such as `((anyerror!T1)!T2)` in which the the outer error union + /// * A nested error such as `anyerror!(anyerror!T)` in which the the outer error union /// is non-error, but the inner error union is an error, is represented as /// a tag of `.eu_payload`, with a sub-tag of `.@"error"`. eu_payload, /// A pointer to the payload of an error union, based on a pointer to an error union. eu_payload_ptr, + /// When the type is optional: + /// * If the tag is `.null_value`, the optional is null. + /// * If the tag is `.opt_payload`, the optional is a payload. + /// * A nested optional such as `??T` in which the the outer optional + /// is non-null, but the inner optional is null, is represented as + /// a tag of `.opt_payload`, with a sub-tag of `.null_value`. + opt_payload, + /// A pointer to the payload of an optional, based on a pointer to an optional. + opt_payload_ptr, /// An instance of a struct. @"struct", /// An instance of a union. @@ -238,6 +247,8 @@ pub const Value = extern union { .repeated, .eu_payload, .eu_payload_ptr, + .opt_payload, + .opt_payload_ptr, => Payload.SubValue, .bytes, @@ -459,7 +470,12 @@ pub const Value = extern union { return Value{ .ptr_otherwise = &new_payload.base }; }, .bytes => return self.copyPayloadShallow(allocator, Payload.Bytes), - .repeated, .eu_payload, .eu_payload_ptr => { + .repeated, + .eu_payload, + .eu_payload_ptr, + .opt_payload, + .opt_payload_ptr, + => { const payload = self.cast(Payload.SubValue).?; const new_payload = try allocator.create(Payload.SubValue); new_payload.* = .{ @@ -656,12 +672,20 @@ pub const Value = extern union { try out_stream.writeAll("(eu_payload) "); val = val.castTag(.eu_payload).?.data; }, + .opt_payload => { + try out_stream.writeAll("(opt_payload) "); + val = val.castTag(.opt_payload).?.data; + }, .inferred_alloc => return out_stream.writeAll("(inferred allocation value)"), .inferred_alloc_comptime => return out_stream.writeAll("(inferred comptime allocation value)"), .eu_payload_ptr => { try out_stream.writeAll("(eu_payload_ptr)"); val = val.castTag(.eu_payload_ptr).?.data; }, + .opt_payload_ptr => { + try out_stream.writeAll("(opt_payload_ptr)"); + val = val.castTag(.opt_payload_ptr).?.data; + }, }; } @@ -776,6 +800,38 @@ pub const Value = extern union { } } + pub fn enumToInt(val: Value, ty: Type, buffer: *Payload.U64) Value { + if (val.castTag(.enum_field_index)) |enum_field_payload| { + const field_index = enum_field_payload.data; + switch (ty.tag()) { + .enum_full, .enum_nonexhaustive => { + const enum_full = ty.cast(Type.Payload.EnumFull).?.data; + if (enum_full.values.count() != 0) { + return enum_full.values.keys()[field_index]; + } else { + // Field index and integer values are the same. + buffer.* = .{ + .base = .{ .tag = .int_u64 }, + .data = field_index, + }; + return Value.initPayload(&buffer.base); + } + }, + .enum_simple => { + // Field index and integer values are the same. + buffer.* = .{ + .base = .{ .tag = .int_u64 }, + .data = field_index, + }; + return Value.initPayload(&buffer.base); + }, + else => unreachable, + } + } + // Assume it is already an integer and return it directly. + return val; + } + /// Asserts the value is an integer. pub fn toBigInt(self: Value, space: *BigIntSpace) BigIntConst { switch (self.tag()) { @@ -1132,7 +1188,10 @@ pub const Value = extern union { } pub fn hash(val: Value, ty: Type, hasher: *std.hash.Wyhash) void { - switch (ty.zigTypeTag()) { + const zig_ty_tag = ty.zigTypeTag(); + std.hash.autoHash(hasher, zig_ty_tag); + + switch (zig_ty_tag) { .BoundFn => unreachable, // TODO remove this from the language .Void, @@ -1157,7 +1216,10 @@ pub const Value = extern union { } }, .Float, .ComptimeFloat => { - @panic("TODO implement hashing float values"); + // TODO double check the lang spec. should we to bitwise hashing here, + // or a hash that normalizes the float value? + const float = val.toFloat(f128); + std.hash.autoHash(hasher, @bitCast(u128, float)); }, .Pointer => { @panic("TODO implement hashing pointer values"); @@ -1169,7 +1231,15 @@ pub const Value = extern union { @panic("TODO implement hashing struct values"); }, .Optional => { - @panic("TODO implement hashing optional values"); + if (val.castTag(.opt_payload)) |payload| { + std.hash.autoHash(hasher, true); // non-null + const sub_val = payload.data; + var buffer: Type.Payload.ElemType = undefined; + const sub_ty = ty.optionalChild(&buffer); + sub_val.hash(sub_ty, hasher); + } else { + std.hash.autoHash(hasher, false); // non-null + } }, .ErrorUnion => { @panic("TODO implement hashing error union values"); @@ -1178,7 +1248,16 @@ pub const Value = extern union { @panic("TODO implement hashing error set values"); }, .Enum => { - @panic("TODO implement hashing enum values"); + var enum_space: Payload.U64 = undefined; + const int_val = val.enumToInt(ty, &enum_space); + + var space: BigIntSpace = undefined; + const big = int_val.toBigInt(&space); + + std.hash.autoHash(hasher, big.positive); + for (big.limbs) |limb| { + std.hash.autoHash(hasher, limb); + } }, .Union => { @panic("TODO implement hashing union values"); @@ -1257,6 +1336,11 @@ pub const Value = extern union { const err_union_val = (try err_union_ptr.pointerDeref(allocator)) orelse return null; break :blk err_union_val.castTag(.eu_payload).?.data; }, + .opt_payload_ptr => blk: { + const opt_ptr = self.castTag(.opt_payload_ptr).?.data; + const opt_val = (try opt_ptr.pointerDeref(allocator)) orelse return null; + break :blk opt_val.castTag(.opt_payload).?.data; + }, .zero, .one, @@ -1354,13 +1438,14 @@ pub const Value = extern union { /// Valid for all types. Asserts the value is not undefined and not unreachable. pub fn isNull(self: Value) bool { return switch (self.tag()) { + .null_value => true, + .opt_payload => false, + .undef => unreachable, .unreachable_value => unreachable, .inferred_alloc => unreachable, .inferred_alloc_comptime => unreachable, - .null_value => true, - - else => false, + else => unreachable, }; } @@ -1390,6 +1475,10 @@ pub const Value = extern union { return switch (val.tag()) { .eu_payload => true, else => false, + + .undef => unreachable, + .inferred_alloc => unreachable, + .inferred_alloc_comptime => unreachable, }; } diff --git a/test/behavior/eval.zig b/test/behavior/eval.zig index 67103e01ff..1b6540cd32 100644 --- a/test/behavior/eval.zig +++ b/test/behavior/eval.zig @@ -148,3 +148,16 @@ const List = blk: { array: T, }; }; + +test "comptime function with the same args is memoized" { + comptime { + try expect(MakeType(i32) == MakeType(i32)); + try expect(MakeType(i32) != MakeType(f64)); + } +} + +fn MakeType(comptime T: type) type { + return struct { + field: T, + }; +} diff --git a/test/behavior/eval_stage1.zig b/test/behavior/eval_stage1.zig index 3599d5a477..644de50fd0 100644 --- a/test/behavior/eval_stage1.zig +++ b/test/behavior/eval_stage1.zig @@ -356,19 +356,6 @@ test "binary math operator in partially inlined function" { try expect(s[3] == 0xd0e0f10); } -test "comptime function with the same args is memoized" { - comptime { - try expect(MakeType(i32) == MakeType(i32)); - try expect(MakeType(i32) != MakeType(f64)); - } -} - -fn MakeType(comptime T: type) type { - return struct { - field: T, - }; -} - test "comptime function with mutable pointer is not memoized" { comptime { var x: i32 = 1;