mirror of
https://github.com/ziglang/zig.git
synced 2026-01-04 20:43:19 +00:00
wasm: Implement basic f16 support
This implements binary operations and comparisons for floats with bitsize 16. It does this by calling into compiler-rt to first extend the float to 32 bits, perform the operation, and then finally truncate back to 16 bits. When loading and storing the f16, we do this as an unsigned 16bit integer.
This commit is contained in:
parent
241180216f
commit
5ebaf49ebb
@ -215,8 +215,7 @@ fn buildOpcode(args: OpcodeBuildArguments) wasm.Opcode {
|
||||
16 => switch (args.valtype1.?) {
|
||||
.i32 => if (args.signedness.? == .signed) return .i32_load16_s else return .i32_load16_u,
|
||||
.i64 => if (args.signedness.? == .signed) return .i64_load16_s else return .i64_load16_u,
|
||||
.f32 => return .f32_load,
|
||||
.f64 => unreachable,
|
||||
.f32, .f64 => unreachable,
|
||||
},
|
||||
32 => switch (args.valtype1.?) {
|
||||
.i64 => if (args.signedness.? == .signed) return .i64_load32_s else return .i64_load32_u,
|
||||
@ -246,8 +245,7 @@ fn buildOpcode(args: OpcodeBuildArguments) wasm.Opcode {
|
||||
16 => switch (args.valtype1.?) {
|
||||
.i32 => return .i32_store16,
|
||||
.i64 => return .i64_store16,
|
||||
.f32 => return .f32_store,
|
||||
.f64 => unreachable,
|
||||
.f32, .f64 => unreachable,
|
||||
},
|
||||
32 => switch (args.valtype1.?) {
|
||||
.i64 => return .i64_store32,
|
||||
@ -725,7 +723,8 @@ fn typeToValtype(ty: Type, target: std.Target) wasm.Valtype {
|
||||
return switch (ty.zigTypeTag()) {
|
||||
.Float => blk: {
|
||||
const bits = ty.floatBits(target);
|
||||
if (bits == 16 or bits == 32) break :blk wasm.Valtype.f32;
|
||||
if (bits == 16) return wasm.Valtype.i32; // stored/loaded as u16
|
||||
if (bits == 32) break :blk wasm.Valtype.f32;
|
||||
if (bits == 64) break :blk wasm.Valtype.f64;
|
||||
if (bits == 128) break :blk wasm.Valtype.i64;
|
||||
return wasm.Valtype.i32; // represented as pointer to stack
|
||||
@ -2013,6 +2012,10 @@ fn binOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WVa
|
||||
}
|
||||
}
|
||||
|
||||
if (ty.isAnyFloat() and ty.floatBits(self.target) == 16) {
|
||||
return self.binOpFloat16(lhs, rhs, op);
|
||||
}
|
||||
|
||||
const opcode: wasm.Opcode = buildOpcode(.{
|
||||
.op = op,
|
||||
.valtype1 = typeToValtype(ty, self.target),
|
||||
@ -2029,6 +2032,20 @@ fn binOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WVa
|
||||
return bin_local;
|
||||
}
|
||||
|
||||
fn binOpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: Op) InnerError!WValue {
|
||||
const ext_lhs = try self.fpext(lhs, Type.f16, Type.f32);
|
||||
const ext_rhs = try self.fpext(rhs, Type.f16, Type.f32);
|
||||
|
||||
const opcode: wasm.Opcode = buildOpcode(.{ .op = op, .valtype1 = .f32, .signedness = .unsigned });
|
||||
try self.emitWValue(ext_lhs);
|
||||
try self.emitWValue(ext_rhs);
|
||||
try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
|
||||
|
||||
// re-use temporary local
|
||||
try self.addLabel(.local_set, ext_lhs.local);
|
||||
return self.fptrunc(ext_lhs, Type.f32, Type.f16);
|
||||
}
|
||||
|
||||
fn binOpBigInt(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WValue {
|
||||
if (ty.intInfo(self.target).bits > 128) {
|
||||
return self.fail("TODO: Implement binary operation for big integer", .{});
|
||||
@ -2310,8 +2327,9 @@ fn lowerConstant(self: *Self, val: Value, ty: Type) InnerError!WValue {
|
||||
},
|
||||
.Bool => return WValue{ .imm32 = @intCast(u32, val.toUnsignedInt(target)) },
|
||||
.Float => switch (ty.floatBits(self.target)) {
|
||||
0...32 => return WValue{ .float32 = val.toFloat(f32) },
|
||||
33...64 => return WValue{ .float64 = val.toFloat(f64) },
|
||||
16 => return WValue{ .imm32 = @bitCast(u16, val.toFloat(f16)) },
|
||||
32 => return WValue{ .float32 = val.toFloat(f32) },
|
||||
64 => return WValue{ .float64 = val.toFloat(f64) },
|
||||
else => unreachable,
|
||||
},
|
||||
.Pointer => switch (val.tag()) {
|
||||
@ -2389,8 +2407,9 @@ fn emitUndefined(self: *Self, ty: Type) InnerError!WValue {
|
||||
else => unreachable,
|
||||
},
|
||||
.Float => switch (ty.floatBits(self.target)) {
|
||||
0...32 => return WValue{ .float32 = @bitCast(f32, @as(u32, 0xaaaaaaaa)) },
|
||||
33...64 => return WValue{ .float64 = @bitCast(f64, @as(u64, 0xaaaaaaaaaaaaaaaa)) },
|
||||
16 => return WValue{ .imm32 = 0xaaaaaaaa },
|
||||
32 => return WValue{ .float32 = @bitCast(f32, @as(u32, 0xaaaaaaaa)) },
|
||||
64 => return WValue{ .float64 = @bitCast(f64, @as(u64, 0xaaaaaaaaaaaaaaaa)) },
|
||||
else => unreachable,
|
||||
},
|
||||
.Pointer => switch (self.arch()) {
|
||||
@ -2562,6 +2581,8 @@ fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOper
|
||||
}
|
||||
} else if (isByRef(ty, self.target)) {
|
||||
return self.cmpBigInt(lhs, rhs, ty, op);
|
||||
} else if (ty.isAnyFloat() and ty.floatBits(self.target) == 16) {
|
||||
return self.cmpFloat16(lhs, rhs, op);
|
||||
}
|
||||
|
||||
// ensure that when we compare pointers, we emit
|
||||
@ -2595,6 +2616,31 @@ fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOper
|
||||
return cmp_tmp;
|
||||
}
|
||||
|
||||
fn cmpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: std.math.CompareOperator) InnerError!WValue {
|
||||
const ext_lhs = try self.fpext(lhs, Type.f16, Type.f32);
|
||||
const ext_rhs = try self.fpext(rhs, Type.f16, Type.f32);
|
||||
|
||||
const opcode: wasm.Opcode = buildOpcode(.{
|
||||
.op = switch (op) {
|
||||
.lt => .lt,
|
||||
.lte => .le,
|
||||
.eq => .eq,
|
||||
.neq => .ne,
|
||||
.gte => .ge,
|
||||
.gt => .gt,
|
||||
},
|
||||
.valtype1 = .f32,
|
||||
.signedness = .unsigned,
|
||||
});
|
||||
try self.emitWValue(ext_lhs);
|
||||
try self.emitWValue(ext_rhs);
|
||||
try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
|
||||
|
||||
const result = try self.allocLocal(Type.initTag(.i32)); // bool is always i32
|
||||
try self.addLabel(.local_set, result.local);
|
||||
return result;
|
||||
}
|
||||
|
||||
fn airCmpVector(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
|
||||
_ = inst;
|
||||
return self.fail("TODO implement airCmpVector for wasm", .{});
|
||||
@ -3934,19 +3980,44 @@ fn airFpext(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
|
||||
|
||||
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
|
||||
const dest_ty = self.air.typeOfIndex(inst);
|
||||
const dest_bits = dest_ty.floatBits(self.target);
|
||||
const src_bits = self.air.typeOf(ty_op.operand).floatBits(self.target);
|
||||
const operand = try self.resolveInst(ty_op.operand);
|
||||
|
||||
if (dest_bits == 64 and src_bits == 32) {
|
||||
const result = try self.allocLocal(dest_ty);
|
||||
return self.fpext(operand, self.air.typeOf(ty_op.operand), dest_ty);
|
||||
}
|
||||
|
||||
fn fpext(self: *Self, operand: WValue, given: Type, wanted: Type) InnerError!WValue {
|
||||
const given_bits = given.floatBits(self.target);
|
||||
const wanted_bits = wanted.floatBits(self.target);
|
||||
|
||||
if (wanted_bits == 64 and given_bits == 32) {
|
||||
const result = try self.allocLocal(wanted);
|
||||
try self.emitWValue(operand);
|
||||
try self.addTag(.f64_promote_f32);
|
||||
try self.addLabel(.local_set, result.local);
|
||||
return result;
|
||||
} else if (given_bits == 16) {
|
||||
// call __extendhfsf2(f16) f32
|
||||
const f32_result = try self.callIntrinsic(
|
||||
"__extendhfsf2",
|
||||
&.{Type.f16},
|
||||
Type.f32,
|
||||
&.{operand},
|
||||
);
|
||||
|
||||
if (wanted_bits == 32) {
|
||||
return f32_result;
|
||||
}
|
||||
if (wanted_bits == 64) {
|
||||
const result = try self.allocLocal(wanted);
|
||||
try self.emitWValue(f32_result);
|
||||
try self.addTag(.f64_promote_f32);
|
||||
try self.addLabel(.local_set, result.local);
|
||||
return result;
|
||||
}
|
||||
return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{wanted_bits});
|
||||
} else {
|
||||
// TODO: Emit a call to compiler-rt to extend the float. e.g. __extendhfsf2
|
||||
return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{dest_bits});
|
||||
return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{wanted_bits});
|
||||
}
|
||||
}
|
||||
|
||||
@ -3955,19 +4026,34 @@ fn airFptrunc(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
|
||||
|
||||
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
|
||||
const dest_ty = self.air.typeOfIndex(inst);
|
||||
const dest_bits = dest_ty.floatBits(self.target);
|
||||
const src_bits = self.air.typeOf(ty_op.operand).floatBits(self.target);
|
||||
const operand = try self.resolveInst(ty_op.operand);
|
||||
return self.fptrunc(operand, self.air.typeOf(ty_op.operand), dest_ty);
|
||||
}
|
||||
|
||||
if (dest_bits == 32 and src_bits == 64) {
|
||||
const result = try self.allocLocal(dest_ty);
|
||||
fn fptrunc(self: *Self, operand: WValue, given: Type, wanted: Type) InnerError!WValue {
|
||||
const given_bits = given.floatBits(self.target);
|
||||
const wanted_bits = wanted.floatBits(self.target);
|
||||
|
||||
if (wanted_bits == 32 and given_bits == 64) {
|
||||
const result = try self.allocLocal(wanted);
|
||||
try self.emitWValue(operand);
|
||||
try self.addTag(.f32_demote_f64);
|
||||
try self.addLabel(.local_set, result.local);
|
||||
return result;
|
||||
} else if (wanted_bits == 16) {
|
||||
const op: WValue = if (given_bits == 64) blk: {
|
||||
const tmp = try self.allocLocal(Type.f32);
|
||||
try self.emitWValue(operand);
|
||||
try self.addTag(.f32_demote_f64);
|
||||
try self.addLabel(.local_set, tmp.local);
|
||||
break :blk tmp;
|
||||
} else operand;
|
||||
|
||||
// call __truncsfhf2(f32) f16
|
||||
return self.callIntrinsic("__truncsfhf2", &.{Type.f32}, Type.f16, &.{op});
|
||||
} else {
|
||||
// TODO: Emit a call to compiler-rt to trunc the float. e.g. __truncdfhf2
|
||||
return self.fail("TODO: Implement 'fptrunc' for floats with bitsize: {d}", .{dest_bits});
|
||||
return self.fail("TODO: Implement 'fptrunc' for floats with bitsize: {d}", .{wanted_bits});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user