From 157af4332a7b78672ff8ad76a00120455547e2fd Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 6 May 2017 23:13:12 -0400 Subject: [PATCH] builtin functions for division and remainder division * add `@divTrunc` and `@divFloor` functions * add `@rem` and `@mod` functions * add compile error for `/` and `%` with signed integers * add `.bit_count` for float primitive types closes #217 --- doc/langref.md | 74 +++++++-- src/all_types.hpp | 26 +-- src/analyze.cpp | 7 + src/bignum.cpp | 64 +++++++- src/bignum.hpp | 3 + src/codegen.cpp | 250 +++++++++++++++++++++-------- src/error.cpp | 2 + src/error.hpp | 2 + src/ir.cpp | 343 +++++++++++++++++++++++----------------- src/ir_print.cpp | 25 ++- src/link.cpp | 2 + std/elf.zig | 7 +- std/fmt.zig | 4 +- std/math.zig | 241 ++++++++++++++++++++++++---- std/mem.zig | 30 +++- std/special/builtin.zig | 92 +++++++++++ std/special/zigrt.zig | 2 +- test/cases/math.zig | 90 ++++++++--- test/cases/misc.zig | 5 + test/compile_errors.zig | 18 ++- test/debug_safety.zig | 4 +- 21 files changed, 976 insertions(+), 315 deletions(-) diff --git a/doc/langref.md b/doc/langref.md index 29bd022c10..a83aecf822 100644 --- a/doc/langref.md +++ b/doc/langref.md @@ -502,15 +502,6 @@ This function performs an atomic compare exchange operation. The `fence` function is used to introduce happens-before edges between operations. -### @divExact(a: T, b: T) -> T - -This function performs integer division `a / b` and returns the result. - -The caller guarantees that this operation will have no remainder. - -In debug mode, a remainder causes a panic. In release mode, a remainder is -undefined behavior. - ### @truncate(comptime T: type, integer) -> T This function truncates bits from an integer type, resulting in a smaller @@ -621,3 +612,68 @@ Converts an enum tag name to a slice of bytes. Example: ### @fieldParentPtr(comptime ParentType: type, comptime field_name: []const u8, field_ptr: &T) -> &ParentType Given a pointer to a field, returns the base pointer of a struct. + +### @rem(numerator: T, denominator: T) -> T + +Remainder division. For unsigned integers this is the same as +`numerator % denominator`. Caller guarantees `denominator > 0`. + + * `@rem(-5, 3) == -2` + * `@divTrunc(a, b) + @rem(a, b) == a` + +See also: + * `std.math.rem` + * `@mod` + +### @mod(numerator: T, denominator: T) -> T + +Modulus division. For unsigned integers this is the same as +`numerator % denominator`. Caller guarantees `denominator > 0`. + + * `@mod(-5, 3) == 1` + * `@divFloor(a, b) + @mod(a, b) == a` + +See also: + * `std.math.mod` + * `@rem` + +### @divTrunc(numerator: T, denominator: T) -> T + +Truncated division. Rounds toward zero. For unsigned integers it is +the same as `numerator / denominator`. Caller guarantees `denominator != 0` and +`!(@isInteger(T) and T.is_signed and numerator == @minValue(T) and denominator == -1)`. + + * `@divTrunc(-5, 3) == -1` + * `@divTrunc(a, b) + @rem(a, b) == a` + +See also: + * `std.math.divTrunc` + * `@divFloor` + * `@divExact` + +### @divFloor(numerator: T, denominator: T) -> T + +Floored division. Rounds toward negative infinity. For unsigned integers it is +the same as `numerator / denominator`. Caller guarantees `denominator != 0` and +`!(@isInteger(T) and T.is_signed and numerator == @minValue(T) and denominator == -1)`. + + * `@divFloor(-5, 3) == -2` + * `@divFloor(a, b) + @mod(a, b) == a` + +See also: + * `std.math.divFloor` + * `@divTrunc` + * `@divExact` + +### @divExact(numerator: T, denominator: T) -> T + +Exact division. Caller guarantees `denominator != 0` and +`@divTrunc(numerator, denominator) * denominator == numerator`. + + * `@divExact(6, 3) == 2` + * `@divExact(a, b) * b == a` + +See also: + * `std.math.divExact` + * `@divTrunc` + * `@divFloor` diff --git a/src/all_types.hpp b/src/all_types.hpp index aced310fc5..2dd53e2e00 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1195,6 +1195,10 @@ enum BuiltinFnId { BuiltinFnIdCmpExchange, BuiltinFnIdFence, BuiltinFnIdDivExact, + BuiltinFnIdDivTrunc, + BuiltinFnIdDivFloor, + BuiltinFnIdRem, + BuiltinFnIdMod, BuiltinFnIdTruncate, BuiltinFnIdIntType, BuiltinFnIdSetDebugSafety, @@ -1270,6 +1274,8 @@ enum ZigLLVMFnId { ZigLLVMFnIdCtz, ZigLLVMFnIdClz, ZigLLVMFnIdOverflowArithmetic, + ZigLLVMFnIdFloor, + ZigLLVMFnIdCeil, }; enum AddSubMul { @@ -1288,6 +1294,9 @@ struct ZigLLVMFnKey { struct { uint32_t bit_count; } clz; + struct { + uint32_t bit_count; + } floor_ceil; struct { AddSubMul add_sub_mul; uint32_t bit_count; @@ -1746,7 +1755,6 @@ enum IrInstructionId { IrInstructionIdEmbedFile, IrInstructionIdCmpxchg, IrInstructionIdFence, - IrInstructionIdDivExact, IrInstructionIdTruncate, IrInstructionIdIntType, IrInstructionIdBoolNot, @@ -1897,8 +1905,13 @@ enum IrBinOp { IrBinOpSubWrap, IrBinOpMult, IrBinOpMultWrap, - IrBinOpDiv, - IrBinOpRem, + IrBinOpDivUnspecified, + IrBinOpDivExact, + IrBinOpDivTrunc, + IrBinOpDivFloor, + IrBinOpRemUnspecified, + IrBinOpRemRem, + IrBinOpRemMod, IrBinOpArrayCat, IrBinOpArrayMult, }; @@ -2250,13 +2263,6 @@ struct IrInstructionFence { AtomicOrder order; }; -struct IrInstructionDivExact { - IrInstruction base; - - IrInstruction *op1; - IrInstruction *op2; -}; - struct IrInstructionTruncate { IrInstruction base; diff --git a/src/analyze.cpp b/src/analyze.cpp index 2b2c303ea8..e1d8ca2869 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -4228,6 +4228,10 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) { return (uint32_t)(x.data.ctz.bit_count) * (uint32_t)810453934; case ZigLLVMFnIdClz: return (uint32_t)(x.data.clz.bit_count) * (uint32_t)2428952817; + case ZigLLVMFnIdFloor: + return (uint32_t)(x.data.floor_ceil.bit_count) * (uint32_t)1899859168; + case ZigLLVMFnIdCeil: + return (uint32_t)(x.data.floor_ceil.bit_count) * (uint32_t)1953839089; case ZigLLVMFnIdOverflowArithmetic: return ((uint32_t)(x.data.overflow_arithmetic.bit_count) * 87135777) + ((uint32_t)(x.data.overflow_arithmetic.add_sub_mul) * 31640542) + @@ -4244,6 +4248,9 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) { return a.data.ctz.bit_count == b.data.ctz.bit_count; case ZigLLVMFnIdClz: return a.data.clz.bit_count == b.data.clz.bit_count; + case ZigLLVMFnIdFloor: + case ZigLLVMFnIdCeil: + return a.data.floor_ceil.bit_count == b.data.floor_ceil.bit_count; case ZigLLVMFnIdOverflowArithmetic: return (a.data.overflow_arithmetic.bit_count == b.data.overflow_arithmetic.bit_count) && (a.data.overflow_arithmetic.add_sub_mul == b.data.overflow_arithmetic.add_sub_mul) && diff --git a/src/bignum.cpp b/src/bignum.cpp index eee85e9a7f..ee79a3477c 100644 --- a/src/bignum.cpp +++ b/src/bignum.cpp @@ -204,6 +204,23 @@ bool bignum_div(BigNum *dest, BigNum *op1, BigNum *op2) { if (dest->kind == BigNumKindFloat) { dest->data.x_float = op1->data.x_float / op2->data.x_float; + } else { + return bignum_div_trunc(dest, op1, op2); + } + return false; +} + +bool bignum_div_trunc(BigNum *dest, BigNum *op1, BigNum *op2) { + assert(op1->kind == op2->kind); + dest->kind = op1->kind; + + if (dest->kind == BigNumKindFloat) { + double result = op1->data.x_float / op2->data.x_float; + if (result >= 0) { + dest->data.x_float = floor(result); + } else { + dest->data.x_float = ceil(result); + } } else { dest->data.x_uint = op1->data.x_uint / op2->data.x_uint; dest->is_negative = op1->is_negative != op2->is_negative; @@ -212,6 +229,29 @@ bool bignum_div(BigNum *dest, BigNum *op1, BigNum *op2) { return false; } +bool bignum_div_floor(BigNum *dest, BigNum *op1, BigNum *op2) { + assert(op1->kind == op2->kind); + dest->kind = op1->kind; + + if (dest->kind == BigNumKindFloat) { + dest->data.x_float = floor(op1->data.x_float / op2->data.x_float); + } else { + if (op1->is_negative != op2->is_negative) { + uint64_t result = op1->data.x_uint / op2->data.x_uint; + if (result * op2->data.x_uint == op1->data.x_uint) { + dest->data.x_uint = result; + } else { + dest->data.x_uint = result + 1; + } + dest->is_negative = true; + } else { + dest->data.x_uint = op1->data.x_uint / op2->data.x_uint; + dest->is_negative = false; + } + } + return false; +} + bool bignum_rem(BigNum *dest, BigNum *op1, BigNum *op2) { assert(op1->kind == op2->kind); dest->kind = op1->kind; @@ -219,10 +259,28 @@ bool bignum_rem(BigNum *dest, BigNum *op1, BigNum *op2) { if (dest->kind == BigNumKindFloat) { dest->data.x_float = fmod(op1->data.x_float, op2->data.x_float); } else { - if (op1->is_negative || op2->is_negative) { - zig_panic("TODO handle remainder division with negative numbers"); - } + assert(!op2->is_negative); dest->data.x_uint = op1->data.x_uint % op2->data.x_uint; + dest->is_negative = op1->is_negative; + bignum_normalize(dest); + } + return false; +} + +bool bignum_mod(BigNum *dest, BigNum *op1, BigNum *op2) { + assert(op1->kind == op2->kind); + dest->kind = op1->kind; + + if (dest->kind == BigNumKindFloat) { + dest->data.x_float = fmod(fmod(op1->data.x_float, op2->data.x_float) + op2->data.x_float, op2->data.x_float); + } else { + assert(!op2->is_negative); + if (op1->is_negative) { + dest->data.x_uint = (op2->data.x_uint - op1->data.x_uint % op2->data.x_uint) % op2->data.x_uint; + } else { + dest->data.x_uint = op1->data.x_uint % op2->data.x_uint; + } + dest->is_negative = false; bignum_normalize(dest); } return false; diff --git a/src/bignum.hpp b/src/bignum.hpp index f8d960c490..6fb22c5a1d 100644 --- a/src/bignum.hpp +++ b/src/bignum.hpp @@ -37,7 +37,10 @@ bool bignum_add(BigNum *dest, BigNum *op1, BigNum *op2); bool bignum_sub(BigNum *dest, BigNum *op1, BigNum *op2); bool bignum_mul(BigNum *dest, BigNum *op1, BigNum *op2); bool bignum_div(BigNum *dest, BigNum *op1, BigNum *op2); +bool bignum_div_trunc(BigNum *dest, BigNum *op1, BigNum *op2); +bool bignum_div_floor(BigNum *dest, BigNum *op1, BigNum *op2); bool bignum_rem(BigNum *dest, BigNum *op1, BigNum *op2); +bool bignum_mod(BigNum *dest, BigNum *op1, BigNum *op2); bool bignum_or(BigNum *dest, BigNum *op1, BigNum *op2); bool bignum_and(BigNum *dest, BigNum *op1, BigNum *op2); diff --git a/src/codegen.cpp b/src/codegen.cpp index 6a5d2b1ef2..49dec3589e 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -538,6 +538,35 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, TypeTableEntry *type_entry, return fn_val; } +static LLVMValueRef get_floor_ceil_fn(CodeGen *g, TypeTableEntry *type_entry, ZigLLVMFnId fn_id) { + assert(type_entry->id == TypeTableEntryIdFloat); + + ZigLLVMFnKey key = {}; + key.id = fn_id; + key.data.floor_ceil.bit_count = (uint32_t)type_entry->data.floating.bit_count; + + auto existing_entry = g->llvm_fn_table.maybe_get(key); + if (existing_entry) + return existing_entry->value; + + const char *name; + if (fn_id == ZigLLVMFnIdFloor) { + name = "floor"; + } else if (fn_id == ZigLLVMFnIdCeil) { + name = "ceil"; + } else { + zig_unreachable(); + } + + char fn_name[64]; + sprintf(fn_name, "llvm.%s.f%zu", name, type_entry->data.floating.bit_count); + LLVMTypeRef fn_type = LLVMFunctionType(type_entry->type_ref, &type_entry->type_ref, 1, false); + LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type); + + g->llvm_fn_table.put(key, fn_val); + return fn_val; +} + static LLVMValueRef get_handle_value(CodeGen *g, LLVMValueRef ptr, TypeTableEntry *type, bool is_volatile) { if (type_has_bits(type)) { if (handle_is_ptr(type)) { @@ -618,7 +647,7 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) { case PanicMsgIdDivisionByZero: return buf_create_from_str("division by zero"); case PanicMsgIdRemainderDivisionByZero: - return buf_create_from_str("remainder division by zero"); + return buf_create_from_str("remainder division by zero or negative value"); case PanicMsgIdExactDivisionRemainder: return buf_create_from_str("exact division produced remainder"); case PanicMsgIdSliceWidenRemainder: @@ -1099,12 +1128,34 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, TypeTableEntry *type_entry, return result; } -static LLVMValueRef gen_div(CodeGen *g, bool want_debug_safety, LLVMValueRef val1, LLVMValueRef val2, - TypeTableEntry *type_entry, bool exact) -{ +static LLVMValueRef gen_floor(CodeGen *g, LLVMValueRef val, TypeTableEntry *type_entry) { + if (type_entry->id == TypeTableEntryIdInt) + return val; + LLVMValueRef floor_fn = get_floor_ceil_fn(g, type_entry, ZigLLVMFnIdFloor); + return LLVMBuildCall(g->builder, floor_fn, &val, 1, ""); +} + +static LLVMValueRef gen_ceil(CodeGen *g, LLVMValueRef val, TypeTableEntry *type_entry) { + if (type_entry->id == TypeTableEntryIdInt) + return val; + + LLVMValueRef ceil_fn = get_floor_ceil_fn(g, type_entry, ZigLLVMFnIdCeil); + return LLVMBuildCall(g->builder, ceil_fn, &val, 1, ""); +} + +enum DivKind { + DivKindFloat, + DivKindTrunc, + DivKindFloor, + DivKindExact, +}; + +static LLVMValueRef gen_div(CodeGen *g, bool want_debug_safety, LLVMValueRef val1, LLVMValueRef val2, + TypeTableEntry *type_entry, DivKind div_kind) +{ + LLVMValueRef zero = LLVMConstNull(type_entry->type_ref); if (want_debug_safety) { - LLVMValueRef zero = LLVMConstNull(type_entry->type_ref); LLVMValueRef is_zero_bit; if (type_entry->id == TypeTableEntryIdInt) { is_zero_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, ""); @@ -1140,55 +1191,111 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_debug_safety, LLVMValueRef val } if (type_entry->id == TypeTableEntryIdFloat) { - assert(!exact); - return LLVMBuildFDiv(g->builder, val1, val2, ""); + LLVMValueRef result = LLVMBuildFDiv(g->builder, val1, val2, ""); + switch (div_kind) { + case DivKindFloat: + return result; + case DivKindExact: + if (want_debug_safety) { + LLVMValueRef floored = gen_floor(g, result, type_entry); + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactOk"); + LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail"); + LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, ""); + + LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); + + LLVMPositionBuilderAtEnd(g->builder, fail_block); + gen_debug_safety_crash(g, PanicMsgIdExactDivisionRemainder); + + LLVMPositionBuilderAtEnd(g->builder, ok_block); + } + return result; + case DivKindTrunc: + { + LLVMValueRef floored = gen_floor(g, result, type_entry); + LLVMValueRef ceiled = gen_ceil(g, result, type_entry); + LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, ""); + return LLVMBuildSelect(g->builder, ltz, ceiled, floored, ""); + } + case DivKindFloor: + return gen_floor(g, result, type_entry); + } + zig_unreachable(); } assert(type_entry->id == TypeTableEntryIdInt); - if (exact) { - if (want_debug_safety) { - LLVMValueRef remainder_val; + switch (div_kind) { + case DivKindFloat: + zig_unreachable(); + case DivKindTrunc: if (type_entry->data.integral.is_signed) { - remainder_val = LLVMBuildSRem(g->builder, val1, val2, ""); + return LLVMBuildSDiv(g->builder, val1, val2, ""); } else { - remainder_val = LLVMBuildURem(g->builder, val1, val2, ""); + return LLVMBuildUDiv(g->builder, val1, val2, ""); } - LLVMValueRef zero = LLVMConstNull(type_entry->type_ref); - LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, ""); + case DivKindExact: + if (want_debug_safety) { + LLVMValueRef remainder_val; + if (type_entry->data.integral.is_signed) { + remainder_val = LLVMBuildSRem(g->builder, val1, val2, ""); + } else { + remainder_val = LLVMBuildURem(g->builder, val1, val2, ""); + } + LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, ""); - LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactOk"); - LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail"); - LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactOk"); + LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail"); + LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); - LLVMPositionBuilderAtEnd(g->builder, fail_block); - gen_debug_safety_crash(g, PanicMsgIdExactDivisionRemainder); + LLVMPositionBuilderAtEnd(g->builder, fail_block); + gen_debug_safety_crash(g, PanicMsgIdExactDivisionRemainder); - LLVMPositionBuilderAtEnd(g->builder, ok_block); - } - if (type_entry->data.integral.is_signed) { - return LLVMBuildExactSDiv(g->builder, val1, val2, ""); - } else { - return LLVMBuildExactUDiv(g->builder, val1, val2, ""); - } - } else { - if (type_entry->data.integral.is_signed) { - return LLVMBuildSDiv(g->builder, val1, val2, ""); - } else { - return LLVMBuildUDiv(g->builder, val1, val2, ""); - } + LLVMPositionBuilderAtEnd(g->builder, ok_block); + } + if (type_entry->data.integral.is_signed) { + return LLVMBuildExactSDiv(g->builder, val1, val2, ""); + } else { + return LLVMBuildExactUDiv(g->builder, val1, val2, ""); + } + case DivKindFloor: + { + if (!type_entry->data.integral.is_signed) { + return LLVMBuildUDiv(g->builder, val1, val2, ""); + } + // const result = @divTrunc(a, b); + // if (result >= 0 or result * b == a) + // return result; + // else + // return result - 1; + + LLVMValueRef result = LLVMBuildSDiv(g->builder, val1, val2, ""); + LLVMValueRef is_pos = LLVMBuildICmp(g->builder, LLVMIntSGE, result, zero, ""); + LLVMValueRef orig_num = LLVMBuildNSWMul(g->builder, result, val2, ""); + LLVMValueRef orig_ok = LLVMBuildICmp(g->builder, LLVMIntEQ, orig_num, val1, ""); + LLVMValueRef ok_bit = LLVMBuildOr(g->builder, orig_ok, is_pos, ""); + LLVMValueRef one = LLVMConstInt(type_entry->type_ref, 1, true); + LLVMValueRef result_minus_1 = LLVMBuildNSWSub(g->builder, result, one, ""); + return LLVMBuildSelect(g->builder, ok_bit, result, result_minus_1, ""); + } } + zig_unreachable(); } -static LLVMValueRef gen_rem(CodeGen *g, bool want_debug_safety, LLVMValueRef val1, LLVMValueRef val2, - TypeTableEntry *type_entry) -{ +enum RemKind { + RemKindRem, + RemKindMod, +}; +static LLVMValueRef gen_rem(CodeGen *g, bool want_debug_safety, LLVMValueRef val1, LLVMValueRef val2, + TypeTableEntry *type_entry, RemKind rem_kind) +{ + LLVMValueRef zero = LLVMConstNull(type_entry->type_ref); if (want_debug_safety) { - LLVMValueRef zero = LLVMConstNull(type_entry->type_ref); LLVMValueRef is_zero_bit; if (type_entry->id == TypeTableEntryIdInt) { - is_zero_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, ""); + LLVMIntPredicate pred = type_entry->data.integral.is_signed ? LLVMIntSLE : LLVMIntEQ; + is_zero_bit = LLVMBuildICmp(g->builder, pred, val2, zero, ""); } else if (type_entry->id == TypeTableEntryIdFloat) { is_zero_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, val2, zero, ""); } else { @@ -1202,30 +1309,30 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_debug_safety, LLVMValueRef val gen_debug_safety_crash(g, PanicMsgIdRemainderDivisionByZero); LLVMPositionBuilderAtEnd(g->builder, rem_zero_ok_block); - - if (type_entry->id == TypeTableEntryIdInt && type_entry->data.integral.is_signed) { - LLVMValueRef neg_1_value = LLVMConstInt(type_entry->type_ref, -1, true); - LLVMValueRef int_min_value = LLVMConstInt(type_entry->type_ref, min_signed_val(type_entry), true); - LLVMBasicBlockRef overflow_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemOverflowOk"); - LLVMBasicBlockRef overflow_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemOverflowFail"); - LLVMValueRef num_is_int_min = LLVMBuildICmp(g->builder, LLVMIntEQ, val1, int_min_value, ""); - LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, ""); - LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, ""); - LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block); - - LLVMPositionBuilderAtEnd(g->builder, overflow_fail_block); - gen_debug_safety_crash(g, PanicMsgIdIntegerOverflow); - - LLVMPositionBuilderAtEnd(g->builder, overflow_ok_block); - } } if (type_entry->id == TypeTableEntryIdFloat) { - return LLVMBuildFRem(g->builder, val1, val2, ""); + if (rem_kind == RemKindRem) { + return LLVMBuildFRem(g->builder, val1, val2, ""); + } else { + LLVMValueRef a = LLVMBuildFRem(g->builder, val1, val2, ""); + LLVMValueRef b = LLVMBuildFAdd(g->builder, a, val2, ""); + LLVMValueRef c = LLVMBuildFRem(g->builder, b, val2, ""); + LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, ""); + return LLVMBuildSelect(g->builder, ltz, c, a, ""); + } } else { assert(type_entry->id == TypeTableEntryIdInt); if (type_entry->data.integral.is_signed) { - return LLVMBuildSRem(g->builder, val1, val2, ""); + if (rem_kind == RemKindRem) { + return LLVMBuildSRem(g->builder, val1, val2, ""); + } else { + LLVMValueRef a = LLVMBuildSRem(g->builder, val1, val2, ""); + LLVMValueRef b = LLVMBuildNSWAdd(g->builder, a, val2, ""); + LLVMValueRef c = LLVMBuildSRem(g->builder, b, val2, ""); + LLVMValueRef ltz = LLVMBuildICmp(g->builder, LLVMIntSLT, val1, zero, ""); + return LLVMBuildSelect(g->builder, ltz, c, a, ""); + } } else { return LLVMBuildURem(g->builder, val1, val2, ""); } @@ -1252,6 +1359,7 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable, case IrBinOpInvalid: case IrBinOpArrayCat: case IrBinOpArrayMult: + case IrBinOpRemUnspecified: zig_unreachable(); case IrBinOpBoolOr: return LLVMBuildOr(g->builder, op1_value, op2_value, ""); @@ -1367,10 +1475,18 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable, } else { zig_unreachable(); } - case IrBinOpDiv: - return gen_div(g, want_debug_safety, op1_value, op2_value, type_entry, false); - case IrBinOpRem: - return gen_rem(g, want_debug_safety, op1_value, op2_value, type_entry); + case IrBinOpDivUnspecified: + return gen_div(g, want_debug_safety, op1_value, op2_value, type_entry, DivKindFloat); + case IrBinOpDivExact: + return gen_div(g, want_debug_safety, op1_value, op2_value, type_entry, DivKindExact); + case IrBinOpDivTrunc: + return gen_div(g, want_debug_safety, op1_value, op2_value, type_entry, DivKindTrunc); + case IrBinOpDivFloor: + return gen_div(g, want_debug_safety, op1_value, op2_value, type_entry, DivKindFloor); + case IrBinOpRemRem: + return gen_rem(g, want_debug_safety, op1_value, op2_value, type_entry, RemKindRem); + case IrBinOpRemMod: + return gen_rem(g, want_debug_safety, op1_value, op2_value, type_entry, RemKindMod); } zig_unreachable(); } @@ -2353,14 +2469,6 @@ static LLVMValueRef ir_render_fence(CodeGen *g, IrExecutable *executable, IrInst return nullptr; } -static LLVMValueRef ir_render_div_exact(CodeGen *g, IrExecutable *executable, IrInstructionDivExact *instruction) { - LLVMValueRef op1_val = ir_llvm_value(g, instruction->op1); - LLVMValueRef op2_val = ir_llvm_value(g, instruction->op2); - - bool want_debug_safety = ir_want_debug_safety(g, &instruction->base); - return gen_div(g, want_debug_safety, op1_val, op2_val, instruction->base.value.type, true); -} - static LLVMValueRef ir_render_truncate(CodeGen *g, IrExecutable *executable, IrInstructionTruncate *instruction) { LLVMValueRef target_val = ir_llvm_value(g, instruction->target); TypeTableEntry *dest_type = instruction->base.value.type; @@ -2965,8 +3073,6 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_cmpxchg(g, executable, (IrInstructionCmpxchg *)instruction); case IrInstructionIdFence: return ir_render_fence(g, executable, (IrInstructionFence *)instruction); - case IrInstructionIdDivExact: - return ir_render_div_exact(g, executable, (IrInstructionDivExact *)instruction); case IrInstructionIdTruncate: return ir_render_truncate(g, executable, (IrInstructionTruncate *)instruction); case IrInstructionIdBoolNot: @@ -4320,7 +4426,6 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdEmbedFile, "embedFile", 1); create_builtin_fn(g, BuiltinFnIdCmpExchange, "cmpxchg", 5); create_builtin_fn(g, BuiltinFnIdFence, "fence", 1); - create_builtin_fn(g, BuiltinFnIdDivExact, "divExact", 2); create_builtin_fn(g, BuiltinFnIdTruncate, "truncate", 2); create_builtin_fn(g, BuiltinFnIdCompileErr, "compileError", 1); create_builtin_fn(g, BuiltinFnIdCompileLog, "compileLog", SIZE_MAX); @@ -4335,6 +4440,11 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdEnumTagName, "enumTagName", 1); create_builtin_fn(g, BuiltinFnIdFieldParentPtr, "fieldParentPtr", 3); create_builtin_fn(g, BuiltinFnIdOffsetOf, "offsetOf", 2); + create_builtin_fn(g, BuiltinFnIdDivExact, "divExact", 2); + create_builtin_fn(g, BuiltinFnIdDivTrunc, "divTrunc", 2); + create_builtin_fn(g, BuiltinFnIdDivFloor, "divFloor", 2); + create_builtin_fn(g, BuiltinFnIdRem, "rem", 2); + create_builtin_fn(g, BuiltinFnIdMod, "mod", 2); } static const char *bool_to_str(bool b) { diff --git a/src/error.cpp b/src/error.cpp index bc4f9feba2..32d8601eda 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -23,6 +23,8 @@ const char *err_str(int err) { case ErrorOverflow: return "overflow"; case ErrorPathAlreadyExists: return "path already exists"; case ErrorUnexpected: return "unexpected error"; + case ErrorExactDivRemainder: return "exact division had a remainder"; + case ErrorNegativeDenominator: return "negative denominator"; } return "(invalid error)"; } diff --git a/src/error.hpp b/src/error.hpp index 7619f1c856..08b66d11f3 100644 --- a/src/error.hpp +++ b/src/error.hpp @@ -23,6 +23,8 @@ enum Error { ErrorOverflow, ErrorPathAlreadyExists, ErrorUnexpected, + ErrorExactDivRemainder, + ErrorNegativeDenominator, }; const char *err_str(int err); diff --git a/src/ir.cpp b/src/ir.cpp index 0f0c407308..9462d05e99 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -400,10 +400,6 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionFence *) { return IrInstructionIdFence; } -static constexpr IrInstructionId ir_instruction_id(IrInstructionDivExact *) { - return IrInstructionIdDivExact; -} - static constexpr IrInstructionId ir_instruction_id(IrInstructionTruncate *) { return IrInstructionIdTruncate; } @@ -1628,23 +1624,6 @@ static IrInstruction *ir_build_fence_from(IrBuilder *irb, IrInstruction *old_ins return new_instruction; } -static IrInstruction *ir_build_div_exact(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *op1, IrInstruction *op2) { - IrInstructionDivExact *instruction = ir_build_instruction(irb, scope, source_node); - instruction->op1 = op1; - instruction->op2 = op2; - - ir_ref_instruction(op1, irb->current_basic_block); - ir_ref_instruction(op2, irb->current_basic_block); - - return &instruction->base; -} - -static IrInstruction *ir_build_div_exact_from(IrBuilder *irb, IrInstruction *old_instruction, IrInstruction *op1, IrInstruction *op2) { - IrInstruction *new_instruction = ir_build_div_exact(irb, old_instruction->scope, old_instruction->source_node, op1, op2); - ir_link_new_instruction(new_instruction, old_instruction); - return new_instruction; -} - static IrInstruction *ir_build_truncate(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *dest_type, IrInstruction *target) { IrInstructionTruncate *instruction = ir_build_instruction(irb, scope, source_node); instruction->dest_type = dest_type; @@ -2597,14 +2576,6 @@ static IrInstruction *ir_instruction_fence_get_dep(IrInstructionFence *instructi } } -static IrInstruction *ir_instruction_divexact_get_dep(IrInstructionDivExact *instruction, size_t index) { - switch (index) { - case 0: return instruction->op1; - case 1: return instruction->op2; - default: return nullptr; - } -} - static IrInstruction *ir_instruction_truncate_get_dep(IrInstructionTruncate *instruction, size_t index) { switch (index) { case 0: return instruction->dest_type; @@ -3022,8 +2993,6 @@ static IrInstruction *ir_instruction_get_dep(IrInstruction *instruction, size_t return ir_instruction_cmpxchg_get_dep((IrInstructionCmpxchg *) instruction, index); case IrInstructionIdFence: return ir_instruction_fence_get_dep((IrInstructionFence *) instruction, index); - case IrInstructionIdDivExact: - return ir_instruction_divexact_get_dep((IrInstructionDivExact *) instruction, index); case IrInstructionIdTruncate: return ir_instruction_truncate_get_dep((IrInstructionTruncate *) instruction, index); case IrInstructionIdIntType: @@ -3644,9 +3613,9 @@ static IrInstruction *ir_gen_bin_op(IrBuilder *irb, Scope *scope, AstNode *node) case BinOpTypeAssignTimesWrap: return ir_gen_assign_op(irb, scope, node, IrBinOpMultWrap); case BinOpTypeAssignDiv: - return ir_gen_assign_op(irb, scope, node, IrBinOpDiv); + return ir_gen_assign_op(irb, scope, node, IrBinOpDivUnspecified); case BinOpTypeAssignMod: - return ir_gen_assign_op(irb, scope, node, IrBinOpRem); + return ir_gen_assign_op(irb, scope, node, IrBinOpRemUnspecified); case BinOpTypeAssignPlus: return ir_gen_assign_op(irb, scope, node, IrBinOpAdd); case BinOpTypeAssignPlusWrap: @@ -3712,9 +3681,9 @@ static IrInstruction *ir_gen_bin_op(IrBuilder *irb, Scope *scope, AstNode *node) case BinOpTypeMultWrap: return ir_gen_bin_op_id(irb, scope, node, IrBinOpMultWrap); case BinOpTypeDiv: - return ir_gen_bin_op_id(irb, scope, node, IrBinOpDiv); + return ir_gen_bin_op_id(irb, scope, node, IrBinOpDivUnspecified); case BinOpTypeMod: - return ir_gen_bin_op_id(irb, scope, node, IrBinOpRem); + return ir_gen_bin_op_id(irb, scope, node, IrBinOpRemUnspecified); case BinOpTypeArrayCat: return ir_gen_bin_op_id(irb, scope, node, IrBinOpArrayCat); case BinOpTypeArrayMult: @@ -4138,7 +4107,63 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo if (arg1_value == irb->codegen->invalid_instruction) return arg1_value; - return ir_build_div_exact(irb, scope, node, arg0_value, arg1_value); + return ir_build_bin_op(irb, scope, node, IrBinOpDivExact, arg0_value, arg1_value, true); + } + case BuiltinFnIdDivTrunc: + { + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope); + if (arg0_value == irb->codegen->invalid_instruction) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + IrInstruction *arg1_value = ir_gen_node(irb, arg1_node, scope); + if (arg1_value == irb->codegen->invalid_instruction) + return arg1_value; + + return ir_build_bin_op(irb, scope, node, IrBinOpDivTrunc, arg0_value, arg1_value, true); + } + case BuiltinFnIdDivFloor: + { + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope); + if (arg0_value == irb->codegen->invalid_instruction) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + IrInstruction *arg1_value = ir_gen_node(irb, arg1_node, scope); + if (arg1_value == irb->codegen->invalid_instruction) + return arg1_value; + + return ir_build_bin_op(irb, scope, node, IrBinOpDivFloor, arg0_value, arg1_value, true); + } + case BuiltinFnIdRem: + { + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope); + if (arg0_value == irb->codegen->invalid_instruction) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + IrInstruction *arg1_value = ir_gen_node(irb, arg1_node, scope); + if (arg1_value == irb->codegen->invalid_instruction) + return arg1_value; + + return ir_build_bin_op(irb, scope, node, IrBinOpRemRem, arg0_value, arg1_value, true); + } + case BuiltinFnIdMod: + { + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope); + if (arg0_value == irb->codegen->invalid_instruction) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + IrInstruction *arg1_value = ir_gen_node(irb, arg1_node, scope); + if (arg1_value == irb->codegen->invalid_instruction) + return arg1_value; + + return ir_build_bin_op(irb, scope, node, IrBinOpRemMod, arg0_value, arg1_value, true); } case BuiltinFnIdTruncate: { @@ -8024,32 +8049,70 @@ static TypeTableEntry *ir_analyze_bin_op_cmp(IrAnalyze *ira, IrInstructionBinOp return ira->codegen->builtin_types.entry_bool; } +enum EvalBigNumSpecial { + EvalBigNumSpecialNone, + EvalBigNumSpecialWrapping, + EvalBigNumSpecialExact, +}; + static int ir_eval_bignum(ConstExprValue *op1_val, ConstExprValue *op2_val, ConstExprValue *out_val, bool (*bignum_fn)(BigNum *, BigNum *, BigNum *), - TypeTableEntry *type, bool wrapping_op) + TypeTableEntry *type, EvalBigNumSpecial special) { bool is_int = false; bool is_float = false; - if (bignum_fn == bignum_div || bignum_fn == bignum_rem) { - if (type->id == TypeTableEntryIdInt || - type->id == TypeTableEntryIdNumLitInt) - { - is_int = true; - } else if (type->id == TypeTableEntryIdFloat || - type->id == TypeTableEntryIdNumLitFloat) - { - is_float = true; - } + if (type->id == TypeTableEntryIdInt || + type->id == TypeTableEntryIdNumLitInt) + { + is_int = true; + } else if (type->id == TypeTableEntryIdFloat || + type->id == TypeTableEntryIdNumLitFloat) + { + is_float = true; + } else { + zig_unreachable(); + } + if (bignum_fn == bignum_div || bignum_fn == bignum_rem || bignum_fn == bignum_mod || + bignum_fn == bignum_div_trunc || bignum_fn == bignum_div_floor) + { if ((is_int && op2_val->data.x_bignum.data.x_uint == 0) || (is_float && op2_val->data.x_bignum.data.x_float == 0.0)) { return ErrorDivByZero; } } + if (bignum_fn == bignum_rem || bignum_fn == bignum_mod) { + BigNum zero; + if (is_float) { + bignum_init_float(&zero, 0.0); + } else { + bignum_init_unsigned(&zero, 0); + } + if (bignum_cmp_lt(&op2_val->data.x_bignum, &zero)) { + return ErrorNegativeDenominator; + } + } + + if (special == EvalBigNumSpecialExact) { + assert(bignum_fn == bignum_div); + BigNum remainder; + if (bignum_rem(&remainder, &op1_val->data.x_bignum, &op2_val->data.x_bignum)) { + return ErrorOverflow; + } + BigNum zero; + if (is_float) { + bignum_init_float(&zero, 0.0); + } else { + bignum_init_unsigned(&zero, 0); + } + if (bignum_cmp_neq(&remainder, &zero)) { + return ErrorExactDivRemainder; + } + } bool overflow = bignum_fn(&out_val->data.x_bignum, &op1_val->data.x_bignum, &op2_val->data.x_bignum); if (overflow) { - if (wrapping_op) { + if (special == EvalBigNumSpecialWrapping) { zig_panic("TODO compiler bug, implement compile-time wrapping arithmetic for >= 64 bit ints"); } else { return ErrorOverflow; @@ -8059,7 +8122,7 @@ static int ir_eval_bignum(ConstExprValue *op1_val, ConstExprValue *op2_val, if (type->id == TypeTableEntryIdInt && !bignum_fits_in_bits(&out_val->data.x_bignum, type->data.integral.bit_count, type->data.integral.is_signed)) { - if (wrapping_op) { + if (special == EvalBigNumSpecialWrapping) { if (type->data.integral.is_signed) { out_val->data.x_bignum.data.x_uint = max_unsigned_val(type) - out_val->data.x_bignum.data.x_uint + 1; out_val->data.x_bignum.is_negative = !out_val->data.x_bignum.is_negative; @@ -8093,35 +8156,44 @@ static int ir_eval_math_op(TypeTableEntry *canon_type, ConstExprValue *op1_val, case IrBinOpCmpGreaterOrEq: case IrBinOpArrayCat: case IrBinOpArrayMult: + case IrBinOpRemUnspecified: zig_unreachable(); case IrBinOpBinOr: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_or, canon_type, false); + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_or, canon_type, EvalBigNumSpecialNone); case IrBinOpBinXor: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_xor, canon_type, false); + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_xor, canon_type, EvalBigNumSpecialNone); case IrBinOpBinAnd: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_and, canon_type, false); + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_and, canon_type, EvalBigNumSpecialNone); case IrBinOpBitShiftLeft: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_shl, canon_type, false); + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_shl, canon_type, EvalBigNumSpecialNone); case IrBinOpBitShiftLeftWrap: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_shl, canon_type, true); + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_shl, canon_type, EvalBigNumSpecialWrapping); case IrBinOpBitShiftRight: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_shr, canon_type, false); + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_shr, canon_type, EvalBigNumSpecialNone); case IrBinOpAdd: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_add, canon_type, false); + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_add, canon_type, EvalBigNumSpecialNone); case IrBinOpAddWrap: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_add, canon_type, true); + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_add, canon_type, EvalBigNumSpecialWrapping); case IrBinOpSub: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_sub, canon_type, false); + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_sub, canon_type, EvalBigNumSpecialNone); case IrBinOpSubWrap: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_sub, canon_type, true); + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_sub, canon_type, EvalBigNumSpecialWrapping); case IrBinOpMult: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_mul, canon_type, false); + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_mul, canon_type, EvalBigNumSpecialNone); case IrBinOpMultWrap: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_mul, canon_type, true); - case IrBinOpDiv: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_div, canon_type, false); - case IrBinOpRem: - return ir_eval_bignum(op1_val, op2_val, out_val, bignum_rem, canon_type, false); + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_mul, canon_type, EvalBigNumSpecialWrapping); + case IrBinOpDivUnspecified: + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_div, canon_type, EvalBigNumSpecialNone); + case IrBinOpDivTrunc: + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_div_trunc, canon_type, EvalBigNumSpecialNone); + case IrBinOpDivFloor: + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_div_floor, canon_type, EvalBigNumSpecialNone); + case IrBinOpDivExact: + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_div, canon_type, EvalBigNumSpecialExact); + case IrBinOpRemRem: + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_rem, canon_type, EvalBigNumSpecialNone); + case IrBinOpRemMod: + return ir_eval_bignum(op1_val, op2_val, out_val, bignum_mod, canon_type, EvalBigNumSpecialNone); } zig_unreachable(); } @@ -8135,6 +8207,31 @@ static TypeTableEntry *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstructionBinOp return resolved_type; IrBinOp op_id = bin_op_instruction->op_id; + bool is_int = resolved_type->id == TypeTableEntryIdInt || resolved_type->id == TypeTableEntryIdNumLitInt; + bool is_signed = ((resolved_type->id == TypeTableEntryIdInt && resolved_type->data.integral.is_signed) || + (resolved_type->id == TypeTableEntryIdNumLitInt && + (op1->value.data.x_bignum.is_negative || op2->value.data.x_bignum.is_negative))); + if (op_id == IrBinOpDivUnspecified) { + if (is_signed) { + ir_add_error(ira, &bin_op_instruction->base, + buf_sprintf("division with '%s' and '%s': signed integers must use @divTrunc, @divFloor, or @divExact", + buf_ptr(&op1->value.type->name), + buf_ptr(&op2->value.type->name))); + return ira->codegen->builtin_types.entry_invalid; + } else if (is_int) { + op_id = IrBinOpDivTrunc; + } + } else if (op_id == IrBinOpRemUnspecified) { + if (is_signed) { + ir_add_error(ira, &bin_op_instruction->base, + buf_sprintf("remainder division with '%s' and '%s': signed integers must use @rem or @mod", + buf_ptr(&op1->value.type->name), + buf_ptr(&op2->value.type->name))); + return ira->codegen->builtin_types.entry_invalid; + } + op_id = IrBinOpRemRem; + } + if (resolved_type->id == TypeTableEntryIdInt || resolved_type->id == TypeTableEntryIdNumLitInt) { @@ -8144,8 +8241,12 @@ static TypeTableEntry *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstructionBinOp (op_id == IrBinOpAdd || op_id == IrBinOpSub || op_id == IrBinOpMult || - op_id == IrBinOpDiv || - op_id == IrBinOpRem)) + op_id == IrBinOpDivUnspecified || + op_id == IrBinOpDivTrunc || + op_id == IrBinOpDivFloor || + op_id == IrBinOpDivExact || + op_id == IrBinOpRemRem || + op_id == IrBinOpRemMod)) { // float } else { @@ -8176,20 +8277,25 @@ static TypeTableEntry *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstructionBinOp int err; if ((err = ir_eval_math_op(resolved_type, op1_val, op_id, op2_val, out_val))) { if (err == ErrorDivByZero) { - ir_add_error_node(ira, bin_op_instruction->base.source_node, - buf_sprintf("division by zero is undefined")); + ir_add_error(ira, &bin_op_instruction->base, buf_sprintf("division by zero is undefined")); return ira->codegen->builtin_types.entry_invalid; } else if (err == ErrorOverflow) { - ir_add_error_node(ira, bin_op_instruction->base.source_node, - buf_sprintf("operation caused overflow")); + ir_add_error(ira, &bin_op_instruction->base, buf_sprintf("operation caused overflow")); return ira->codegen->builtin_types.entry_invalid; + } else if (err == ErrorExactDivRemainder) { + ir_add_error(ira, &bin_op_instruction->base, buf_sprintf("exact division had a remainder")); + return ira->codegen->builtin_types.entry_invalid; + } else if (err == ErrorNegativeDenominator) { + ir_add_error(ira, &bin_op_instruction->base, buf_sprintf("negative denominator")); + return ira->codegen->builtin_types.entry_invalid; + } else { + zig_unreachable(); } return ira->codegen->builtin_types.entry_invalid; } ir_num_lit_fits_in_other_type(ira, &bin_op_instruction->base, resolved_type); return resolved_type; - } ir_build_bin_op_from(&ira->new_irb, &bin_op_instruction->base, op_id, @@ -8197,6 +8303,7 @@ static TypeTableEntry *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstructionBinOp return resolved_type; } + static TypeTableEntry *ir_analyze_array_cat(IrAnalyze *ira, IrInstructionBinOp *instruction) { IrInstruction *op1 = instruction->op1->other; TypeTableEntry *op1_type = op1->value.type; @@ -8416,8 +8523,13 @@ static TypeTableEntry *ir_analyze_instruction_bin_op(IrAnalyze *ira, IrInstructi case IrBinOpSubWrap: case IrBinOpMult: case IrBinOpMultWrap: - case IrBinOpDiv: - case IrBinOpRem: + case IrBinOpDivUnspecified: + case IrBinOpDivTrunc: + case IrBinOpDivFloor: + case IrBinOpDivExact: + case IrBinOpRemUnspecified: + case IrBinOpRemRem: + case IrBinOpRemMod: return ir_analyze_bin_op_math(ira, bin_op_instruction); case IrBinOpArrayCat: return ir_analyze_array_cat(ira, bin_op_instruction); @@ -10007,6 +10119,21 @@ static TypeTableEntry *ir_analyze_instruction_field_ptr(IrAnalyze *ira, IrInstru buf_ptr(&child_type->name), buf_ptr(field_name))); return ira->codegen->builtin_types.entry_invalid; } + } else if (child_type->id == TypeTableEntryIdFloat) { + if (buf_eql_str(field_name, "bit_count")) { + bool ptr_is_const = true; + bool ptr_is_volatile = false; + return ir_analyze_const_ptr(ira, &field_ptr_instruction->base, + create_const_unsigned_negative(ira->codegen->builtin_types.entry_num_lit_int, + child_type->data.floating.bit_count, false), + ira->codegen->builtin_types.entry_num_lit_int, + ConstPtrMutComptimeConst, ptr_is_const, ptr_is_volatile); + } else { + ir_add_error(ira, &field_ptr_instruction->base, + buf_sprintf("type '%s' has no member called '%s'", + buf_ptr(&child_type->name), buf_ptr(field_name))); + return ira->codegen->builtin_types.entry_invalid; + } } else { ir_add_error(ira, &field_ptr_instruction->base, buf_sprintf("type '%s' does not support field access", buf_ptr(&child_type->name))); @@ -12030,71 +12157,6 @@ static TypeTableEntry *ir_analyze_instruction_fence(IrAnalyze *ira, IrInstructio return ira->codegen->builtin_types.entry_void; } -static TypeTableEntry *ir_analyze_instruction_div_exact(IrAnalyze *ira, IrInstructionDivExact *instruction) { - IrInstruction *op1 = instruction->op1->other; - if (type_is_invalid(op1->value.type)) - return ira->codegen->builtin_types.entry_invalid; - - IrInstruction *op2 = instruction->op2->other; - if (type_is_invalid(op2->value.type)) - return ira->codegen->builtin_types.entry_invalid; - - - IrInstruction *peer_instructions[] = { op1, op2 }; - TypeTableEntry *result_type = ir_resolve_peer_types(ira, instruction->base.source_node, peer_instructions, 2); - - if (type_is_invalid(result_type)) - return ira->codegen->builtin_types.entry_invalid; - - if (result_type->id != TypeTableEntryIdInt && - result_type->id != TypeTableEntryIdNumLitInt) - { - ir_add_error(ira, &instruction->base, - buf_sprintf("expected integer type, found '%s'", buf_ptr(&result_type->name))); - return ira->codegen->builtin_types.entry_invalid; - } - - IrInstruction *casted_op1 = ir_implicit_cast(ira, op1, result_type); - if (type_is_invalid(casted_op1->value.type)) - return ira->codegen->builtin_types.entry_invalid; - - IrInstruction *casted_op2 = ir_implicit_cast(ira, op2, result_type); - if (type_is_invalid(casted_op2->value.type)) - return ira->codegen->builtin_types.entry_invalid; - - if (casted_op1->value.special == ConstValSpecialStatic && - casted_op2->value.special == ConstValSpecialStatic) - { - ConstExprValue *op1_val = ir_resolve_const(ira, casted_op1, UndefBad); - ConstExprValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad); - assert(op1_val); - assert(op2_val); - - if (op1_val->data.x_bignum.data.x_uint == 0) { - ir_add_error(ira, &instruction->base, buf_sprintf("division by zero")); - return ira->codegen->builtin_types.entry_invalid; - } - - BigNum remainder; - if (bignum_rem(&remainder, &op1_val->data.x_bignum, &op2_val->data.x_bignum)) { - ir_add_error(ira, &instruction->base, buf_sprintf("integer overflow")); - return ira->codegen->builtin_types.entry_invalid; - } - - if (remainder.data.x_uint != 0) { - ir_add_error(ira, &instruction->base, buf_sprintf("exact division had a remainder")); - return ira->codegen->builtin_types.entry_invalid; - } - - ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base); - bignum_div(&out_val->data.x_bignum, &op1_val->data.x_bignum, &op2_val->data.x_bignum); - return result_type; - } - - ir_build_div_exact_from(&ira->new_irb, &instruction->base, casted_op1, casted_op2); - return result_type; -} - static TypeTableEntry *ir_analyze_instruction_truncate(IrAnalyze *ira, IrInstructionTruncate *instruction) { IrInstruction *dest_type_value = instruction->dest_type->other; TypeTableEntry *dest_type = ir_resolve_type(ira, dest_type_value); @@ -13261,8 +13323,6 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi return ir_analyze_instruction_cmpxchg(ira, (IrInstructionCmpxchg *)instruction); case IrInstructionIdFence: return ir_analyze_instruction_fence(ira, (IrInstructionFence *)instruction); - case IrInstructionIdDivExact: - return ir_analyze_instruction_div_exact(ira, (IrInstructionDivExact *)instruction); case IrInstructionIdTruncate: return ir_analyze_instruction_truncate(ira, (IrInstructionTruncate *)instruction); case IrInstructionIdIntType: @@ -13469,7 +13529,6 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdMinValue: case IrInstructionIdMaxValue: case IrInstructionIdEmbedFile: - case IrInstructionIdDivExact: case IrInstructionIdTruncate: case IrInstructionIdIntType: case IrInstructionIdBoolNot: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index eddbd58752..e676ad73a9 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -109,10 +109,20 @@ static const char *ir_bin_op_id_str(IrBinOp op_id) { return "*"; case IrBinOpMultWrap: return "*%"; - case IrBinOpDiv: + case IrBinOpDivUnspecified: return "/"; - case IrBinOpRem: + case IrBinOpDivTrunc: + return "@divTrunc"; + case IrBinOpDivFloor: + return "@divFloor"; + case IrBinOpDivExact: + return "@divExact"; + case IrBinOpRemUnspecified: return "%"; + case IrBinOpRemRem: + return "@rem"; + case IrBinOpRemMod: + return "@mod"; case IrBinOpArrayCat: return "++"; case IrBinOpArrayMult: @@ -580,14 +590,6 @@ static void ir_print_fence(IrPrint *irp, IrInstructionFence *instruction) { fprintf(irp->f, ")"); } -static void ir_print_div_exact(IrPrint *irp, IrInstructionDivExact *instruction) { - fprintf(irp->f, "@divExact("); - ir_print_other_instruction(irp, instruction->op1); - fprintf(irp->f, ", "); - ir_print_other_instruction(irp, instruction->op2); - fprintf(irp->f, ")"); -} - static void ir_print_truncate(IrPrint *irp, IrInstructionTruncate *instruction) { fprintf(irp->f, "@truncate("); ir_print_other_instruction(irp, instruction->dest_type); @@ -1056,9 +1058,6 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdFence: ir_print_fence(irp, (IrInstructionFence *)instruction); break; - case IrInstructionIdDivExact: - ir_print_div_exact(irp, (IrInstructionDivExact *)instruction); - break; case IrInstructionIdTruncate: ir_print_truncate(irp, (IrInstructionTruncate *)instruction); break; diff --git a/src/link.cpp b/src/link.cpp index 09bea8a999..f67b05724a 100644 --- a/src/link.cpp +++ b/src/link.cpp @@ -297,6 +297,7 @@ static void construct_linker_job_elf(LinkJob *lj) { lj->args.append("-lgcc"); lj->args.append("-lgcc_eh"); lj->args.append("-lc"); + lj->args.append("-lm"); lj->args.append("--end-group"); } else { lj->args.append("-lgcc"); @@ -304,6 +305,7 @@ static void construct_linker_job_elf(LinkJob *lj) { lj->args.append("-lgcc_s"); lj->args.append("--no-as-needed"); lj->args.append("-lc"); + lj->args.append("-lm"); lj->args.append("-lgcc"); lj->args.append("--as-needed"); lj->args.append("-lgcc_s"); diff --git a/std/elf.zig b/std/elf.zig index e95e994295..ca302a3cdb 100644 --- a/std/elf.zig +++ b/std/elf.zig @@ -165,9 +165,9 @@ pub const Elf = struct { if (elf.string_section_index >= sh_entry_count) return error.InvalidFormat; const sh_byte_count = u64(sh_entry_size) * u64(sh_entry_count); - const end_sh = %return math.addOverflow(u64, elf.section_header_offset, sh_byte_count); + const end_sh = %return math.add(u64, elf.section_header_offset, sh_byte_count); const ph_byte_count = u64(ph_entry_size) * u64(ph_entry_count); - const end_ph = %return math.addOverflow(u64, elf.program_header_offset, ph_byte_count); + const end_ph = %return math.add(u64, elf.program_header_offset, ph_byte_count); const stream_end = %return elf.in_stream.getEndPos(); if (stream_end < end_sh or stream_end < end_ph) { @@ -214,8 +214,7 @@ pub const Elf = struct { for (elf.section_headers) |*section| { if (section.sh_type != SHT_NOBITS) { - const file_end_offset = %return math.addOverflow(u64, - section.offset, section.size); + const file_end_offset = %return math.add(u64, section.offset, section.size); if (stream_end < file_end_offset) return error.InvalidFormat; } } diff --git a/std/fmt.zig b/std/fmt.zig index 3de5c0eb4c..80ca4af778 100644 --- a/std/fmt.zig +++ b/std/fmt.zig @@ -305,8 +305,8 @@ pub fn parseUnsigned(comptime T: type, buf: []const u8, radix: u8) -> %T { for (buf) |c| { const digit = %return charToDigit(c, radix); - x = %return math.mulOverflow(T, x, radix); - x = %return math.addOverflow(T, x, digit); + x = %return math.mul(T, x, radix); + x = %return math.add(T, x, digit); } return x; diff --git a/std/math.zig b/std/math.zig index c51e45ca4f..c146272d62 100644 --- a/std/math.zig +++ b/std/math.zig @@ -1,37 +1,64 @@ const assert = @import("debug.zig").assert; pub const Cmp = enum { + Less, Equal, Greater, - Less, }; pub fn min(x: var, y: var) -> @typeOf(x + y) { if (x < y) x else y } +test "math.min" { + assert(min(i32(-1), i32(2)) == -1); +} + pub fn max(x: var, y: var) -> @typeOf(x + y) { if (x > y) x else y } +test "math.max" { + assert(max(i32(-1), i32(2)) == 2); +} + error Overflow; -pub fn mulOverflow(comptime T: type, a: T, b: T) -> %T { +pub fn mul(comptime T: type, a: T, b: T) -> %T { var answer: T = undefined; if (@mulWithOverflow(T, a, b, &answer)) error.Overflow else answer } -pub fn addOverflow(comptime T: type, a: T, b: T) -> %T { + +error Overflow; +pub fn add(comptime T: type, a: T, b: T) -> %T { var answer: T = undefined; if (@addWithOverflow(T, a, b, &answer)) error.Overflow else answer } -pub fn subOverflow(comptime T: type, a: T, b: T) -> %T { + +error Overflow; +pub fn sub(comptime T: type, a: T, b: T) -> %T { var answer: T = undefined; if (@subWithOverflow(T, a, b, &answer)) error.Overflow else answer } -pub fn shlOverflow(comptime T: type, a: T, b: T) -> %T { + +error Overflow; +pub fn shl(comptime T: type, a: T, b: T) -> %T { var answer: T = undefined; if (@shlWithOverflow(T, a, b, &answer)) error.Overflow else answer } +test "math overflow functions" { + testOverflow(); + comptime testOverflow(); +} + +fn testOverflow() { + assert(%%mul(i32, 3, 4) == 12); + assert(%%add(i32, 3, 4) == 7); + assert(%%sub(i32, 3, 4) == -1); + assert(%%shl(i32, 0b11, 4) == 0b110000); +} + + pub fn log(comptime base: usize, value: var) -> @typeOf(value) { const T = @typeOf(value); if (@isInteger(T)) { @@ -47,35 +74,191 @@ pub fn log(comptime base: usize, value: var) -> @typeOf(value) { } } -/// x must be an integer or a float -/// Note that this causes undefined behavior if -/// @typeOf(x).is_signed and x == @minValue(@typeOf(x)). -pub fn abs(x: var) -> @typeOf(x) { +error Overflow; +pub fn absInt(x: var) -> %@typeOf(x) { const T = @typeOf(x); - if (@isInteger(T)) { + comptime assert(@isInteger(T)); // must pass an integer to absInt + comptime assert(T.is_signed); // must pass a signed integer to absInt + if (x == @minValue(@typeOf(x))) + return error.Overflow; + { + @setDebugSafety(this, false); return if (x < 0) -x else x; - } else if (@isFloat(T)) { - @compileError("TODO implement abs for floats"); + } +} + +test "math.absInt" { + testAbsInt(); + comptime testAbsInt(); +} +fn testAbsInt() { + assert(%%absInt(i32(-10)) == 10); + assert(%%absInt(i32(10)) == 10); +} + +pub fn absFloat(x: var) -> @typeOf(x) { + comptime assert(@isFloat(@typeOf(x))); + return if (x < 0) -x else x; +} + +test "math.absFloat" { + testAbsFloat(); + comptime testAbsFloat(); +} +fn testAbsFloat() { + assert(absFloat(f32(-10.0)) == 10.0); + assert(absFloat(f32(10.0)) == 10.0); +} + +error DivisionByZero; +error Overflow; +pub fn divTrunc(comptime T: type, numerator: T, denominator: T) -> %T { + @setDebugSafety(this, false); + if (denominator == 0) + return error.DivisionByZero; + if (@isInteger(T) and T.is_signed and numerator == @minValue(T) and denominator == -1) + return error.Overflow; + return @divTrunc(numerator, denominator); +} + +test "math.divTrunc" { + testDivTrunc(); + comptime testDivTrunc(); +} +fn testDivTrunc() { + assert(%%divTrunc(i32, 5, 3) == 1); + assert(%%divTrunc(i32, -5, 3) == -1); + if (divTrunc(i8, -5, 0)) |_| unreachable else |err| assert(err == error.DivisionByZero); + if (divTrunc(i8, -128, -1)) |_| unreachable else |err| assert(err == error.Overflow); + + assert(%%divTrunc(f32, 5.0, 3.0) == 1.0); + assert(%%divTrunc(f32, -5.0, 3.0) == -1.0); +} + +error DivisionByZero; +error Overflow; +pub fn divFloor(comptime T: type, numerator: T, denominator: T) -> %T { + @setDebugSafety(this, false); + if (denominator == 0) + return error.DivisionByZero; + if (@isInteger(T) and T.is_signed and numerator == @minValue(T) and denominator == -1) + return error.Overflow; + return @divFloor(numerator, denominator); +} + +test "math.divFloor" { + testDivFloor(); + comptime testDivFloor(); +} +fn testDivFloor() { + assert(%%divFloor(i32, 5, 3) == 1); + assert(%%divFloor(i32, -5, 3) == -2); + if (divFloor(i8, -5, 0)) |_| unreachable else |err| assert(err == error.DivisionByZero); + if (divFloor(i8, -128, -1)) |_| unreachable else |err| assert(err == error.Overflow); + + assert(%%divFloor(f32, 5.0, 3.0) == 1.0); + assert(%%divFloor(f32, -5.0, 3.0) == -2.0); +} + +error DivisionByZero; +error Overflow; +error UnexpectedRemainder; +pub fn divExact(comptime T: type, numerator: T, denominator: T) -> %T { + @setDebugSafety(this, false); + if (denominator == 0) + return error.DivisionByZero; + if (@isInteger(T) and T.is_signed and numerator == @minValue(T) and denominator == -1) + return error.Overflow; + const result = @divTrunc(numerator, denominator); + if (result * denominator != numerator) + return error.UnexpectedRemainder; + return result; +} + +test "math.divExact" { + testDivExact(); + comptime testDivExact(); +} +fn testDivExact() { + assert(%%divExact(i32, 10, 5) == 2); + assert(%%divExact(i32, -10, 5) == -2); + if (divExact(i8, -5, 0)) |_| unreachable else |err| assert(err == error.DivisionByZero); + if (divExact(i8, -128, -1)) |_| unreachable else |err| assert(err == error.Overflow); + if (divExact(i32, 5, 2)) |_| unreachable else |err| assert(err == error.UnexpectedRemainder); + + assert(%%divExact(f32, 10.0, 5.0) == 2.0); + assert(%%divExact(f32, -10.0, 5.0) == -2.0); + if (divExact(f32, 5.0, 2.0)) |_| unreachable else |err| assert(err == error.UnexpectedRemainder); +} + +error DivisionByZero; +error NegativeDenominator; +pub fn mod(comptime T: type, numerator: T, denominator: T) -> %T { + @setDebugSafety(this, false); + if (denominator == 0) + return error.DivisionByZero; + if (denominator < 0) + return error.NegativeDenominator; + return @mod(numerator, denominator); +} + +test "math.mod" { + testMod(); + comptime testMod(); +} +fn testMod() { + assert(%%mod(i32, -5, 3) == 1); + assert(%%mod(i32, 5, 3) == 2); + if (mod(i32, 10, -1)) |_| unreachable else |err| assert(err == error.NegativeDenominator); + if (mod(i32, 10, 0)) |_| unreachable else |err| assert(err == error.DivisionByZero); + + assert(%%mod(f32, -5, 3) == 1); + assert(%%mod(f32, 5, 3) == 2); + if (mod(f32, 10, -1)) |_| unreachable else |err| assert(err == error.NegativeDenominator); + if (mod(f32, 10, 0)) |_| unreachable else |err| assert(err == error.DivisionByZero); +} + +error DivisionByZero; +error NegativeDenominator; +pub fn rem(comptime T: type, numerator: T, denominator: T) -> %T { + @setDebugSafety(this, false); + if (denominator == 0) + return error.DivisionByZero; + if (denominator < 0) + return error.NegativeDenominator; + return @rem(numerator, denominator); +} + +test "math.rem" { + testRem(); + comptime testRem(); +} +fn testRem() { + assert(%%rem(i32, -5, 3) == -2); + assert(%%rem(i32, 5, 3) == 2); + if (rem(i32, 10, -1)) |_| unreachable else |err| assert(err == error.NegativeDenominator); + if (rem(i32, 10, 0)) |_| unreachable else |err| assert(err == error.DivisionByZero); + + assert(%%rem(f32, -5, 3) == -2); + assert(%%rem(f32, 5, 3) == 2); + if (rem(f32, 10, -1)) |_| unreachable else |err| assert(err == error.NegativeDenominator); + if (rem(f32, 10, 0)) |_| unreachable else |err| assert(err == error.DivisionByZero); +} + +fn isNan(comptime T: type, x: T) -> bool { + assert(@isFloat(T)); + const bits = floatBits(x); + if (T == f32) { + return (bits & 0x7fffffff) > 0x7f800000; + } else if (T == f64) { + return (bits & (@maxValue(u64) >> 1)) > (u64(0x7ff) << 52); } else { unreachable; } } -fn getReturnTypeForAbs(comptime T: type) -> type { - if (@isInteger(T)) { - return @IntType(false, T.bit_count); - } else { - return T; - } -} -test "testMath" { - testMathImpl(); - comptime testMathImpl(); -} - -fn testMathImpl() { - assert(%%mulOverflow(i32, 3, 4) == 12); - assert(%%addOverflow(i32, 3, 4) == 7); - assert(%%subOverflow(i32, 3, 4) == -1); - assert(%%shlOverflow(i32, 0b11, 4) == 0b110000); +fn floatBits(comptime T: type, x: T) -> @IntType(false, T.bit_count) { + assert(@isFloat(T)); + const uint = @IntType(false, T.bit_count); + return *@intToPtr(&const uint, &x); } diff --git a/std/mem.zig b/std/mem.zig index 591005ad98..7d29ddc007 100644 --- a/std/mem.zig +++ b/std/mem.zig @@ -35,12 +35,12 @@ pub const Allocator = struct { } fn alloc(self: &Allocator, comptime T: type, n: usize) -> %[]T { - const byte_count = %return math.mulOverflow(usize, @sizeOf(T), n); + const byte_count = %return math.mul(usize, @sizeOf(T), n); ([]T)(%return self.allocFn(self, byte_count)) } fn realloc(self: &Allocator, comptime T: type, old_mem: []T, n: usize) -> %[]T { - const byte_count = %return math.mulOverflow(usize, @sizeOf(T), n); + const byte_count = %return math.mul(usize, @sizeOf(T), n); ([]T)(%return self.reallocFn(self, ([]u8)(old_mem), byte_count)) } @@ -333,3 +333,29 @@ fn testWriteIntImpl() { assert(eql(u8, bytes, []u8{ 0x34, 0x12, 0x00, 0x00 })); } + +pub fn min(comptime T: type, slice: []const T) -> T { + var best = slice[0]; + var i: usize = 1; + while (i < slice.len) : (i += 1) { + best = math.min(best, slice[i]); + } + return best; +} + +test "mem.min" { + assert(min(u8, "abcdefg") == 'a'); +} + +pub fn max(comptime T: type, slice: []const T) -> T { + var best = slice[0]; + var i: usize = 1; + while (i < slice.len) : (i += 1) { + best = math.max(best, slice[i]); + } + return best; +} + +test "mem.max" { + assert(max(u8, "abcdefg") == 'g'); +} diff --git a/std/special/builtin.zig b/std/special/builtin.zig index f15f13c45f..a21705d82e 100644 --- a/std/special/builtin.zig +++ b/std/special/builtin.zig @@ -29,3 +29,95 @@ export fn __stack_chk_fail() { } @panic("stack smashing detected"); } + +export fn fmodf(x: f32, y: f32) -> f32 { generic_fmod(f32, x, y) } +export fn fmod(x: f64, y: f64) -> f64 { generic_fmod(f64, x, y) } + +fn generic_fmod(comptime T: type, x: T, y: T) -> T { + //@setDebugSafety(this, false); + const uint = @IntType(false, T.bit_count); + const digits = if (T == f32) 23 else 52; + const exp_bits = if (T == f32) 9 else 12; + const bits_minus_1 = T.bit_count - 1; + const mask = if (T == f32) 0xff else 0x7ff; + var ux = *@ptrCast(&const uint, &x); + var uy = *@ptrCast(&const uint, &y); + var ex = i32((ux >> digits) & mask); + var ey = i32((uy >> digits) & mask); + const sx = if (T == f32) u32(ux & 0x80000000) else i32(ux >> bits_minus_1); + var i: uint = undefined; + + if (uy <<% 1 == 0 or isNan(uint, uy) or ex == mask) + return (x * y) / (x * y); + + if (ux <<% 1 <= uy <<% 1) { + if (ux <<% 1 == uy <<% 1) + return 0 * x; + return x; + } + + // normalize x and y + if (ex == 0) { + i = ux <<% exp_bits; + while (i >> bits_minus_1 == 0) : ({ex -= 1; i <<%= 1}) {} + ux <<%= twosComplementCast(uint, -ex + 1); + } else { + ux &= @maxValue(uint) >> exp_bits; + ux |= 1 <<% digits; + } + if (ey == 0) { + i = uy <<% exp_bits; + while (i >> bits_minus_1 == 0) : ({ey -= 1; i <<%= 1}) {} + uy <<= twosComplementCast(uint, -ey + 1); + } else { + uy &= @maxValue(uint) >> exp_bits; + uy |= 1 <<% digits; + } + + // x mod y + while (ex > ey) : (ex -= 1) { + i = ux -% uy; + if (i >> bits_minus_1 == 0) { + if (i == 0) + return 0 * x; + ux = i; + } + ux <<%= 1; + } + i = ux -% uy; + if (i >> bits_minus_1 == 0) { + if (i == 0) + return 0 * x; + ux = i; + } + while (ux >> digits == 0) : ({ux <<%= 1; ex -= 1}) {} + + // scale result up + if (ex > 0) { + ux -%= 1 <<% digits; + ux |= twosComplementCast(uint, ex) <<% digits; + } else { + ux >>= twosComplementCast(uint, -ex + 1); + } + if (T == f32) { + ux |= sx; + } else { + ux |= uint(sx) <<% bits_minus_1; + } + return *@ptrCast(&const T, &ux); +} + +fn isNan(comptime T: type, bits: T) -> bool { + if (T == u32) { + return (bits & 0x7fffffff) > 0x7f800000; + } else if (T == u64) { + return (bits & (@maxValue(u64) >> 1)) > (u64(0x7ff) <<% 52); + } else { + unreachable; + } +} + +// TODO this should be a builtin function and it shouldn't do a ptr cast +fn twosComplementCast(comptime T: type, src: var) -> T { + return *@ptrCast(&const @IntType(T.is_signed, @typeOf(src).bit_count), &src); +} diff --git a/std/special/zigrt.zig b/std/special/zigrt.zig index f5904050c2..b34d9ce78a 100644 --- a/std/special/zigrt.zig +++ b/std/special/zigrt.zig @@ -1,5 +1,5 @@ // This file contains functions that zig depends on to coordinate between -// multiple .o files. The symbols are defined Weak so that multiple +// multiple .o files. The symbols are defined LinkOnce so that multiple // instances of zig_rt.zig do not conflict with each other. const builtin = @import("builtin"); diff --git a/test/cases/math.zig b/test/cases/math.zig index c5636350ab..b3deb5af09 100644 --- a/test/cases/math.zig +++ b/test/cases/math.zig @@ -1,47 +1,76 @@ const assert = @import("std").debug.assert; -test "exactDivision" { - assert(divExact(55, 11) == 5); -} -fn divExact(a: u32, b: u32) -> u32 { - @divExact(a, b) +test "division" { + testDivision(); + comptime testDivision(); } +fn testDivision() { + assert(div(u32, 13, 3) == 4); + assert(div(f32, 1.0, 2.0) == 0.5); -test "floatDivision" { - assert(fdiv32(12.0, 3.0) == 4.0); + assert(divExact(u32, 55, 11) == 5); + assert(divExact(i32, -55, 11) == -5); + assert(divExact(f32, 55.0, 11.0) == 5.0); + assert(divExact(f32, -55.0, 11.0) == -5.0); + + assert(divFloor(i32, 5, 3) == 1); + assert(divFloor(i32, -5, 3) == -2); + assert(divFloor(f32, 5.0, 3.0) == 1.0); + assert(divFloor(f32, -5.0, 3.0) == -2.0); + assert(divFloor(i32, -0x80000000, -2) == 0x40000000); + assert(divFloor(i32, 0, -0x80000000) == 0); + assert(divFloor(i32, -0x40000001, 0x40000000) == -2); + assert(divFloor(i32, -0x80000000, 1) == -0x80000000); + + assert(divTrunc(i32, 5, 3) == 1); + assert(divTrunc(i32, -5, 3) == -1); + assert(divTrunc(f32, 5.0, 3.0) == 1.0); + assert(divTrunc(f32, -5.0, 3.0) == -1.0); } -fn fdiv32(a: f32, b: f32) -> f32 { +fn div(comptime T: type, a: T, b: T) -> T { a / b } +fn divExact(comptime T: type, a: T, b: T) -> T { + @divExact(a, b) +} +fn divFloor(comptime T: type, a: T, b: T) -> T { + @divFloor(a, b) +} +fn divTrunc(comptime T: type, a: T, b: T) -> T { + @divTrunc(a, b) +} -test "overflowIntrinsics" { +test "@addWithOverflow" { var result: u8 = undefined; assert(@addWithOverflow(u8, 250, 100, &result)); assert(!@addWithOverflow(u8, 100, 150, &result)); assert(result == 250); } -test "shlWithOverflow" { +// TODO test mulWithOverflow +// TODO test subWithOverflow + +test "@shlWithOverflow" { var result: u16 = undefined; assert(@shlWithOverflow(u16, 0b0010111111111111, 3, &result)); assert(!@shlWithOverflow(u16, 0b0010111111111111, 2, &result)); assert(result == 0b1011111111111100); } -test "countLeadingZeroes" { +test "@clz" { assert(@clz(u8(0b00001010)) == 4); assert(@clz(u8(0b10001010)) == 0); assert(@clz(u8(0b00000000)) == 8); } -test "countTrailingZeroes" { +test "@ctz" { assert(@ctz(u8(0b10100000)) == 5); assert(@ctz(u8(0b10001010)) == 1); assert(@ctz(u8(0b00000000)) == 8); } -test "modifyOperators" { - var i : i32 = 0; +test "assignment operators" { + var i: u32 = 0; i += 5; assert(i == 5); i -= 2; assert(i == 3); i *= 20; assert(i == 60); @@ -57,6 +86,8 @@ test "modifyOperators" { } test "threeExprInARow" { + testThreeExprInARow(false, true); + comptime testThreeExprInARow(false, true); } fn testThreeExprInARow(f: bool, t: bool) { assertFalse(f or f or f); @@ -72,13 +103,12 @@ fn testThreeExprInARow(f: bool, t: bool) { assertFalse(!!false); assertFalse(i32(7) != --(i32(7))); } - fn assertFalse(b: bool) { assert(!b); } -test "constNumberLiteral" { +test "const number literal" { const one = 1; const eleven = ten + one; @@ -88,8 +118,9 @@ const ten = 10; -test "unsignedWrapping" { +test "unsigned wrapping" { testUnsignedWrappingEval(@maxValue(u32)); + comptime testUnsignedWrappingEval(@maxValue(u32)); } fn testUnsignedWrappingEval(x: u32) { const zero = x +% 1; @@ -98,8 +129,9 @@ fn testUnsignedWrappingEval(x: u32) { assert(orig == @maxValue(u32)); } -test "signedWrapping" { +test "signed wrapping" { testSignedWrappingEval(@maxValue(i32)); + comptime testSignedWrappingEval(@maxValue(i32)); } fn testSignedWrappingEval(x: i32) { const min_val = x +% 1; @@ -108,8 +140,9 @@ fn testSignedWrappingEval(x: i32) { assert(max_val == @maxValue(i32)); } -test "negationWrapping" { +test "negation wrapping" { testNegationWrappingEval(@minValue(i16)); + comptime testNegationWrappingEval(@minValue(i16)); } fn testNegationWrappingEval(x: i16) { assert(x == -32768); @@ -117,20 +150,25 @@ fn testNegationWrappingEval(x: i16) { assert(neg == -32768); } -test "shlWrapping" { +test "shift left wrapping" { testShlWrappingEval(@maxValue(u16)); + comptime testShlWrappingEval(@maxValue(u16)); } fn testShlWrappingEval(x: u16) { const shifted = x <<% 1; assert(shifted == 65534); } -test "unsigned64BitDivision" { - const result = div(1152921504606846976, 34359738365); +test "unsigned 64-bit division" { + test_u64_div(); + comptime test_u64_div(); +} +fn test_u64_div() { + const result = divWithResult(1152921504606846976, 34359738365); assert(result.quotient == 33554432); assert(result.remainder == 100663296); } -fn div(a: u64, b: u64) -> DivResult { +fn divWithResult(a: u64, b: u64) -> DivResult { DivResult { .quotient = a / b, .remainder = a % b, @@ -141,7 +179,7 @@ const DivResult = struct { remainder: u64, }; -test "binaryNot" { +test "binary not" { assert(comptime {~u16(0b1010101010101010) == 0b0101010101010101}); assert(comptime {~u64(2147483647) == 18446744071562067968}); testBinaryNot(0b1010101010101010); @@ -151,7 +189,7 @@ fn testBinaryNot(x: u16) { assert(~x == 0b0101010101010101); } -test "smallIntAddition" { +test "small int addition" { var x: @IntType(false, 2) = 0; assert(x == 0); @@ -170,7 +208,7 @@ test "smallIntAddition" { assert(result == 0); } -test "testFloatEquality" { +test "float equality" { const x: f64 = 0.012; const y: f64 = x + 1.0; diff --git a/test/cases/misc.zig b/test/cases/misc.zig index 56008a344b..a5f19242d9 100644 --- a/test/cases/misc.zig +++ b/test/cases/misc.zig @@ -49,6 +49,11 @@ test "@IntType builtin" { assert(!usize.is_signed); } +test "floating point primitive bit counts" { + assert(f32.bit_count == 32); + assert(f64.bit_count == 64); +} + const u1 = @IntType(false, 1); const u63 = @IntType(false, 63); const i1 = @IntType(true, 1); diff --git a/test/compile_errors.zig b/test/compile_errors.zig index f021898763..ff4003819a 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -702,7 +702,7 @@ pub fn addCases(cases: &tests.CompileErrorContext) { cases.add("division by zero", \\const lit_int_x = 1 / 0; \\const lit_float_x = 1.0 / 0.0; - \\const int_x = i32(1) / i32(0); + \\const int_x = u32(1) / u32(0); \\const float_x = f32(1.0) / f32(0.0); \\ \\export fn entry1() -> usize { @sizeOf(@typeOf(lit_int_x)) } @@ -792,7 +792,7 @@ pub fn addCases(cases: &tests.CompileErrorContext) { cases.add("compile time division by zero", \\const y = foo(0); - \\fn foo(x: i32) -> i32 { + \\fn foo(x: u32) -> u32 { \\ 1 / x \\} \\ @@ -1709,4 +1709,18 @@ pub fn addCases(cases: &tests.CompileErrorContext) { \\extern fn quux(usize); , ".tmp_source.zig:4:8: error: unable to inline function"); + + cases.add("signed integer division", + \\export fn foo(a: i32, b: i32) -> i32 { + \\ a / b + \\} + , + ".tmp_source.zig:2:7: error: division with 'i32' and 'i32': signed integers must use @divTrunc, @divFloor, or @divExact"); + + cases.add("signed integer remainder division", + \\export fn foo(a: i32, b: i32) -> i32 { + \\ a % b + \\} + , + ".tmp_source.zig:2:7: error: remainder division with 'i32' and 'i32': signed integers must use @rem or @mod"); } diff --git a/test/debug_safety.zig b/test/debug_safety.zig index 0995a0e544..3df411f7cd 100644 --- a/test/debug_safety.zig +++ b/test/debug_safety.zig @@ -97,7 +97,7 @@ pub fn addCases(cases: &tests.CompareOutputContext) { \\ if (x == 32767) return error.Whatever; \\} \\fn div(a: i16, b: i16) -> i16 { - \\ a / b + \\ @divTrunc(a, b) \\} ); @@ -141,7 +141,7 @@ pub fn addCases(cases: &tests.CompareOutputContext) { \\ const x = div0(999, 0); \\} \\fn div0(a: i32, b: i32) -> i32 { - \\ a / b + \\ @divTrunc(a, b) \\} );