From 293d6bdc73c5fe01b07ebe3d09c9a78613fed093 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 17 Feb 2023 16:39:45 -0700 Subject: [PATCH] AstGen: back to index-based for loops --- src/AstGen.zig | 159 +++++++++++++++++++++------------------------- src/Sema.zig | 63 ++++++++---------- src/Zir.zig | 21 ++---- src/print_zir.zig | 6 +- 4 files changed, 106 insertions(+), 143 deletions(-) diff --git a/src/AstGen.zig b/src/AstGen.zig index 523ac235ac..b90201713e 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -88,7 +88,6 @@ fn setExtra(astgen: *AstGen, index: usize, extra: anytype) void { Zir.Inst.BuiltinCall.Flags => @bitCast(u32, @field(extra, field.name)), Zir.Inst.SwitchBlock.Bits => @bitCast(u32, @field(extra, field.name)), Zir.Inst.FuncFancy.Bits => @bitCast(u32, @field(extra, field.name)), - Zir.Inst.ElemPtrImm.Bits => @bitCast(u32, @field(extra, field.name)), else => @compileError("bad field type"), }; i += 1; @@ -1566,9 +1565,7 @@ fn arrayInitExprRlPtrInner( for (elements) |elem_init, i| { const elem_ptr = try gz.addPlNode(.elem_ptr_imm, elem_init, Zir.Inst.ElemPtrImm{ .ptr = result_ptr, - .bits = .{ - .index = @intCast(u31, i), - }, + .index = @intCast(u32, i), }); astgen.extra.items[extra_index] = refToIndex(elem_ptr).?; extra_index += 1; @@ -2601,6 +2598,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As .field_base_ptr, .ret_ptr, .ret_type, + .for_len, .@"try", .try_ptr, //.try_inline, @@ -2669,7 +2667,6 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As .validate_deref, .save_err_ret_index, .restore_err_ret_index, - .for_check_lens, => break :b true, .@"defer" => unreachable, @@ -6305,23 +6302,26 @@ fn forExpr( const node_data = tree.nodes.items(.data); const gpa = astgen.gpa; - const allocs = try gpa.alloc(Zir.Inst.Ref, for_full.ast.inputs.len); - defer gpa.free(allocs); + // For counters, this is the start value; for indexables, this is the base + // pointer that can be used with elem_ptr and similar instructions. + // Special value `none` means that this is a counter and its start value is + // zero, indicating that the main index counter can be used directly. + const indexables = try gpa.alloc(Zir.Inst.Ref, for_full.ast.inputs.len); + defer gpa.free(indexables); // elements of this array can be `none`, indicating no length check. const lens = try gpa.alloc(Zir.Inst.Ref, for_full.ast.inputs.len); defer gpa.free(lens); - const alloc_tag: Zir.Inst.Tag = if (is_inline) .alloc_comptime_mut else .alloc_mut; + // We will use a single zero-based counter no matter how many indexables there are. + const index_ptr = blk: { + const alloc_tag: Zir.Inst.Tag = if (is_inline) .alloc_comptime_mut else .alloc; + const index_ptr = try parent_gz.addUnNode(alloc_tag, .usize_type, node); + // initialize to zero + _ = try parent_gz.addBin(.store, index_ptr, .zero_usize); + break :blk index_ptr; + }; - // Tracks the index of allocs/lens that has a length to be checked and is - // used for the end value. - // If this is null, there are no len checks. - var end_input_index: ?u32 = null; - // This is a value to use to find out if the for loop has reached the end - // yet. It prefers to use a counter since the end value is provided directly, - // and otherwise falls back to adding ptr+len of a slice to compute end. - // Corresponds to end_input_index and will be .none in case that value is null. - var cond_end_val: Zir.Inst.Ref = .none; + var any_len_checks = false; { var capture_token = for_full.payload_token; @@ -6341,10 +6341,8 @@ fn forExpr( if (capture_is_ref) { return astgen.failTok(ident_tok, "cannot capture reference to range", .{}); } - const counter_ptr = try parent_gz.addUnNode(alloc_tag, .usize_type, node); const start_node = node_data[input].lhs; const start_val = try expr(parent_gz, scope, .{ .rl = .none }, start_node); - _ = try parent_gz.addBin(.store, counter_ptr, start_val); const end_node = node_data[input].rhs; const end_val = if (end_node != 0) @@ -6352,7 +6350,8 @@ fn forExpr( else .none; - const range_len = if (end_val == .none or nodeIsTriviallyZero(tree, start_node)) + const start_is_zero = nodeIsTriviallyZero(tree, start_node); + const range_len = if (end_val == .none or start_is_zero) end_val else try parent_gz.addPlNode(.sub, input, Zir.Inst.Bin{ @@ -6360,61 +6359,33 @@ fn forExpr( .rhs = start_val, }); - if (range_len != .none and cond_end_val == .none) { - end_input_index = i; - cond_end_val = end_val; - } - - allocs[i] = counter_ptr; + any_len_checks = any_len_checks or range_len != .none; + indexables[i] = if (start_is_zero) .none else start_val; lens[i] = range_len; } else { const indexable = try expr(parent_gz, scope, .{ .rl = .none }, input); - // This instruction has nice compile errors so we put it before the other ones - // even though it is not needed until later in the block. - const ptr_len = try parent_gz.addUnNode(.indexable_ptr_len, indexable, input); - const base_ptr = try parent_gz.addPlNode(.elem_ptr_imm, input, Zir.Inst.ElemPtrImm{ - .ptr = indexable, - .bits = .{ - .index = 0, - .manyptr = true, - }, - }); - const alloc_ty_inst = try parent_gz.addUnNode(.typeof, base_ptr, node); - const alloc = try parent_gz.addUnNode(alloc_tag, alloc_ty_inst, node); - _ = try parent_gz.addBin(.store, alloc, base_ptr); + const indexable_len = try parent_gz.addUnNode(.indexable_ptr_len, indexable, input); - if (end_input_index == null) { - end_input_index = i; - assert(cond_end_val == .none); - } - - allocs[i] = alloc; - lens[i] = ptr_len; + any_len_checks = true; + indexables[i] = indexable; + lens[i] = indexable_len; } } } - // In case there are no counters which already have an end computed, we - // compute an end from base pointer plus length. - if (end_input_index) |i| { - if (cond_end_val == .none) { - cond_end_val = try parent_gz.addPlNode(.add, for_full.ast.inputs[i], Zir.Inst.Bin{ - .lhs = allocs[i], - .rhs = lens[i], - }); - } - } - // We use a dedicated ZIR instruction to assert the lengths to assist with // nicer error reporting as well as fewer ZIR bytes emitted. - if (end_input_index != null) { + const len: Zir.Inst.Ref = len: { + if (!any_len_checks) break :len .none; + const lens_len = @intCast(u32, lens.len); try astgen.extra.ensureUnusedCapacity(gpa, @typeInfo(Zir.Inst.MultiOp).Struct.fields.len + lens_len); - _ = try parent_gz.addPlNode(.for_check_lens, node, Zir.Inst.MultiOp{ + const len = try parent_gz.addPlNode(.for_len, node, Zir.Inst.MultiOp{ .operands_len = lens_len, }); appendRefsAssumeCapacity(astgen, lens); - } + break :len len; + }; const loop_tag: Zir.Inst.Tag = if (is_inline) .block_inline else .loop; const loop_block = try parent_gz.makeBlockInst(loop_tag, node); @@ -6429,22 +6400,14 @@ fn forExpr( var cond_scope = parent_gz.makeSubBlock(&loop_scope.base); defer cond_scope.unstack(); - // Load all the iterables. - const loaded_ptrs = try gpa.alloc(Zir.Inst.Ref, allocs.len); - defer gpa.free(loaded_ptrs); - for (allocs) |alloc, i| { - loaded_ptrs[i] = try cond_scope.addUnNode(.load, alloc, for_full.ast.inputs[i]); - } - // Check the condition. - const input_index = end_input_index orelse { + if (!any_len_checks) { return astgen.failNode(node, "TODO: handle infinite for loop", .{}); - }; - assert(cond_end_val != .none); - - const cond = try cond_scope.addPlNode(.cmp_neq, for_full.ast.inputs[input_index], Zir.Inst.Bin{ - .lhs = loaded_ptrs[input_index], - .rhs = cond_end_val, + } + const index = try cond_scope.addUnNode(.load, index_ptr, node); + const cond = try cond_scope.addPlNode(.cmp_lt, node, Zir.Inst.Bin{ + .lhs = index, + .rhs = len, }); const condbr_tag: Zir.Inst.Tag = if (is_inline) .condbr_inline else .condbr; @@ -6455,14 +6418,12 @@ fn forExpr( // cond_block unstacked now, can add new instructions to loop_scope try loop_scope.instructions.append(gpa, cond_block); - // Increment the loop variables. - for (allocs) |alloc, i| { - const incremented = try loop_scope.addPlNode(.add, node, Zir.Inst.Bin{ - .lhs = loaded_ptrs[i], - .rhs = .one_usize, - }); - _ = try loop_scope.addBin(.store, alloc, incremented); - } + // Increment the index variable. + const index_plus_one = try loop_scope.addPlNode(.add, node, Zir.Inst.Bin{ + .lhs = index, + .rhs = .one_usize, + }); + _ = try loop_scope.addBin(.store, index_ptr, index_plus_one); const repeat_tag: Zir.Inst.Tag = if (is_inline) .repeat_inline else .repeat; _ = try loop_scope.addNode(repeat_tag, node); @@ -6500,21 +6461,43 @@ fn forExpr( const name_str_index = try astgen.identAsString(ident_tok); try astgen.detectLocalShadowing(capture_sub_scope, name_str_index, ident_tok, capture_name, .capture); - const loaded = if (capture_is_ref) - loaded_ptrs[i] - else - try then_scope.addUnNode(.load, loaded_ptrs[i], input); + const capture_inst = inst: { + const is_counter = node_tags[input] == .for_range; + + if (indexables[i] == .none) { + // Special case: the main index can be used directly. + assert(is_counter); + assert(!capture_is_ref); + break :inst index; + } + + // For counters, we add the index variable to the start value; for + // indexables, we use it as an element index. This is so similar + // that they can share the same code paths, branching only on the + // ZIR tag. + const switch_cond = (@as(u2, @boolToInt(capture_is_ref)) << 1) | @boolToInt(is_counter); + const tag: Zir.Inst.Tag = switch (switch_cond) { + 0b00 => .elem_val, + 0b01 => .add, + 0b10 => .elem_ptr, + 0b11 => unreachable, // compile error emitted already + }; + break :inst try then_scope.addPlNode(tag, input, Zir.Inst.Bin{ + .lhs = indexables[i], + .rhs = index, + }); + }; capture_scopes[i] = .{ .parent = capture_sub_scope, .gen_zir = &then_scope, .name = name_str_index, - .inst = loaded, + .inst = capture_inst, .token_src = ident_tok, .id_cat = .capture, }; - try then_scope.addDbgVar(.dbg_var_val, name_str_index, loaded); + try then_scope.addDbgVar(.dbg_var_val, name_str_index, capture_inst); capture_sub_scope = &capture_scopes[i].base; } diff --git a/src/Sema.zig b/src/Sema.zig index c251aa9fbf..cb40a85364 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -1035,6 +1035,7 @@ fn analyzeBodyInner( .@"await" => try sema.zirAwait(block, inst), .array_base_ptr => try sema.zirArrayBasePtr(block, inst), .field_base_ptr => try sema.zirFieldBasePtr(block, inst), + .for_len => try sema.zirForLen(block, inst), .clz => try sema.zirBitCount(block, inst, .clz, Value.clz), .ctz => try sema.zirBitCount(block, inst, .ctz, Value.ctz), @@ -1386,11 +1387,6 @@ fn analyzeBodyInner( i += 1; continue; }, - .for_check_lens => { - try sema.zirForCheckLens(block, inst); - i += 1; - continue; - }, // Special case instructions to handle comptime control flow. .@"break" => { @@ -3924,6 +3920,16 @@ fn zirFieldBasePtr( return sema.failWithStructInitNotSupported(block, src, sema.typeOf(start_ptr).childType()); } +fn zirForLen(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { + const inst_data = sema.code.instructions.items(.data)[inst].pl_node; + const extra = sema.code.extraData(Zir.Inst.MultiOp, inst_data.payload_index); + const args = sema.code.refSlice(extra.end, extra.data.operands_len); + const src = inst_data.src(); + + _ = args; + return sema.fail(block, src, "TODO implement zirForCheckLens", .{}); +} + fn validateArrayInitTy( sema: *Sema, block: *Block, @@ -9649,7 +9655,7 @@ fn zirElemPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data; const array_ptr = try sema.resolveInst(extra.lhs); const elem_index = try sema.resolveInst(extra.rhs); - return sema.elemPtr(block, src, array_ptr, elem_index, src, false, .One); + return sema.elemPtr(block, src, array_ptr, elem_index, src, false); } fn zirElemPtrNode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { @@ -9662,7 +9668,7 @@ fn zirElemPtrNode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data; const array_ptr = try sema.resolveInst(extra.lhs); const elem_index = try sema.resolveInst(extra.rhs); - return sema.elemPtr(block, src, array_ptr, elem_index, elem_index_src, false, .One); + return sema.elemPtr(block, src, array_ptr, elem_index, elem_index_src, false); } fn zirElemPtrImm(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { @@ -9673,9 +9679,8 @@ fn zirElemPtrImm(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError! const src = inst_data.src(); const extra = sema.code.extraData(Zir.Inst.ElemPtrImm, inst_data.payload_index).data; const array_ptr = try sema.resolveInst(extra.ptr); - const elem_index = try sema.addIntUnsigned(Type.usize, extra.bits.index); - const size: std.builtin.Type.Pointer.Size = if (extra.bits.manyptr) .Many else .One; - return sema.elemPtr(block, src, array_ptr, elem_index, src, true, size); + const elem_index = try sema.addIntUnsigned(Type.usize, extra.index); + return sema.elemPtr(block, src, array_ptr, elem_index, src, true); } fn zirSliceStart(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { @@ -17102,16 +17107,6 @@ fn zirRestoreErrRetIndex(sema: *Sema, start_block: *Block, inst: Zir.Inst.Index) return sema.popErrorReturnTrace(start_block, src, operand, saved_index); } -fn zirForCheckLens(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!void { - const inst_data = sema.code.instructions.items(.data)[inst].pl_node; - const extra = sema.code.extraData(Zir.Inst.MultiOp, inst_data.payload_index); - const args = sema.code.refSlice(extra.end, extra.data.operands_len); - const src = inst_data.src(); - - _ = args; - return sema.fail(block, src, "TODO implement zirForCheckLens", .{}); -} - fn addToInferredErrorSet(sema: *Sema, uncasted_operand: Air.Inst.Ref) !void { assert(sema.fn_ret_ty.zigTypeTag() == .ErrorUnion); @@ -22906,7 +22901,7 @@ fn panicSentinelMismatch( const actual_sentinel = if (ptr_ty.isSlice()) try parent_block.addBinOp(.slice_elem_val, ptr, sentinel_index) else blk: { - const elem_ptr_ty = try sema.elemPtrType(ptr_ty, null, .One); + const elem_ptr_ty = try sema.elemPtrType(ptr_ty, null); const sentinel_ptr = try parent_block.addPtrElemPtr(ptr, sentinel_index, elem_ptr_ty); break :blk try parent_block.addTyOp(.load, sentinel_ty, sentinel_ptr); }; @@ -24073,7 +24068,6 @@ fn elemPtr( elem_index: Air.Inst.Ref, elem_index_src: LazySrcLoc, init: bool, - size: std.builtin.Type.Pointer.Size, ) CompileError!Air.Inst.Ref { const indexable_ptr_src = src; // TODO better source location const indexable_ptr_ty = sema.typeOf(indexable_ptr); @@ -24100,12 +24094,13 @@ fn elemPtr( const index_val = maybe_index_val orelse break :rs elem_index_src; const index = @intCast(usize, index_val.toUnsignedInt(target)); const elem_ptr = try ptr_val.elemPtr(indexable_ty, sema.arena, index, sema.mod); - const elem_ptr_ty = try sema.elemPtrType(indexable_ty, index, size); - return sema.addConstant(elem_ptr_ty, elem_ptr); + const result_ty = try sema.elemPtrType(indexable_ty, index); + return sema.addConstant(result_ty, elem_ptr); }; - const elem_ptr_ty = try sema.elemPtrType(indexable_ty, null, size); + const result_ty = try sema.elemPtrType(indexable_ty, null); + try sema.requireRuntimeBlock(block, src, runtime_src); - return block.addPtrElemPtr(indexable, elem_index, elem_ptr_ty); + return block.addPtrElemPtr(indexable, elem_index, result_ty); }, .One => { assert(indexable_ty.childType().zigTypeTag() == .Array); // Guaranteed by isIndexable @@ -24167,7 +24162,7 @@ fn elemVal( }, .One => { assert(indexable_ty.childType().zigTypeTag() == .Array); // Guaranteed by isIndexable - const elem_ptr = try sema.elemPtr(block, indexable_src, indexable, elem_index, elem_index_src, false, .One); + const elem_ptr = try sema.elemPtr(block, indexable_src, indexable, elem_index, elem_index_src, false); return sema.analyzeLoad(block, indexable_src, elem_ptr, elem_index_src); }, }, @@ -24405,7 +24400,7 @@ fn elemPtrArray( break :o index; } else null; - const elem_ptr_ty = try sema.elemPtrType(array_ptr_ty, offset, .One); + const elem_ptr_ty = try sema.elemPtrType(array_ptr_ty, offset); if (maybe_undef_array_ptr_val) |array_ptr_val| { if (array_ptr_val.isUndef()) { @@ -24510,7 +24505,7 @@ fn elemPtrSlice( break :o index; } else null; - const elem_ptr_ty = try sema.elemPtrType(slice_ty, offset, .One); + const elem_ptr_ty = try sema.elemPtrType(slice_ty, offset); if (maybe_undef_slice_val) |slice_val| { if (slice_val.isUndef()) { @@ -26240,7 +26235,7 @@ fn storePtr2( const elem_src = operand_src; // TODO better source location const elem = try sema.tupleField(block, operand_src, uncasted_operand, elem_src, i); const elem_index = try sema.addIntUnsigned(Type.usize, i); - const elem_ptr = try sema.elemPtr(block, ptr_src, ptr, elem_index, elem_src, false, .One); + const elem_ptr = try sema.elemPtr(block, ptr_src, ptr, elem_index, elem_src, false); try sema.storePtr2(block, src, elem_ptr, elem_src, elem, elem_src, .store); } return; @@ -33277,12 +33272,7 @@ fn compareVector( /// For []T, returns *T /// Handles const-ness and address spaces in particular. /// This code is duplicated in `analyzePtrArithmetic`. -fn elemPtrType( - sema: *Sema, - ptr_ty: Type, - offset: ?usize, - size: std.builtin.Type.Pointer.Size, -) !Type { +fn elemPtrType(sema: *Sema, ptr_ty: Type, offset: ?usize) !Type { const ptr_info = ptr_ty.ptrInfo().data; const elem_ty = ptr_ty.elemType2(); const allow_zero = ptr_info.@"allowzero" and (offset orelse 0) == 0; @@ -33327,7 +33317,6 @@ fn elemPtrType( break :a new_align; }; return try Type.ptr(sema.arena, sema.mod, .{ - .size = size, .pointee_type = elem_ty, .mutable = ptr_info.mutable, .@"addrspace" = ptr_info.@"addrspace", diff --git a/src/Zir.zig b/src/Zir.zig index edbd70e170..e215dfac10 100644 --- a/src/Zir.zig +++ b/src/Zir.zig @@ -79,7 +79,6 @@ pub fn extraData(code: Zir, comptime T: type, index: usize) struct { data: T, en Inst.BuiltinCall.Flags => @bitCast(Inst.BuiltinCall.Flags, code.extra[i]), Inst.SwitchBlock.Bits => @bitCast(Inst.SwitchBlock.Bits, code.extra[i]), Inst.FuncFancy.Bits => @bitCast(Inst.FuncFancy.Bits, code.extra[i]), - Inst.ElemPtrImm.Bits => @bitCast(Inst.ElemPtrImm.Bits, code.extra[i]), else => @compileError("bad field type"), }; i += 1; @@ -501,14 +500,14 @@ pub const Inst = struct { /// Uses the `node` field. repeat_inline, /// Asserts that all the lengths provided match. Used to build a for loop. - /// Return value is always void. + /// Return value is the length as a usize. /// Uses the `pl_node` field with payload `MultiOp`. /// There is exactly one item corresponding to each AST node inside the for - /// loop condition. Each item may be `none`, indicating an unbounded range. + /// loop condition. Any item may be `none`, indicating an unbounded range. /// Illegal behaviors: /// * If all lengths are unbounded ranges (always a compile error). /// * If any two lengths do not match each other. - for_check_lens, + for_len, /// Merge two error sets into one, `E1 || E2`. /// Uses the `pl_node` field with payload `Bin`. merge_error_sets, @@ -1254,7 +1253,7 @@ pub const Inst = struct { .defer_err_code, .save_err_ret_index, .restore_err_ret_index, - .for_check_lens, + .for_len, => false, .@"break", @@ -1322,7 +1321,6 @@ pub const Inst = struct { .memcpy, .memset, .check_comptime_control_flow, - .for_check_lens, .@"defer", .defer_err_code, .restore_err_ret_index, @@ -1547,6 +1545,7 @@ pub const Inst = struct { .repeat_inline, .panic, .panic_comptime, + .for_len, .@"try", .try_ptr, //.try_inline, @@ -1602,7 +1601,7 @@ pub const Inst = struct { .@"break" = .@"break", .break_inline = .@"break", .check_comptime_control_flow = .un_node, - .for_check_lens = .pl_node, + .for_len = .pl_node, .call = .pl_node, .cmp_lt = .pl_node, .cmp_lte = .pl_node, @@ -2975,13 +2974,7 @@ pub const Inst = struct { pub const ElemPtrImm = struct { ptr: Ref, - bits: Bits, - - pub const Bits = packed struct(u32) { - index: u31, - /// Controls whether the type returned is `*T` or `[*]T`. - manyptr: bool = false, - }; + index: u32, }; /// 0. multi_cases_len: u32 // If has_multi_cases is set. diff --git a/src/print_zir.zig b/src/print_zir.zig index 0977a88d53..e5ce9321f5 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -355,7 +355,7 @@ const Writer = struct { .array_type, => try self.writePlNodeBin(stream, inst), - .for_check_lens => try self.writePlNodeMultiOp(stream, inst), + .for_len => try self.writePlNodeMultiOp(stream, inst), .elem_ptr_imm => try self.writeElemPtrImm(stream, inst), @@ -888,9 +888,7 @@ const Writer = struct { const extra = self.code.extraData(Zir.Inst.ElemPtrImm, inst_data.payload_index).data; try self.writeInstRef(stream, extra.ptr); - try stream.print(", {d}", .{extra.bits.index}); - try self.writeFlag(stream, ", manyptr", extra.bits.manyptr); - try stream.writeAll(") "); + try stream.print(", {d}) ", .{extra.index}); try self.writeSrc(stream, inst_data.src()); }