translate-c: Add support for cast-to-union

Fixes #10955
This commit is contained in:
Evan Haas 2022-02-21 13:18:18 -08:00 committed by Veikka Tuominen
parent 4a0b037464
commit 9716a1c3ab
8 changed files with 95 additions and 6 deletions

View File

@ -30,6 +30,12 @@ pub fn cast(comptime DestType: type, target: anytype) DestType {
else => {},
}
},
.Union => |info| {
inline for (info.fields) |field| {
if (field.field_type == SourceType) return @unionInit(DestType, field.name, target);
}
@compileError("cast to union type '" ++ @typeName(DestType) ++ "' from type '" ++ @typeName(SourceType) ++ "' which is not present in union");
},
else => {},
}
return @as(DestType, target);

View File

@ -258,6 +258,14 @@ pub const CaseStmt = opaque {
extern fn ZigClangCaseStmt_getSubStmt(*const CaseStmt) *const Stmt;
};
pub const CastExpr = opaque {
pub const getCastKind = ZigClangCastExpr_getCastKind;
extern fn ZigClangCastExpr_getCastKind(*const CastExpr) CK;
pub const getTargetFieldForToUnionCast = ZigClangCastExpr_getTargetFieldForToUnionCast;
extern fn ZigClangCastExpr_getTargetFieldForToUnionCast(*const CastExpr, QualType, QualType) ?*const FieldDecl;
};
pub const CharacterLiteral = opaque {
pub const getBeginLoc = ZigClangCharacterLiteral_getBeginLoc;
extern fn ZigClangCharacterLiteral_getBeginLoc(*const CharacterLiteral) SourceLocation;

View File

@ -1791,14 +1791,31 @@ fn transCStyleCastExprClass(
stmt: *const clang.CStyleCastExpr,
result_used: ResultUsed,
) TransError!Node {
const cast_expr = @ptrCast(*const clang.CastExpr, stmt);
const sub_expr = stmt.getSubExpr();
const cast_node = (try transCCast(
const dst_type = stmt.getType();
const src_type = sub_expr.getType();
const sub_expr_node = try transExpr(c, scope, sub_expr, .used);
const loc = stmt.getBeginLoc();
const cast_node = if (cast_expr.getCastKind() == .ToUnion) blk: {
const field_decl = cast_expr.getTargetFieldForToUnionCast(dst_type, src_type).?; // C syntax error if target field is null
const field_name = try c.str(@ptrCast(*const clang.NamedDecl, field_decl).getName_bytes_begin());
const union_ty = try transQualType(c, scope, dst_type, loc);
const inits = [1]ast.Payload.ContainerInit.Initializer{.{ .name = field_name, .value = sub_expr_node }};
break :blk try Tag.container_init.create(c.arena, .{
.lhs = union_ty,
.inits = try c.arena.dupe(ast.Payload.ContainerInit.Initializer, &inits),
});
} else (try transCCast(
c,
scope,
stmt.getBeginLoc(),
stmt.getType(),
sub_expr.getType(),
try transExpr(c, scope, sub_expr, .used),
loc,
dst_type,
src_type,
sub_expr_node,
));
return maybeSuppressResult(c, scope, result_used, cast_node);
}
@ -2370,7 +2387,7 @@ fn cIntTypeForEnum(enum_qt: clang.QualType) clang.QualType {
return enum_decl.getIntegerType();
}
// when modifying this function, make sure to also update std.meta.cast
// when modifying this function, make sure to also update std.zig.c_translation.cast
fn transCCast(
c: *Context,
scope: *Scope,

View File

@ -2986,6 +2986,18 @@ const struct ZigClangCompoundStmt *ZigClangStmtExpr_getSubStmt(const struct ZigC
return reinterpret_cast<const ZigClangCompoundStmt *>(casted->getSubStmt());
}
enum ZigClangCK ZigClangCastExpr_getCastKind(const struct ZigClangCastExpr *self) {
auto casted = reinterpret_cast<const clang::CastExpr *>(self);
return (ZigClangCK)casted->getCastKind();
}
const struct ZigClangFieldDecl *ZigClangCastExpr_getTargetFieldForToUnionCast(const struct ZigClangCastExpr *self, ZigClangQualType union_type, ZigClangQualType op_type) {
clang::QualType union_qt = bitcast(union_type);
clang::QualType op_qt = bitcast(op_type);
auto casted = reinterpret_cast<const clang::CastExpr *>(self);
return reinterpret_cast<const ZigClangFieldDecl *>(casted->getTargetFieldForToUnionCast(union_qt, op_qt));
}
struct ZigClangSourceLocation ZigClangCharacterLiteral_getBeginLoc(const struct ZigClangCharacterLiteral *self) {
auto casted = reinterpret_cast<const clang::CharacterLiteral *>(self);
return bitcast(casted->getBeginLoc());

View File

@ -103,6 +103,7 @@ struct ZigClangBuiltinType;
struct ZigClangCStyleCastExpr;
struct ZigClangCallExpr;
struct ZigClangCaseStmt;
struct ZigClangCastExpr;
struct ZigClangCharacterLiteral;
struct ZigClangChooseExpr;
struct ZigClangCompoundAssignOperator;
@ -1317,6 +1318,9 @@ ZIG_EXTERN_C struct ZigClangQualType ZigClangDecayedType_getDecayedType(const st
ZIG_EXTERN_C const struct ZigClangCompoundStmt *ZigClangStmtExpr_getSubStmt(const struct ZigClangStmtExpr *);
ZIG_EXTERN_C enum ZigClangCK ZigClangCastExpr_getCastKind(const struct ZigClangCastExpr *);
ZIG_EXTERN_C const struct ZigClangFieldDecl *ZigClangCastExpr_getTargetFieldForToUnionCast(const struct ZigClangCastExpr *, struct ZigClangQualType, struct ZigClangQualType);
ZIG_EXTERN_C struct ZigClangSourceLocation ZigClangCharacterLiteral_getBeginLoc(const struct ZigClangCharacterLiteral *);
ZIG_EXTERN_C enum ZigClangCharacterLiteral_CharacterKind ZigClangCharacterLiteral_getKind(const struct ZigClangCharacterLiteral *);
ZIG_EXTERN_C unsigned ZigClangCharacterLiteral_getValue(const struct ZigClangCharacterLiteral *);

View File

@ -15,6 +15,11 @@ struct Foo {
int a;
};
union U {
long l;
double d;
};
#define SIZE_OF_FOO sizeof(struct Foo)
#define MAP_FAILED ((void *) -1)
@ -30,3 +35,5 @@ struct Foo {
#define IGNORE_ME_8(x) (volatile void)(x)
#define IGNORE_ME_9(x) (const volatile void)(x)
#define IGNORE_ME_10(x) (volatile const void)(x)
#define UNION_CAST(X) (union U)(X)

View File

@ -47,3 +47,16 @@ test "cast negative integer to pointer" {
try expectEqual(@intToPtr(?*anyopaque, @bitCast(usize, @as(isize, -1))), h.MAP_FAILED);
}
test "casting to union with a macro" {
if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO Sema.zirUnionInitPtr
const l: c_long = 42;
const d: f64 = 2.0;
var casted = h.UNION_CAST(l);
try expectEqual(l, casted.l);
casted = h.UNION_CAST(d);
try expectEqual(d, casted.d);
}

View File

@ -1829,4 +1829,26 @@ pub fn addCases(cases: *tests.RunTranslatedCContext) void {
\\ return 0;
\\}
, "");
cases.add("Cast-to-union. Issue #10955",
\\#include <stdlib.h>
\\struct S { int x; };
\\union U {
\\ long l;
\\ double d;
\\ struct S s;
\\};
\\union U bar(union U u) { return u; }
\\int main(void) {
\\ union U u = (union U) 42L;
\\ if (u.l != 42L) abort();
\\ u = (union U) 2.0;
\\ if (u.d != 2.0) abort();
\\ u = bar((union U)4.0);
\\ if (u.d != 4.0) abort();
\\ u = (union U)(struct S){ .x = 5 };
\\ if (u.s.x != 5) abort();
\\ return 0;
\\}
, "");
}