From 5186c6c4ee520f7f1c87134d45cd7a33f8b2aaea Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 8 Jan 2025 16:25:00 -0800 Subject: [PATCH] wasm linker: distinguish symbol name vs import name, and implement weak --- src/link/Wasm.zig | 87 +++++++++++++++++++++++++++++++++++----- src/link/Wasm/Flush.zig | 15 +++---- src/link/Wasm/Object.zig | 43 +++++++++++++++----- 3 files changed, 115 insertions(+), 30 deletions(-) diff --git a/src/link/Wasm.zig b/src/link/Wasm.zig index 1f0ba70b58..2f8689bb1b 100644 --- a/src/link/Wasm.zig +++ b/src/link/Wasm.zig @@ -91,6 +91,7 @@ objects: std.ArrayListUnmanaged(Object) = .{}, func_types: std.AutoArrayHashMapUnmanaged(FunctionType, void) = .empty, /// Provides a mapping of both imports and provided functions to symbol name. /// Local functions may be unnamed. +/// Key is symbol name, however the `FunctionImport` may have an name override for the import name. object_function_imports: std.AutoArrayHashMapUnmanaged(String, FunctionImport) = .empty, /// All functions for all objects. object_functions: std.ArrayListUnmanaged(ObjectFunction) = .empty, @@ -164,7 +165,7 @@ object_host_name: OptionalString, /// Memory section memories: std.wasm.Memory = .{ .limits = .{ .min = 0, - .max = undefined, + .max = 0, .flags = .{ .has_max = false, .is_shared = false }, } }, @@ -371,6 +372,17 @@ pub const OutputFunctionIndex = enum(u32) { return fromResolution(wasm, .fromObjectFunction(wasm, index)).?; } + pub fn fromObjectFunctionHandlingWeak(wasm: *const Wasm, index: ObjectFunctionIndex) OutputFunctionIndex { + const ptr = index.ptr(wasm); + if (ptr.flags.binding == .weak) { + const name = ptr.name.unwrap().?; + const import = wasm.object_function_imports.getPtr(name).?; + assert(import.resolution != .unresolved); + return fromResolution(wasm, import.resolution).?; + } + return fromResolution(wasm, .fromObjectFunction(wasm, index)).?; + } + pub fn fromIpIndex(wasm: *const Wasm, ip_index: InternPool.Index) OutputFunctionIndex { const zcu = wasm.base.comp.zcu.?; const ip = &zcu.intern_pool; @@ -923,6 +935,8 @@ const DebugSection = struct {}; pub const FunctionImport = extern struct { flags: SymbolFlags, module_name: OptionalString, + /// May be different than the key which is a symbol name. + name: String, source_location: SourceLocation, resolution: Resolution, type: FunctionType.Index, @@ -1042,10 +1056,14 @@ pub const FunctionImport = extern struct { return &wasm.object_function_imports.values()[@intFromEnum(index)]; } - pub fn name(index: Index, wasm: *const Wasm) String { + pub fn symbolName(index: Index, wasm: *const Wasm) String { return index.key(wasm).*; } + pub fn importName(index: Index, wasm: *const Wasm) String { + return index.value(wasm).name; + } + pub fn moduleName(index: Index, wasm: *const Wasm) OptionalString { return index.value(wasm).module_name; } @@ -1079,6 +1097,8 @@ pub const ObjectFunction = extern struct { pub const GlobalImport = extern struct { flags: SymbolFlags, module_name: OptionalString, + /// May be different than the key which is a symbol name. + name: String, source_location: SourceLocation, resolution: Resolution, @@ -1194,10 +1214,14 @@ pub const GlobalImport = extern struct { return &wasm.object_global_imports.values()[@intFromEnum(index)]; } - pub fn name(index: Index, wasm: *const Wasm) String { + pub fn symbolName(index: Index, wasm: *const Wasm) String { return index.key(wasm).*; } + pub fn importName(index: Index, wasm: *const Wasm) String { + return index.value(wasm).name; + } + pub fn moduleName(index: Index, wasm: *const Wasm) OptionalString { return index.value(wasm).module_name; } @@ -1260,6 +1284,8 @@ pub const RefType1 = enum(u1) { pub const TableImport = extern struct { flags: SymbolFlags, module_name: String, + /// May be different than the key which is a symbol name. + name: String, source_location: SourceLocation, resolution: Resolution, limits_min: u32, @@ -1387,6 +1413,15 @@ pub const ObjectTableIndex = enum(u32) { pub fn ptr(index: ObjectTableIndex, wasm: *const Wasm) *Table { return &wasm.object_tables.items[@intFromEnum(index)]; } + + pub fn chaseWeak(i: ObjectTableIndex, wasm: *const Wasm) ObjectTableIndex { + const table = ptr(i, wasm); + if (table.flags.binding != .weak) return i; + const name = table.name.unwrap().?; + const import = wasm.object_table_imports.getPtr(name).?; + assert(import.resolution != .unresolved); // otherwise it should resolve to this one. + return import.resolution.unpack().object_table; + } }; /// Index into `Wasm.object_globals`. @@ -1400,6 +1435,15 @@ pub const ObjectGlobalIndex = enum(u32) { pub fn name(index: ObjectGlobalIndex, wasm: *const Wasm) OptionalString { return index.ptr(wasm).name; } + + pub fn chaseWeak(i: ObjectGlobalIndex, wasm: *const Wasm) ObjectGlobalIndex { + const global = ptr(i, wasm); + if (global.flags.binding != .weak) return i; + const import_name = global.name.unwrap().?; + const import = wasm.object_global_imports.getPtr(import_name).?; + assert(import.resolution != .unresolved); // otherwise it should resolve to this one. + return import.resolution.unpack(wasm).object_global; + } }; pub const ObjectMemory = extern struct { @@ -1442,6 +1486,15 @@ pub const ObjectFunctionIndex = enum(u32) { assert(result != .none); return result; } + + pub fn chaseWeak(i: ObjectFunctionIndex, wasm: *const Wasm) ObjectFunctionIndex { + const func = ptr(i, wasm); + if (func.flags.binding != .weak) return i; + const name = func.name.unwrap().?; + const import = wasm.object_function_imports.getPtr(name).?; + assert(import.resolution != .unresolved); // otherwise it should resolve to this one. + return import.resolution.unpack(wasm).object_function; + } }; /// Index into `object_functions`, or null. @@ -2131,7 +2184,7 @@ pub const ZcuImportIndex = enum(u32) { return &wasm.imports.keys()[@intFromEnum(index)]; } - pub fn name(index: ZcuImportIndex, wasm: *const Wasm) String { + pub fn importName(index: ZcuImportIndex, wasm: *const Wasm) String { const zcu = wasm.base.comp.zcu.?; const ip = &zcu.intern_pool; const nav_index = index.ptr(wasm).*; @@ -2217,9 +2270,9 @@ pub const FunctionImportId = enum(u32) { } } - pub fn name(id: FunctionImportId, wasm: *const Wasm) String { + pub fn importName(id: FunctionImportId, wasm: *const Wasm) String { return switch (unpack(id, wasm)) { - inline .object_function_import, .zcu_import => |i| i.name(wasm), + inline .object_function_import, .zcu_import => |i| i.importName(wasm), }; } @@ -2300,9 +2353,9 @@ pub const GlobalImportId = enum(u32) { } } - pub fn name(id: GlobalImportId, wasm: *const Wasm) String { + pub fn importName(id: GlobalImportId, wasm: *const Wasm) String { return switch (unpack(id, wasm)) { - inline .object_global_import, .zcu_import => |i| i.name(wasm), + inline .object_global_import, .zcu_import => |i| i.importName(wasm), }; } @@ -3297,6 +3350,14 @@ pub fn prelink(wasm: *Wasm, prog_node: std.Progress.Node) link.File.FlushError!v try markDataImport(wasm, name, import, @enumFromInt(i)); } } + + // This is a wild ass guess at how to merge memories, haven't checked yet + // what the proper way to do this is. + for (wasm.object_memory_imports.values()) |*memory_import| { + wasm.memories.limits.min = @min(wasm.memories.limits.min, memory_import.limits_min); + wasm.memories.limits.max = @max(wasm.memories.limits.max, memory_import.limits_max); + wasm.memories.limits.flags.has_max = wasm.memories.limits.flags.has_max or memory_import.limits_has_max; + } } fn markFunctionImport( @@ -3532,12 +3593,12 @@ fn markRelocations(wasm: *Wasm, relocs: ObjectRelocation.IterableSlice) link.Fil .table_index_i64, .table_index_rel_sleb, .table_index_rel_sleb64, - => try markFunction(wasm, pointee.function), + => try markFunction(wasm, pointee.function.chaseWeak(wasm)), .global_index_leb, .global_index_i32, - => try markGlobal(wasm, pointee.global), + => try markGlobal(wasm, pointee.global.chaseWeak(wasm)), .table_number_leb, - => try wasm.tables.put(wasm.base.comp.gpa, .fromObjectTable(pointee.table), {}), + => try markTable(wasm, pointee.table.chaseWeak(wasm)), .section_offset_i32 => { log.warn("TODO: ensure section {d} is included in output", .{pointee.section}); @@ -3561,6 +3622,10 @@ fn markRelocations(wasm: *Wasm, relocs: ObjectRelocation.IterableSlice) link.Fil } } +fn markTable(wasm: *Wasm, i: ObjectTableIndex) link.File.FlushError!void { + try wasm.tables.put(wasm.base.comp.gpa, .fromObjectTable(i), {}); +} + pub fn flushModule( wasm: *Wasm, arena: Allocator, diff --git a/src/link/Wasm/Flush.zig b/src/link/Wasm/Flush.zig index 23b590696b..1e4f256a8f 100644 --- a/src/link/Wasm/Flush.zig +++ b/src/link/Wasm/Flush.zig @@ -459,7 +459,7 @@ pub fn finish(f: *Flush, wasm: *Wasm) !void { try leb.writeUleb128(binary_writer, @as(u32, @intCast(module_name.len))); try binary_writer.writeAll(module_name); - const name = id.name(wasm).slice(wasm); + const name = id.importName(wasm).slice(wasm); try leb.writeUleb128(binary_writer, @as(u32, @intCast(name.len))); try binary_writer.writeAll(name); @@ -474,7 +474,7 @@ pub fn finish(f: *Flush, wasm: *Wasm) !void { try leb.writeUleb128(binary_writer, @as(u32, @intCast(module_name.len))); try binary_writer.writeAll(module_name); - const name = id.key(wasm).slice(wasm); + const name = table_import.name.slice(wasm); try leb.writeUleb128(binary_writer, @as(u32, @intCast(name.len))); try binary_writer.writeAll(name); @@ -484,10 +484,7 @@ pub fn finish(f: *Flush, wasm: *Wasm) !void { } total_imports += wasm.table_imports.entries.len; - for (wasm.object_memory_imports.keys(), wasm.object_memory_imports.values()) |name, *memory_import| { - try emitMemoryImport(wasm, binary_bytes, name, memory_import); - total_imports += 1; - } else if (import_memory) { + if (import_memory) { const name = if (is_obj) wasm.preloaded_strings.__linear_memory else wasm.preloaded_strings.memory; try emitMemoryImport(wasm, binary_bytes, name, &.{ // TODO the import_memory option needs to specify from which module @@ -506,7 +503,7 @@ pub fn finish(f: *Flush, wasm: *Wasm) !void { try leb.writeUleb128(binary_writer, @as(u32, @intCast(module_name.len))); try binary_writer.writeAll(module_name); - const name = id.name(wasm).slice(wasm); + const name = id.importName(wasm).slice(wasm); try leb.writeUleb128(binary_writer, @as(u32, @intCast(name.len))); try binary_writer.writeAll(name); @@ -1458,8 +1455,8 @@ fn applyRelocs(code: []u8, code_offset: u32, relocs: Wasm.ObjectRelocation.Itera if (offset >= relocs.end) break; const sliced_code = code[offset - code_offset ..]; switch (tag) { - .function_index_i32 => reloc_u32_function(sliced_code, .fromObjectFunction(wasm, pointee.function)), - .function_index_leb => reloc_leb_function(sliced_code, .fromObjectFunction(wasm, pointee.function)), + .function_index_i32 => reloc_u32_function(sliced_code, .fromObjectFunctionHandlingWeak(wasm, pointee.function)), + .function_index_leb => reloc_leb_function(sliced_code, .fromObjectFunctionHandlingWeak(wasm, pointee.function)), .function_offset_i32 => @panic("TODO this value is not known yet"), .function_offset_i64 => @panic("TODO this value is not known yet"), .table_index_i32 => @panic("TODO indirect function table needs to support object functions too"), diff --git a/src/link/Wasm/Object.zig b/src/link/Wasm/Object.zig index 9f877a0a05..91354c3d03 100644 --- a/src/link/Wasm/Object.zig +++ b/src/link/Wasm/Object.zig @@ -972,7 +972,7 @@ pub fn parse( for (ss.symbol_table.items) |symbol| switch (symbol.pointee) { .function_import => |index| { const ptr = index.ptr(ss); - const name = symbol.name.unwrap().?; + const name = symbol.name.unwrap() orelse ptr.name; if (symbol.flags.binding == .local) { diags.addParseError(path, "local symbol '{s}' references import", .{name.slice(wasm)}); continue; @@ -1000,10 +1000,18 @@ pub fn parse( source_location.addNote(&err, "module '{s}' here", .{ptr.module_name.slice(wasm)}); continue; } + if (gop.value_ptr.name != ptr.name) { + var err = try diags.addErrorWithNotes(2); + try err.addMsg("symbol '{s}' mismatching import names", .{name.slice(wasm)}); + gop.value_ptr.source_location.addNote(&err, "imported as '{s}' here", .{gop.value_ptr.name.slice(wasm)}); + source_location.addNote(&err, "imported as '{s}' here", .{ptr.name.slice(wasm)}); + continue; + } } else { gop.value_ptr.* = .{ .flags = symbol.flags, .module_name = ptr.module_name.toOptional(), + .name = ptr.name, .source_location = source_location, .resolution = .unresolved, .type = fn_ty_index, @@ -1012,7 +1020,7 @@ pub fn parse( }, .global_import => |index| { const ptr = index.ptr(ss); - const name = symbol.name.unwrap().?; + const name = symbol.name.unwrap() orelse ptr.name; if (symbol.flags.binding == .local) { diags.addParseError(path, "local symbol '{s}' references import", .{name.slice(wasm)}); continue; @@ -1049,10 +1057,18 @@ pub fn parse( source_location.addNote(&err, "module '{s}' here", .{ptr.module_name.slice(wasm)}); continue; } + if (gop.value_ptr.name != ptr.name) { + var err = try diags.addErrorWithNotes(2); + try err.addMsg("symbol '{s}' mismatching import names", .{name.slice(wasm)}); + gop.value_ptr.source_location.addNote(&err, "imported as '{s}' here", .{gop.value_ptr.name.slice(wasm)}); + source_location.addNote(&err, "imported as '{s}' here", .{ptr.name.slice(wasm)}); + continue; + } } else { gop.value_ptr.* = .{ .flags = symbol.flags, .module_name = ptr.module_name.toOptional(), + .name = ptr.name, .source_location = source_location, .resolution = .unresolved, }; @@ -1064,7 +1080,7 @@ pub fn parse( }, .table_import => |index| { const ptr = index.ptr(ss); - const name = symbol.name.unwrap().?; + const name = symbol.name.unwrap() orelse ptr.name; if (symbol.flags.binding == .local) { diags.addParseError(path, "local symbol '{s}' references import", .{name.slice(wasm)}); continue; @@ -1088,6 +1104,13 @@ pub fn parse( source_location.addNote(&err, "module '{s}' here", .{ptr.module_name.slice(wasm)}); continue; } + if (gop.value_ptr.name != ptr.name) { + var err = try diags.addErrorWithNotes(2); + try err.addMsg("symbol '{s}' mismatching import names", .{name.slice(wasm)}); + gop.value_ptr.source_location.addNote(&err, "imported as '{s}' here", .{gop.value_ptr.name.slice(wasm)}); + source_location.addNote(&err, "imported as '{s}' here", .{ptr.name.slice(wasm)}); + continue; + } if (symbol.flags.binding == .strong) gop.value_ptr.flags.binding = .strong; if (!symbol.flags.visibility_hidden) gop.value_ptr.flags.visibility_hidden = false; if (symbol.flags.no_strip) gop.value_ptr.flags.no_strip = true; @@ -1095,6 +1118,7 @@ pub fn parse( gop.value_ptr.* = .{ .flags = symbol.flags, .module_name = ptr.module_name, + .name = ptr.name, .source_location = source_location, .resolution = .unresolved, .limits_min = ptr.limits_min, @@ -1158,6 +1182,7 @@ pub fn parse( gop.value_ptr.* = .{ .flags = symbol.flags, .module_name = host_name, + .name = name, .source_location = source_location, .resolution = .fromObjectFunction(wasm, index), .type = ptr.type_index, @@ -1214,8 +1239,9 @@ pub fn parse( gop.value_ptr.* = .{ .flags = symbol.flags, .module_name = .none, + .name = name, .source_location = source_location, - .resolution = .unresolved, + .resolution = .fromObjectGlobal(wasm, index), }; gop.value_ptr.flags.global_type = .{ .valtype = .from(new_ty.valtype), @@ -1258,7 +1284,7 @@ pub fn parse( gop.value_ptr.* = .{ .flags = symbol.flags, .source_location = source_location, - .resolution = .unresolved, + .resolution = .fromObjectDataIndex(wasm, index), }; } }, @@ -1280,11 +1306,8 @@ pub fn parse( switch (exp.pointee) { inline .function, .table, .memory, .global => |index| { const ptr = index.ptr(wasm); - if (ptr.name == .none) { - // Missing symbol table entry; use defaults for exported things. - ptr.name = exp.name.toOptional(); - ptr.flags.exported = true; - } + ptr.name = exp.name.toOptional(); + ptr.flags.exported = true; }, } }