wasm: Implement basic f16 support

This implements binary operations and comparisons
for floats with bitsize 16. It does this by calling into
compiler-rt to first extend the float to 32 bits, perform the operation,
and then finally truncate back to 16 bits. When loading and storing the f16,
we do this as an unsigned 16bit integer.
This commit is contained in:
Luuk de Gram 2022-05-30 21:07:03 +02:00
parent 241180216f
commit 5ebaf49ebb

View File

@ -215,8 +215,7 @@ fn buildOpcode(args: OpcodeBuildArguments) wasm.Opcode {
16 => switch (args.valtype1.?) {
.i32 => if (args.signedness.? == .signed) return .i32_load16_s else return .i32_load16_u,
.i64 => if (args.signedness.? == .signed) return .i64_load16_s else return .i64_load16_u,
.f32 => return .f32_load,
.f64 => unreachable,
.f32, .f64 => unreachable,
},
32 => switch (args.valtype1.?) {
.i64 => if (args.signedness.? == .signed) return .i64_load32_s else return .i64_load32_u,
@ -246,8 +245,7 @@ fn buildOpcode(args: OpcodeBuildArguments) wasm.Opcode {
16 => switch (args.valtype1.?) {
.i32 => return .i32_store16,
.i64 => return .i64_store16,
.f32 => return .f32_store,
.f64 => unreachable,
.f32, .f64 => unreachable,
},
32 => switch (args.valtype1.?) {
.i64 => return .i64_store32,
@ -725,7 +723,8 @@ fn typeToValtype(ty: Type, target: std.Target) wasm.Valtype {
return switch (ty.zigTypeTag()) {
.Float => blk: {
const bits = ty.floatBits(target);
if (bits == 16 or bits == 32) break :blk wasm.Valtype.f32;
if (bits == 16) return wasm.Valtype.i32; // stored/loaded as u16
if (bits == 32) break :blk wasm.Valtype.f32;
if (bits == 64) break :blk wasm.Valtype.f64;
if (bits == 128) break :blk wasm.Valtype.i64;
return wasm.Valtype.i32; // represented as pointer to stack
@ -2013,6 +2012,10 @@ fn binOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WVa
}
}
if (ty.isAnyFloat() and ty.floatBits(self.target) == 16) {
return self.binOpFloat16(lhs, rhs, op);
}
const opcode: wasm.Opcode = buildOpcode(.{
.op = op,
.valtype1 = typeToValtype(ty, self.target),
@ -2029,6 +2032,20 @@ fn binOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WVa
return bin_local;
}
fn binOpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: Op) InnerError!WValue {
const ext_lhs = try self.fpext(lhs, Type.f16, Type.f32);
const ext_rhs = try self.fpext(rhs, Type.f16, Type.f32);
const opcode: wasm.Opcode = buildOpcode(.{ .op = op, .valtype1 = .f32, .signedness = .unsigned });
try self.emitWValue(ext_lhs);
try self.emitWValue(ext_rhs);
try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
// re-use temporary local
try self.addLabel(.local_set, ext_lhs.local);
return self.fptrunc(ext_lhs, Type.f32, Type.f16);
}
fn binOpBigInt(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WValue {
if (ty.intInfo(self.target).bits > 128) {
return self.fail("TODO: Implement binary operation for big integer", .{});
@ -2310,8 +2327,9 @@ fn lowerConstant(self: *Self, val: Value, ty: Type) InnerError!WValue {
},
.Bool => return WValue{ .imm32 = @intCast(u32, val.toUnsignedInt(target)) },
.Float => switch (ty.floatBits(self.target)) {
0...32 => return WValue{ .float32 = val.toFloat(f32) },
33...64 => return WValue{ .float64 = val.toFloat(f64) },
16 => return WValue{ .imm32 = @bitCast(u16, val.toFloat(f16)) },
32 => return WValue{ .float32 = val.toFloat(f32) },
64 => return WValue{ .float64 = val.toFloat(f64) },
else => unreachable,
},
.Pointer => switch (val.tag()) {
@ -2389,8 +2407,9 @@ fn emitUndefined(self: *Self, ty: Type) InnerError!WValue {
else => unreachable,
},
.Float => switch (ty.floatBits(self.target)) {
0...32 => return WValue{ .float32 = @bitCast(f32, @as(u32, 0xaaaaaaaa)) },
33...64 => return WValue{ .float64 = @bitCast(f64, @as(u64, 0xaaaaaaaaaaaaaaaa)) },
16 => return WValue{ .imm32 = 0xaaaaaaaa },
32 => return WValue{ .float32 = @bitCast(f32, @as(u32, 0xaaaaaaaa)) },
64 => return WValue{ .float64 = @bitCast(f64, @as(u64, 0xaaaaaaaaaaaaaaaa)) },
else => unreachable,
},
.Pointer => switch (self.arch()) {
@ -2562,6 +2581,8 @@ fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOper
}
} else if (isByRef(ty, self.target)) {
return self.cmpBigInt(lhs, rhs, ty, op);
} else if (ty.isAnyFloat() and ty.floatBits(self.target) == 16) {
return self.cmpFloat16(lhs, rhs, op);
}
// ensure that when we compare pointers, we emit
@ -2595,6 +2616,31 @@ fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOper
return cmp_tmp;
}
fn cmpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: std.math.CompareOperator) InnerError!WValue {
const ext_lhs = try self.fpext(lhs, Type.f16, Type.f32);
const ext_rhs = try self.fpext(rhs, Type.f16, Type.f32);
const opcode: wasm.Opcode = buildOpcode(.{
.op = switch (op) {
.lt => .lt,
.lte => .le,
.eq => .eq,
.neq => .ne,
.gte => .ge,
.gt => .gt,
},
.valtype1 = .f32,
.signedness = .unsigned,
});
try self.emitWValue(ext_lhs);
try self.emitWValue(ext_rhs);
try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
const result = try self.allocLocal(Type.initTag(.i32)); // bool is always i32
try self.addLabel(.local_set, result.local);
return result;
}
fn airCmpVector(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
_ = inst;
return self.fail("TODO implement airCmpVector for wasm", .{});
@ -3934,19 +3980,44 @@ fn airFpext(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const dest_ty = self.air.typeOfIndex(inst);
const dest_bits = dest_ty.floatBits(self.target);
const src_bits = self.air.typeOf(ty_op.operand).floatBits(self.target);
const operand = try self.resolveInst(ty_op.operand);
if (dest_bits == 64 and src_bits == 32) {
const result = try self.allocLocal(dest_ty);
return self.fpext(operand, self.air.typeOf(ty_op.operand), dest_ty);
}
fn fpext(self: *Self, operand: WValue, given: Type, wanted: Type) InnerError!WValue {
const given_bits = given.floatBits(self.target);
const wanted_bits = wanted.floatBits(self.target);
if (wanted_bits == 64 and given_bits == 32) {
const result = try self.allocLocal(wanted);
try self.emitWValue(operand);
try self.addTag(.f64_promote_f32);
try self.addLabel(.local_set, result.local);
return result;
} else if (given_bits == 16) {
// call __extendhfsf2(f16) f32
const f32_result = try self.callIntrinsic(
"__extendhfsf2",
&.{Type.f16},
Type.f32,
&.{operand},
);
if (wanted_bits == 32) {
return f32_result;
}
if (wanted_bits == 64) {
const result = try self.allocLocal(wanted);
try self.emitWValue(f32_result);
try self.addTag(.f64_promote_f32);
try self.addLabel(.local_set, result.local);
return result;
}
return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{wanted_bits});
} else {
// TODO: Emit a call to compiler-rt to extend the float. e.g. __extendhfsf2
return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{dest_bits});
return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{wanted_bits});
}
}
@ -3955,19 +4026,34 @@ fn airFptrunc(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const dest_ty = self.air.typeOfIndex(inst);
const dest_bits = dest_ty.floatBits(self.target);
const src_bits = self.air.typeOf(ty_op.operand).floatBits(self.target);
const operand = try self.resolveInst(ty_op.operand);
return self.fptrunc(operand, self.air.typeOf(ty_op.operand), dest_ty);
}
if (dest_bits == 32 and src_bits == 64) {
const result = try self.allocLocal(dest_ty);
fn fptrunc(self: *Self, operand: WValue, given: Type, wanted: Type) InnerError!WValue {
const given_bits = given.floatBits(self.target);
const wanted_bits = wanted.floatBits(self.target);
if (wanted_bits == 32 and given_bits == 64) {
const result = try self.allocLocal(wanted);
try self.emitWValue(operand);
try self.addTag(.f32_demote_f64);
try self.addLabel(.local_set, result.local);
return result;
} else if (wanted_bits == 16) {
const op: WValue = if (given_bits == 64) blk: {
const tmp = try self.allocLocal(Type.f32);
try self.emitWValue(operand);
try self.addTag(.f32_demote_f64);
try self.addLabel(.local_set, tmp.local);
break :blk tmp;
} else operand;
// call __truncsfhf2(f32) f16
return self.callIntrinsic("__truncsfhf2", &.{Type.f32}, Type.f16, &.{op});
} else {
// TODO: Emit a call to compiler-rt to trunc the float. e.g. __truncdfhf2
return self.fail("TODO: Implement 'fptrunc' for floats with bitsize: {d}", .{dest_bits});
return self.fail("TODO: Implement 'fptrunc' for floats with bitsize: {d}", .{wanted_bits});
}
}