From b5459eb987d89c4759c31123a7baa0a0d962c024 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 15 Apr 2018 13:21:52 -0400 Subject: [PATCH] add @sqrt built-in function See #767 --- CMakeLists.txt | 1 - doc/langref.html.in | 12 +- src/all_types.hpp | 12 +- src/analyze.cpp | 9 +- src/bigfloat.cpp | 4 + src/bigfloat.hpp | 1 + src/codegen.cpp | 24 +++- src/ir.cpp | 94 +++++++++++++ src/ir_print.cpp | 15 ++ std/math/sqrt.zig | 295 +++++---------------------------------- std/math/x86_64/sqrt.zig | 15 -- std/special/builtin.zig | 209 +++++++++++++++++++++++++++ test/cases/math.zig | 16 +++ 13 files changed, 419 insertions(+), 288 deletions(-) delete mode 100644 std/math/x86_64/sqrt.zig diff --git a/CMakeLists.txt b/CMakeLists.txt index 021fd43cf0..bf90a7ef46 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -498,7 +498,6 @@ set(ZIG_STD_FILES "math/tan.zig" "math/tanh.zig" "math/trunc.zig" - "math/x86_64/sqrt.zig" "mem.zig" "net.zig" "os/child_process.zig" diff --git a/doc/langref.html.in b/doc/langref.html.in index 856d62f142..d9436e55b7 100644 --- a/doc/langref.html.in +++ b/doc/langref.html.in @@ -4669,6 +4669,16 @@ pub const FloatMode = enum { The result is a target-specific compile time constant.

{#header_close#} + {#header_open|@sqrt#} +
@sqrt(comptime T: type, value: T) -> T
+

+ Performs the square root of a floating point number. Uses a dedicated hardware instruction + when available. Currently only supports f32 and f64 at runtime. f128 at runtime is TODO. +

+

+ This is a low-level intrinsic. Most code can use std.math.sqrt instead. +

+ {#header_close#} {#header_open|@subWithOverflow#}
@subWithOverflow(comptime T: type, a: T, b: T, result: &T) -> bool

@@ -5991,7 +6001,7 @@ hljs.registerLanguage("zig", function(t) { a = t.IR + "\\s*\\(", c = { keyword: "const align var extern stdcallcc nakedcc volatile export pub noalias inline struct packed enum union break return try catch test continue unreachable comptime and or asm defer errdefer if else switch while for fn use bool f32 f64 void type noreturn error i8 u8 i16 u16 i32 u32 i64 u64 isize usize i8w u8w i16w i32w u32w i64w u64w isizew usizew c_short c_ushort c_int c_uint c_long c_ulong c_longlong c_ulonglong", - built_in: "breakpoint returnAddress frameAddress fieldParentPtr setFloatMode IntType OpaqueType compileError compileLog setCold setRuntimeSafety setEvalBranchQuota offsetOf memcpy inlineCall setGlobalLinkage setGlobalSection divTrunc divFloor enumTagName intToPtr ptrToInt panic canImplicitCast ptrCast bitCast rem mod memset sizeOf alignOf alignCast maxValue minValue memberCount memberName memberType typeOf addWithOverflow subWithOverflow mulWithOverflow shlWithOverflow shlExact shrExact cInclude cDefine cUndef ctz clz import cImport errorName embedFile cmpxchg fence divExact truncate atomicRmw", + built_in: "breakpoint returnAddress frameAddress fieldParentPtr setFloatMode IntType OpaqueType compileError compileLog setCold setRuntimeSafety setEvalBranchQuota offsetOf memcpy inlineCall setGlobalLinkage setGlobalSection divTrunc divFloor enumTagName intToPtr ptrToInt panic canImplicitCast ptrCast bitCast rem mod memset sizeOf alignOf alignCast maxValue minValue memberCount memberName memberType typeOf addWithOverflow subWithOverflow mulWithOverflow shlWithOverflow shlExact shrExact cInclude cDefine cUndef ctz clz import cImport errorName embedFile cmpxchg fence divExact truncate atomicRmw sqrt", literal: "true false null undefined" }, n = [e, t.CLCM, t.CBCM, s, r]; diff --git a/src/all_types.hpp b/src/all_types.hpp index d27a5c7a1c..b43214a60e 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1317,6 +1317,7 @@ enum BuiltinFnId { BuiltinFnIdDivFloor, BuiltinFnIdRem, BuiltinFnIdMod, + BuiltinFnIdSqrt, BuiltinFnIdTruncate, BuiltinFnIdIntType, BuiltinFnIdSetCold, @@ -1413,6 +1414,7 @@ enum ZigLLVMFnId { ZigLLVMFnIdOverflowArithmetic, ZigLLVMFnIdFloor, ZigLLVMFnIdCeil, + ZigLLVMFnIdSqrt, }; enum AddSubMul { @@ -1433,7 +1435,7 @@ struct ZigLLVMFnKey { } clz; struct { uint32_t bit_count; - } floor_ceil; + } floating; struct { AddSubMul add_sub_mul; uint32_t bit_count; @@ -2047,6 +2049,7 @@ enum IrInstructionId { IrInstructionIdAddImplicitReturnType, IrInstructionIdMergeErrRetTraces, IrInstructionIdMarkErrRetTracePtr, + IrInstructionIdSqrt, }; struct IrInstruction { @@ -3036,6 +3039,13 @@ struct IrInstructionMarkErrRetTracePtr { IrInstruction *err_ret_trace_ptr; }; +struct IrInstructionSqrt { + IrInstruction base; + + IrInstruction *type; + IrInstruction *op; +}; + static const size_t slice_ptr_index = 0; static const size_t slice_len_index = 1; diff --git a/src/analyze.cpp b/src/analyze.cpp index c73e6b39e3..9092da6e3b 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -5801,9 +5801,11 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) { case ZigLLVMFnIdClz: return (uint32_t)(x.data.clz.bit_count) * (uint32_t)2428952817; case ZigLLVMFnIdFloor: - return (uint32_t)(x.data.floor_ceil.bit_count) * (uint32_t)1899859168; + return (uint32_t)(x.data.floating.bit_count) * (uint32_t)1899859168; case ZigLLVMFnIdCeil: - return (uint32_t)(x.data.floor_ceil.bit_count) * (uint32_t)1953839089; + return (uint32_t)(x.data.floating.bit_count) * (uint32_t)1953839089; + case ZigLLVMFnIdSqrt: + return (uint32_t)(x.data.floating.bit_count) * (uint32_t)2225366385; case ZigLLVMFnIdOverflowArithmetic: return ((uint32_t)(x.data.overflow_arithmetic.bit_count) * 87135777) + ((uint32_t)(x.data.overflow_arithmetic.add_sub_mul) * 31640542) + @@ -5822,7 +5824,8 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) { return a.data.clz.bit_count == b.data.clz.bit_count; case ZigLLVMFnIdFloor: case ZigLLVMFnIdCeil: - return a.data.floor_ceil.bit_count == b.data.floor_ceil.bit_count; + case ZigLLVMFnIdSqrt: + return a.data.floating.bit_count == b.data.floating.bit_count; case ZigLLVMFnIdOverflowArithmetic: return (a.data.overflow_arithmetic.bit_count == b.data.overflow_arithmetic.bit_count) && (a.data.overflow_arithmetic.add_sub_mul == b.data.overflow_arithmetic.add_sub_mul) && diff --git a/src/bigfloat.cpp b/src/bigfloat.cpp index 2cab9658e8..dcb6db61db 100644 --- a/src/bigfloat.cpp +++ b/src/bigfloat.cpp @@ -181,3 +181,7 @@ bool bigfloat_has_fraction(const BigFloat *bigfloat) { f128M_roundToInt(&bigfloat->value, softfloat_round_minMag, false, &floored); return !f128M_eq(&floored, &bigfloat->value); } + +void bigfloat_sqrt(BigFloat *dest, const BigFloat *op) { + f128M_sqrt(&op->value, &dest->value); +} diff --git a/src/bigfloat.hpp b/src/bigfloat.hpp index 894b252c3a..e212c30c87 100644 --- a/src/bigfloat.hpp +++ b/src/bigfloat.hpp @@ -42,6 +42,7 @@ void bigfloat_div_trunc(BigFloat *dest, const BigFloat *op1, const BigFloat *op2 void bigfloat_div_floor(BigFloat *dest, const BigFloat *op1, const BigFloat *op2); void bigfloat_rem(BigFloat *dest, const BigFloat *op1, const BigFloat *op2); void bigfloat_mod(BigFloat *dest, const BigFloat *op1, const BigFloat *op2); +void bigfloat_sqrt(BigFloat *dest, const BigFloat *op); void bigfloat_append_buf(Buf *buf, const BigFloat *op); Cmp bigfloat_cmp(const BigFloat *op1, const BigFloat *op2); diff --git a/src/codegen.cpp b/src/codegen.cpp index a58832f983..b45214a5e0 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -717,12 +717,12 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, TypeTableEntry *type_entry, return fn_val; } -static LLVMValueRef get_floor_ceil_fn(CodeGen *g, TypeTableEntry *type_entry, ZigLLVMFnId fn_id) { +static LLVMValueRef get_float_fn(CodeGen *g, TypeTableEntry *type_entry, ZigLLVMFnId fn_id) { assert(type_entry->id == TypeTableEntryIdFloat); ZigLLVMFnKey key = {}; key.id = fn_id; - key.data.floor_ceil.bit_count = (uint32_t)type_entry->data.floating.bit_count; + key.data.floating.bit_count = (uint32_t)type_entry->data.floating.bit_count; auto existing_entry = g->llvm_fn_table.maybe_get(key); if (existing_entry) @@ -733,6 +733,8 @@ static LLVMValueRef get_floor_ceil_fn(CodeGen *g, TypeTableEntry *type_entry, Zi name = "floor"; } else if (fn_id == ZigLLVMFnIdCeil) { name = "ceil"; + } else if (fn_id == ZigLLVMFnIdSqrt) { + name = "sqrt"; } else { zig_unreachable(); } @@ -1900,7 +1902,7 @@ static LLVMValueRef gen_floor(CodeGen *g, LLVMValueRef val, TypeTableEntry *type if (type_entry->id == TypeTableEntryIdInt) return val; - LLVMValueRef floor_fn = get_floor_ceil_fn(g, type_entry, ZigLLVMFnIdFloor); + LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloor); return LLVMBuildCall(g->builder, floor_fn, &val, 1, ""); } @@ -1908,7 +1910,7 @@ static LLVMValueRef gen_ceil(CodeGen *g, LLVMValueRef val, TypeTableEntry *type_ if (type_entry->id == TypeTableEntryIdInt) return val; - LLVMValueRef ceil_fn = get_floor_ceil_fn(g, type_entry, ZigLLVMFnIdCeil); + LLVMValueRef ceil_fn = get_float_fn(g, type_entry, ZigLLVMFnIdCeil); return LLVMBuildCall(g->builder, ceil_fn, &val, 1, ""); } @@ -3247,10 +3249,12 @@ static LLVMValueRef get_int_builtin_fn(CodeGen *g, TypeTableEntry *int_type, Bui fn_name = "cttz"; key.id = ZigLLVMFnIdCtz; key.data.ctz.bit_count = (uint32_t)int_type->data.integral.bit_count; - } else { + } else if (fn_id == BuiltinFnIdClz) { fn_name = "ctlz"; key.id = ZigLLVMFnIdClz; key.data.clz.bit_count = (uint32_t)int_type->data.integral.bit_count; + } else { + zig_unreachable(); } auto existing_entry = g->llvm_fn_table.maybe_get(key); @@ -4402,6 +4406,13 @@ static LLVMValueRef ir_render_mark_err_ret_trace_ptr(CodeGen *g, IrExecutable *e return nullptr; } +static LLVMValueRef ir_render_sqrt(CodeGen *g, IrExecutable *executable, IrInstructionSqrt *instruction) { + LLVMValueRef op = ir_llvm_value(g, instruction->op); + assert(instruction->base.value.type->id == TypeTableEntryIdFloat); + LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdSqrt); + return LLVMBuildCall(g->builder, fn_val, &op, 1, ""); +} + static void set_debug_location(CodeGen *g, IrInstruction *instruction) { AstNode *source_node = instruction->source_node; Scope *scope = instruction->scope; @@ -4623,6 +4634,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_merge_err_ret_traces(g, executable, (IrInstructionMergeErrRetTraces *)instruction); case IrInstructionIdMarkErrRetTracePtr: return ir_render_mark_err_ret_trace_ptr(g, executable, (IrInstructionMarkErrRetTracePtr *)instruction); + case IrInstructionIdSqrt: + return ir_render_sqrt(g, executable, (IrInstructionSqrt *)instruction); } zig_unreachable(); } @@ -6109,6 +6122,7 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdDivFloor, "divFloor", 2); create_builtin_fn(g, BuiltinFnIdRem, "rem", 2); create_builtin_fn(g, BuiltinFnIdMod, "mod", 2); + create_builtin_fn(g, BuiltinFnIdSqrt, "sqrt", 2); create_builtin_fn(g, BuiltinFnIdInlineCall, "inlineCall", SIZE_MAX); create_builtin_fn(g, BuiltinFnIdNoInlineCall, "noInlineCall", SIZE_MAX); create_builtin_fn(g, BuiltinFnIdTypeId, "typeId", 1); diff --git a/src/ir.cpp b/src/ir.cpp index 0fac1bd219..08229b8bb3 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -733,6 +733,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionMarkErrRetTraceP return IrInstructionIdMarkErrRetTracePtr; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionSqrt *) { + return IrInstructionIdSqrt; +} + template static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) { T *special_instruction = allocate(1); @@ -2731,6 +2735,17 @@ static IrInstruction *ir_build_mark_err_ret_trace_ptr(IrBuilder *irb, Scope *sco return &instruction->base; } +static IrInstruction *ir_build_sqrt(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type, IrInstruction *op) { + IrInstructionSqrt *instruction = ir_build_instruction(irb, scope, source_node); + instruction->type = type; + instruction->op = op; + + if (type != nullptr) ir_ref_instruction(type, irb->current_basic_block); + ir_ref_instruction(op, irb->current_basic_block); + + return &instruction->base; +} + static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_scope, size_t *results) { results[ReturnKindUnconditional] = 0; results[ReturnKindError] = 0; @@ -3845,6 +3860,20 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo return ir_build_bin_op(irb, scope, node, IrBinOpRemMod, arg0_value, arg1_value, true); } + case BuiltinFnIdSqrt: + { + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope); + if (arg0_value == irb->codegen->invalid_instruction) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + IrInstruction *arg1_value = ir_gen_node(irb, arg1_node, scope); + if (arg1_value == irb->codegen->invalid_instruction) + return arg1_value; + + return ir_build_sqrt(irb, scope, node, arg0_value, arg1_value); + } case BuiltinFnIdTruncate: { AstNode *arg0_node = node->data.fn_call_expr.params.at(0); @@ -18031,6 +18060,68 @@ static TypeTableEntry *ir_analyze_instruction_mark_err_ret_trace_ptr(IrAnalyze * return result->value.type; } +static TypeTableEntry *ir_analyze_instruction_sqrt(IrAnalyze *ira, IrInstructionSqrt *instruction) { + TypeTableEntry *float_type = ir_resolve_type(ira, instruction->type->other); + if (type_is_invalid(float_type)) + return ira->codegen->builtin_types.entry_invalid; + + IrInstruction *op = instruction->op->other; + if (type_is_invalid(op->value.type)) + return ira->codegen->builtin_types.entry_invalid; + + bool ok_type = float_type->id == TypeTableEntryIdNumLitFloat || float_type->id == TypeTableEntryIdFloat; + if (!ok_type) { + ir_add_error(ira, instruction->type, buf_sprintf("@sqrt does not support type '%s'", buf_ptr(&float_type->name))); + return ira->codegen->builtin_types.entry_invalid; + } + + IrInstruction *casted_op = ir_implicit_cast(ira, op, float_type); + if (type_is_invalid(casted_op->value.type)) + return ira->codegen->builtin_types.entry_invalid; + + if (instr_is_comptime(casted_op)) { + ConstExprValue *val = ir_resolve_const(ira, casted_op, UndefBad); + if (!val) + return ira->codegen->builtin_types.entry_invalid; + + ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base); + + if (float_type->id == TypeTableEntryIdNumLitFloat) { + bigfloat_sqrt(&out_val->data.x_bigfloat, &val->data.x_bigfloat); + } else if (float_type->id == TypeTableEntryIdFloat) { + switch (float_type->data.floating.bit_count) { + case 32: + out_val->data.x_f32 = sqrtf(val->data.x_f32); + break; + case 64: + out_val->data.x_f64 = sqrt(val->data.x_f64); + break; + case 128: + f128M_sqrt(&val->data.x_f128, &out_val->data.x_f128); + break; + default: + zig_unreachable(); + } + } else { + zig_unreachable(); + } + + return float_type; + } + + assert(float_type->id == TypeTableEntryIdFloat); + if (float_type->data.floating.bit_count != 32 && float_type->data.floating.bit_count != 64) { + ir_add_error(ira, instruction->type, buf_sprintf("compiler TODO: add implementation of sqrt for '%s'", buf_ptr(&float_type->name))); + return ira->codegen->builtin_types.entry_invalid; + } + + IrInstruction *result = ir_build_sqrt(&ira->new_irb, instruction->base.scope, + instruction->base.source_node, nullptr, casted_op); + ir_link_new_instruction(result, &instruction->base); + result->value.type = float_type; + return result->value.type; +} + static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstruction *instruction) { switch (instruction->id) { case IrInstructionIdInvalid: @@ -18278,6 +18369,8 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi return ir_analyze_instruction_merge_err_ret_traces(ira, (IrInstructionMergeErrRetTraces *)instruction); case IrInstructionIdMarkErrRetTracePtr: return ir_analyze_instruction_mark_err_ret_trace_ptr(ira, (IrInstructionMarkErrRetTracePtr *)instruction); + case IrInstructionIdSqrt: + return ir_analyze_instruction_sqrt(ira, (IrInstructionSqrt *)instruction); } zig_unreachable(); } @@ -18490,6 +18583,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdCoroFree: case IrInstructionIdCoroPromise: case IrInstructionIdPromiseResultType: + case IrInstructionIdSqrt: return false; case IrInstructionIdAsm: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index 99f79ff75e..5f8dd60187 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -1204,6 +1204,18 @@ static void ir_print_mark_err_ret_trace_ptr(IrPrint *irp, IrInstructionMarkErrRe fprintf(irp->f, ")"); } +static void ir_print_sqrt(IrPrint *irp, IrInstructionSqrt *instruction) { + fprintf(irp->f, "@sqrt("); + if (instruction->type != nullptr) { + ir_print_other_instruction(irp, instruction->type); + } else { + fprintf(irp->f, "null"); + } + fprintf(irp->f, ","); + ir_print_other_instruction(irp, instruction->op); + fprintf(irp->f, ")"); +} + static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { ir_print_prefix(irp, instruction); switch (instruction->id) { @@ -1590,6 +1602,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdMarkErrRetTracePtr: ir_print_mark_err_ret_trace_ptr(irp, (IrInstructionMarkErrRetTracePtr *)instruction); break; + case IrInstructionIdSqrt: + ir_print_sqrt(irp, (IrInstructionSqrt *)instruction); + break; } fprintf(irp->f, "\n"); } diff --git a/std/math/sqrt.zig b/std/math/sqrt.zig index 690f8b6901..982bd28b72 100644 --- a/std/math/sqrt.zig +++ b/std/math/sqrt.zig @@ -14,26 +14,8 @@ const TypeId = builtin.TypeId; pub fn sqrt(x: var) (if (@typeId(@typeOf(x)) == TypeId.Int) @IntType(false, @typeOf(x).bit_count / 2) else @typeOf(x)) { const T = @typeOf(x); switch (@typeId(T)) { - TypeId.FloatLiteral => { - return T(sqrt64(x)); - }, - TypeId.Float => { - switch (T) { - f32 => { - switch (builtin.arch) { - builtin.Arch.x86_64 => return @import("x86_64/sqrt.zig").sqrt32(x), - else => return sqrt32(x), - } - }, - f64 => { - switch (builtin.arch) { - builtin.Arch.x86_64 => return @import("x86_64/sqrt.zig").sqrt64(x), - else => return sqrt64(x), - } - }, - else => @compileError("sqrt not implemented for " ++ @typeName(T)), - } - }, + TypeId.FloatLiteral => return T(@sqrt(f64, x)), // TODO upgrade to f128 + TypeId.Float => return @sqrt(T, x), TypeId.IntLiteral => comptime { if (x > @maxValue(u128)) { @compileError("sqrt not implemented for comptime_int greater than 128 bits"); @@ -43,269 +25,58 @@ pub fn sqrt(x: var) (if (@typeId(@typeOf(x)) == TypeId.Int) @IntType(false, @typ } return T(sqrt_int(u128, x)); }, - TypeId.Int => { - return sqrt_int(T, x); - }, + TypeId.Int => return sqrt_int(T, x), else => @compileError("sqrt not implemented for " ++ @typeName(T)), } } -fn sqrt32(x: f32) f32 { - const tiny: f32 = 1.0e-30; - const sign: i32 = @bitCast(i32, u32(0x80000000)); - var ix: i32 = @bitCast(i32, x); - - if ((ix & 0x7F800000) == 0x7F800000) { - return x * x + x; // sqrt(nan) = nan, sqrt(+inf) = +inf, sqrt(-inf) = snan - } - - // zero - if (ix <= 0) { - if (ix & ~sign == 0) { - return x; // sqrt (+-0) = +-0 - } - if (ix < 0) { - return math.snan(f32); - } - } - - // normalize - var m = ix >> 23; - if (m == 0) { - // subnormal - var i: i32 = 0; - while (ix & 0x00800000 == 0) : (i += 1) { - ix <<= 1; - } - m -= i - 1; - } - - m -= 127; // unbias exponent - ix = (ix & 0x007FFFFF) | 0x00800000; - - if (m & 1 != 0) { // odd m, double x to even - ix += ix; - } - - m >>= 1; // m = [m / 2] - - // sqrt(x) bit by bit - ix += ix; - var q: i32 = 0; // q = sqrt(x) - var s: i32 = 0; - var r: i32 = 0x01000000; // r = moving bit right -> left - - while (r != 0) { - const t = s + r; - if (t <= ix) { - s = t + r; - ix -= t; - q += r; - } - ix += ix; - r >>= 1; - } - - // floating add to find rounding direction - if (ix != 0) { - var z = 1.0 - tiny; // inexact - if (z >= 1.0) { - z = 1.0 + tiny; - if (z > 1.0) { - q += 2; - } else { - if (q & 1 != 0) { - q += 1; - } - } - } - } - - ix = (q >> 1) + 0x3f000000; - ix += m << 23; - return @bitCast(f32, ix); -} - -// NOTE: The original code is full of implicit signed -> unsigned assumptions and u32 wraparound -// behaviour. Most intermediate i32 values are changed to u32 where appropriate but there are -// potentially some edge cases remaining that are not handled in the same way. -fn sqrt64(x: f64) f64 { - const tiny: f64 = 1.0e-300; - const sign: u32 = 0x80000000; - const u = @bitCast(u64, x); - - var ix0 = u32(u >> 32); - var ix1 = u32(u & 0xFFFFFFFF); - - // sqrt(nan) = nan, sqrt(+inf) = +inf, sqrt(-inf) = nan - if (ix0 & 0x7FF00000 == 0x7FF00000) { - return x * x + x; - } - - // sqrt(+-0) = +-0 - if (x == 0.0) { - return x; - } - // sqrt(-ve) = snan - if (ix0 & sign != 0) { - return math.snan(f64); - } - - // normalize x - var m = i32(ix0 >> 20); - if (m == 0) { - // subnormal - while (ix0 == 0) { - m -= 21; - ix0 |= ix1 >> 11; - ix1 <<= 21; - } - - // subnormal - var i: u32 = 0; - while (ix0 & 0x00100000 == 0) : (i += 1) { - ix0 <<= 1; - } - m -= i32(i) - 1; - ix0 |= ix1 >> u5(32 - i); - ix1 <<= u5(i); - } - - // unbias exponent - m -= 1023; - ix0 = (ix0 & 0x000FFFFF) | 0x00100000; - if (m & 1 != 0) { - ix0 += ix0 + (ix1 >> 31); - ix1 = ix1 +% ix1; - } - m >>= 1; - - // sqrt(x) bit by bit - ix0 += ix0 + (ix1 >> 31); - ix1 = ix1 +% ix1; - - var q: u32 = 0; - var q1: u32 = 0; - var s0: u32 = 0; - var s1: u32 = 0; - var r: u32 = 0x00200000; - var t: u32 = undefined; - var t1: u32 = undefined; - - while (r != 0) { - t = s0 +% r; - if (t <= ix0) { - s0 = t + r; - ix0 -= t; - q += r; - } - ix0 = ix0 +% ix0 +% (ix1 >> 31); - ix1 = ix1 +% ix1; - r >>= 1; - } - - r = sign; - while (r != 0) { - t = s1 +% r; - t = s0; - if (t < ix0 or (t == ix0 and t1 <= ix1)) { - s1 = t1 +% r; - if (t1 & sign == sign and s1 & sign == 0) { - s0 += 1; - } - ix0 -= t; - if (ix1 < t1) { - ix0 -= 1; - } - ix1 = ix1 -% t1; - q1 += r; - } - ix0 = ix0 +% ix0 +% (ix1 >> 31); - ix1 = ix1 +% ix1; - r >>= 1; - } - - // rounding direction - if (ix0 | ix1 != 0) { - var z = 1.0 - tiny; // raise inexact - if (z >= 1.0) { - z = 1.0 + tiny; - if (q1 == 0xFFFFFFFF) { - q1 = 0; - q += 1; - } else if (z > 1.0) { - if (q1 == 0xFFFFFFFE) { - q += 1; - } - q1 += 2; - } else { - q1 += q1 & 1; - } - } - } - - ix0 = (q >> 1) + 0x3FE00000; - ix1 = q1 >> 1; - if (q & 1 != 0) { - ix1 |= 0x80000000; - } - - // NOTE: musl here appears to rely on signed twos-complement wraparound. +% has the same - // behaviour at least. - var iix0 = i32(ix0); - iix0 = iix0 +% (m << 20); - - const uz = (u64(iix0) << 32) | ix1; - return @bitCast(f64, uz); -} - test "math.sqrt" { - assert(sqrt(f32(0.0)) == sqrt32(0.0)); - assert(sqrt(f64(0.0)) == sqrt64(0.0)); + assert(sqrt(f32(0.0)) == @sqrt(f32, 0.0)); + assert(sqrt(f64(0.0)) == @sqrt(f64, 0.0)); } test "math.sqrt32" { const epsilon = 0.000001; - assert(sqrt32(0.0) == 0.0); - assert(math.approxEq(f32, sqrt32(2.0), 1.414214, epsilon)); - assert(math.approxEq(f32, sqrt32(3.6), 1.897367, epsilon)); - assert(sqrt32(4.0) == 2.0); - assert(math.approxEq(f32, sqrt32(7.539840), 2.745877, epsilon)); - assert(math.approxEq(f32, sqrt32(19.230934), 4.385309, epsilon)); - assert(sqrt32(64.0) == 8.0); - assert(math.approxEq(f32, sqrt32(64.1), 8.006248, epsilon)); - assert(math.approxEq(f32, sqrt32(8942.230469), 94.563370, epsilon)); + assert(@sqrt(f32, 0.0) == 0.0); + assert(math.approxEq(f32, @sqrt(f32, 2.0), 1.414214, epsilon)); + assert(math.approxEq(f32, @sqrt(f32, 3.6), 1.897367, epsilon)); + assert(@sqrt(f32, 4.0) == 2.0); + assert(math.approxEq(f32, @sqrt(f32, 7.539840), 2.745877, epsilon)); + assert(math.approxEq(f32, @sqrt(f32, 19.230934), 4.385309, epsilon)); + assert(@sqrt(f32, 64.0) == 8.0); + assert(math.approxEq(f32, @sqrt(f32, 64.1), 8.006248, epsilon)); + assert(math.approxEq(f32, @sqrt(f32, 8942.230469), 94.563370, epsilon)); } test "math.sqrt64" { const epsilon = 0.000001; - assert(sqrt64(0.0) == 0.0); - assert(math.approxEq(f64, sqrt64(2.0), 1.414214, epsilon)); - assert(math.approxEq(f64, sqrt64(3.6), 1.897367, epsilon)); - assert(sqrt64(4.0) == 2.0); - assert(math.approxEq(f64, sqrt64(7.539840), 2.745877, epsilon)); - assert(math.approxEq(f64, sqrt64(19.230934), 4.385309, epsilon)); - assert(sqrt64(64.0) == 8.0); - assert(math.approxEq(f64, sqrt64(64.1), 8.006248, epsilon)); - assert(math.approxEq(f64, sqrt64(8942.230469), 94.563367, epsilon)); + assert(@sqrt(f64, 0.0) == 0.0); + assert(math.approxEq(f64, @sqrt(f64, 2.0), 1.414214, epsilon)); + assert(math.approxEq(f64, @sqrt(f64, 3.6), 1.897367, epsilon)); + assert(@sqrt(f64, 4.0) == 2.0); + assert(math.approxEq(f64, @sqrt(f64, 7.539840), 2.745877, epsilon)); + assert(math.approxEq(f64, @sqrt(f64, 19.230934), 4.385309, epsilon)); + assert(@sqrt(f64, 64.0) == 8.0); + assert(math.approxEq(f64, @sqrt(f64, 64.1), 8.006248, epsilon)); + assert(math.approxEq(f64, @sqrt(f64, 8942.230469), 94.563367, epsilon)); } test "math.sqrt32.special" { - assert(math.isPositiveInf(sqrt32(math.inf(f32)))); - assert(sqrt32(0.0) == 0.0); - assert(sqrt32(-0.0) == -0.0); - assert(math.isNan(sqrt32(-1.0))); - assert(math.isNan(sqrt32(math.nan(f32)))); + assert(math.isPositiveInf(@sqrt(f32, math.inf(f32)))); + assert(@sqrt(f32, 0.0) == 0.0); + assert(@sqrt(f32, -0.0) == -0.0); + assert(math.isNan(@sqrt(f32, -1.0))); + assert(math.isNan(@sqrt(f32, math.nan(f32)))); } test "math.sqrt64.special" { - assert(math.isPositiveInf(sqrt64(math.inf(f64)))); - assert(sqrt64(0.0) == 0.0); - assert(sqrt64(-0.0) == -0.0); - assert(math.isNan(sqrt64(-1.0))); - assert(math.isNan(sqrt64(math.nan(f64)))); + assert(math.isPositiveInf(@sqrt(f64, math.inf(f64)))); + assert(@sqrt(f64, 0.0) == 0.0); + assert(@sqrt(f64, -0.0) == -0.0); + assert(math.isNan(@sqrt(f64, -1.0))); + assert(math.isNan(@sqrt(f64, math.nan(f64)))); } fn sqrt_int(comptime T: type, value: T) @IntType(false, T.bit_count / 2) { diff --git a/std/math/x86_64/sqrt.zig b/std/math/x86_64/sqrt.zig deleted file mode 100644 index ad9ce0c96c..0000000000 --- a/std/math/x86_64/sqrt.zig +++ /dev/null @@ -1,15 +0,0 @@ -pub fn sqrt32(x: f32) f32 { - return asm ( - \\sqrtss %%xmm0, %%xmm0 - : [ret] "={xmm0}" (-> f32) - : [x] "{xmm0}" (x) - ); -} - -pub fn sqrt64(x: f64) f64 { - return asm ( - \\sqrtsd %%xmm0, %%xmm0 - : [ret] "={xmm0}" (-> f64) - : [x] "{xmm0}" (x) - ); -} diff --git a/std/special/builtin.zig b/std/special/builtin.zig index ac6eefe3d9..56aa2ebaf8 100644 --- a/std/special/builtin.zig +++ b/std/special/builtin.zig @@ -194,3 +194,212 @@ fn isNan(comptime T: type, bits: T) bool { unreachable; } } + +// NOTE: The original code is full of implicit signed -> unsigned assumptions and u32 wraparound +// behaviour. Most intermediate i32 values are changed to u32 where appropriate but there are +// potentially some edge cases remaining that are not handled in the same way. +export fn sqrt(x: f64) f64 { + const tiny: f64 = 1.0e-300; + const sign: u32 = 0x80000000; + const u = @bitCast(u64, x); + + var ix0 = u32(u >> 32); + var ix1 = u32(u & 0xFFFFFFFF); + + // sqrt(nan) = nan, sqrt(+inf) = +inf, sqrt(-inf) = nan + if (ix0 & 0x7FF00000 == 0x7FF00000) { + return x * x + x; + } + + // sqrt(+-0) = +-0 + if (x == 0.0) { + return x; + } + // sqrt(-ve) = snan + if (ix0 & sign != 0) { + return math.snan(f64); + } + + // normalize x + var m = i32(ix0 >> 20); + if (m == 0) { + // subnormal + while (ix0 == 0) { + m -= 21; + ix0 |= ix1 >> 11; + ix1 <<= 21; + } + + // subnormal + var i: u32 = 0; + while (ix0 & 0x00100000 == 0) : (i += 1) { + ix0 <<= 1; + } + m -= i32(i) - 1; + ix0 |= ix1 >> u5(32 - i); + ix1 <<= u5(i); + } + + // unbias exponent + m -= 1023; + ix0 = (ix0 & 0x000FFFFF) | 0x00100000; + if (m & 1 != 0) { + ix0 += ix0 + (ix1 >> 31); + ix1 = ix1 +% ix1; + } + m >>= 1; + + // sqrt(x) bit by bit + ix0 += ix0 + (ix1 >> 31); + ix1 = ix1 +% ix1; + + var q: u32 = 0; + var q1: u32 = 0; + var s0: u32 = 0; + var s1: u32 = 0; + var r: u32 = 0x00200000; + var t: u32 = undefined; + var t1: u32 = undefined; + + while (r != 0) { + t = s0 +% r; + if (t <= ix0) { + s0 = t + r; + ix0 -= t; + q += r; + } + ix0 = ix0 +% ix0 +% (ix1 >> 31); + ix1 = ix1 +% ix1; + r >>= 1; + } + + r = sign; + while (r != 0) { + t = s1 +% r; + t = s0; + if (t < ix0 or (t == ix0 and t1 <= ix1)) { + s1 = t1 +% r; + if (t1 & sign == sign and s1 & sign == 0) { + s0 += 1; + } + ix0 -= t; + if (ix1 < t1) { + ix0 -= 1; + } + ix1 = ix1 -% t1; + q1 += r; + } + ix0 = ix0 +% ix0 +% (ix1 >> 31); + ix1 = ix1 +% ix1; + r >>= 1; + } + + // rounding direction + if (ix0 | ix1 != 0) { + var z = 1.0 - tiny; // raise inexact + if (z >= 1.0) { + z = 1.0 + tiny; + if (q1 == 0xFFFFFFFF) { + q1 = 0; + q += 1; + } else if (z > 1.0) { + if (q1 == 0xFFFFFFFE) { + q += 1; + } + q1 += 2; + } else { + q1 += q1 & 1; + } + } + } + + ix0 = (q >> 1) + 0x3FE00000; + ix1 = q1 >> 1; + if (q & 1 != 0) { + ix1 |= 0x80000000; + } + + // NOTE: musl here appears to rely on signed twos-complement wraparound. +% has the same + // behaviour at least. + var iix0 = i32(ix0); + iix0 = iix0 +% (m << 20); + + const uz = (u64(iix0) << 32) | ix1; + return @bitCast(f64, uz); +} + +export fn sqrtf(x: f32) f32 { + const tiny: f32 = 1.0e-30; + const sign: i32 = @bitCast(i32, u32(0x80000000)); + var ix: i32 = @bitCast(i32, x); + + if ((ix & 0x7F800000) == 0x7F800000) { + return x * x + x; // sqrt(nan) = nan, sqrt(+inf) = +inf, sqrt(-inf) = snan + } + + // zero + if (ix <= 0) { + if (ix & ~sign == 0) { + return x; // sqrt (+-0) = +-0 + } + if (ix < 0) { + return math.snan(f32); + } + } + + // normalize + var m = ix >> 23; + if (m == 0) { + // subnormal + var i: i32 = 0; + while (ix & 0x00800000 == 0) : (i += 1) { + ix <<= 1; + } + m -= i - 1; + } + + m -= 127; // unbias exponent + ix = (ix & 0x007FFFFF) | 0x00800000; + + if (m & 1 != 0) { // odd m, double x to even + ix += ix; + } + + m >>= 1; // m = [m / 2] + + // sqrt(x) bit by bit + ix += ix; + var q: i32 = 0; // q = sqrt(x) + var s: i32 = 0; + var r: i32 = 0x01000000; // r = moving bit right -> left + + while (r != 0) { + const t = s + r; + if (t <= ix) { + s = t + r; + ix -= t; + q += r; + } + ix += ix; + r >>= 1; + } + + // floating add to find rounding direction + if (ix != 0) { + var z = 1.0 - tiny; // inexact + if (z >= 1.0) { + z = 1.0 + tiny; + if (z > 1.0) { + q += 2; + } else { + if (q & 1 != 0) { + q += 1; + } + } + } + } + + ix = (q >> 1) + 0x3f000000; + ix += m << 23; + return @bitCast(f32, ix); +} diff --git a/test/cases/math.zig b/test/cases/math.zig index 574aa39bb1..47d001a590 100644 --- a/test/cases/math.zig +++ b/test/cases/math.zig @@ -402,3 +402,19 @@ test "comptime float rem int" { assert(x == 1.0); } } + +test "@sqrt" { + testSqrt(f64, 12.0); + comptime testSqrt(f64, 12.0); + testSqrt(f32, 13.0); + comptime testSqrt(f32, 13.0); + + const x = 14.0; + const y = x * x; + const z = @sqrt(@typeOf(y), y); + comptime assert(z == x); +} + +fn testSqrt(comptime T: type, x: T) void { + assert(@sqrt(T, x * x) == x); +}