diff --git a/std/math/index.zig b/std/math/index.zig index a015450760..a066e20b14 100644 --- a/std/math/index.zig +++ b/std/math/index.zig @@ -219,11 +219,59 @@ pub fn negate(x: var) -> %@typeOf(x) { } error Overflow; -pub fn shl(comptime T: type, a: T, shift_amt: Log2Int(T)) -> %T { +pub fn shlExact(comptime T: type, a: T, shift_amt: Log2Int(T)) -> %T { var answer: T = undefined; if (@shlWithOverflow(T, a, shift_amt, &answer)) error.Overflow else answer } +/// Shifts left. Overflowed bits are truncated. +/// A negative shift amount results in a right shift. +pub fn shl(comptime T: type, a: T, shift_amt: var) -> T { + const abs_shift_amt = absCast(shift_amt); + const casted_shift_amt = if (abs_shift_amt >= T.bit_count) return 0 else Log2Int(T)(abs_shift_amt); + + if (@typeOf(shift_amt).is_signed) { + if (shift_amt >= 0) { + return a << casted_shift_amt; + } else { + return a >> casted_shift_amt; + } + } + + return a << casted_shift_amt; +} + +test "math.shl" { + assert(shl(u8, 0b11111111, usize(3)) == 0b11111000); + assert(shl(u8, 0b11111111, usize(8)) == 0); + assert(shl(u8, 0b11111111, usize(9)) == 0); + assert(shl(u8, 0b11111111, isize(-2)) == 0b00111111); +} + +/// Shifts right. Overflowed bits are truncated. +/// A negative shift amount results in a lefft shift. +pub fn shr(comptime T: type, a: T, shift_amt: var) -> T { + const abs_shift_amt = absCast(shift_amt); + const casted_shift_amt = if (abs_shift_amt >= T.bit_count) return 0 else Log2Int(T)(abs_shift_amt); + + if (@typeOf(shift_amt).is_signed) { + if (shift_amt >= 0) { + return a >> casted_shift_amt; + } else { + return a << casted_shift_amt; + } + } + + return a >> casted_shift_amt; +} + +test "math.shr" { + assert(shr(u8, 0b11111111, usize(3)) == 0b00011111); + assert(shr(u8, 0b11111111, usize(8)) == 0); + assert(shr(u8, 0b11111111, usize(9)) == 0); + assert(shr(u8, 0b11111111, isize(-2)) == 0b11111100); +} + pub fn Log2Int(comptime T: type) -> type { @IntType(false, log2(T.bit_count)) } @@ -237,7 +285,7 @@ fn testOverflow() { assert(%%mul(i32, 3, 4) == 12); assert(%%add(i32, 3, 4) == 7); assert(%%sub(i32, 3, 4) == -1); - assert(%%shl(i32, 0b11, 4) == 0b110000); + assert(%%shlExact(i32, 0b11, 4) == 0b110000); }