diff --git a/src/clang.zig b/src/clang.zig index fc7a25fe12..954cfee6b2 100644 --- a/src/clang.zig +++ b/src/clang.zig @@ -848,7 +848,10 @@ pub const UnaryOperator = opaque { extern fn ZigClangUnaryOperator_getBeginLoc(*const UnaryOperator) SourceLocation; }; -pub const ValueDecl = opaque {}; +pub const ValueDecl = opaque { + pub const getType = ZigClangValueDecl_getType; + extern fn ZigClangValueDecl_getType(*const ValueDecl) QualType; +}; pub const VarDecl = opaque { pub const getLocation = ZigClangVarDecl_getLocation; diff --git a/src/translate_c.zig b/src/translate_c.zig index 11dbacefa2..e3652ddbfb 100644 --- a/src/translate_c.zig +++ b/src/translate_c.zig @@ -3208,6 +3208,38 @@ fn transArrayAccess(rp: RestorePoint, scope: *Scope, stmt: *const clang.ArraySub return maybeSuppressResult(rp, scope, result_used, &node.base); } +/// Check if an expression is ultimately a reference to a function declaration +/// (which means it should not be unwrapped with `.?` in translated code) +fn cIsFunctionDeclRef(expr: *const clang.Expr) bool { + switch (expr.getStmtClass()) { + .ParenExprClass => { + const op_expr = @ptrCast(*const clang.ParenExpr, expr).getSubExpr(); + return cIsFunctionDeclRef(op_expr); + }, + .DeclRefExprClass => { + const decl_ref = @ptrCast(*const clang.DeclRefExpr, expr); + const value_decl = decl_ref.getDecl(); + const qt = value_decl.getType(); + return qualTypeChildIsFnProto(qt); + }, + .ImplicitCastExprClass => { + const implicit_cast = @ptrCast(*const clang.ImplicitCastExpr, expr); + const cast_kind = implicit_cast.getCastKind(); + if (cast_kind == .BuiltinFnToFnPtr) return true; + if (cast_kind == .FunctionToPointerDecay) { + return cIsFunctionDeclRef(implicit_cast.getSubExpr()); + } + return false; + }, + .UnaryOperatorClass => { + const un_op = @ptrCast(*const clang.UnaryOperator, expr); + const opcode = un_op.getOpcode(); + return (opcode == .AddrOf or opcode == .Deref) and cIsFunctionDeclRef(un_op.getSubExpr()); + }, + else => return false, + } +} + fn transCallExpr(rp: RestorePoint, scope: *Scope, stmt: *const clang.CallExpr, result_used: ResultUsed) TransError!*ast.Node { const callee = stmt.getCallee(); var raw_fn_expr = try transExpr(rp, scope, callee, .used, .r_value); @@ -3215,24 +3247,9 @@ fn transCallExpr(rp: RestorePoint, scope: *Scope, stmt: *const clang.CallExpr, r var is_ptr = false; const fn_ty = qualTypeGetFnProto(callee.getType(), &is_ptr); - const fn_expr = if (is_ptr and fn_ty != null) blk: { - if (callee.getStmtClass() == .ImplicitCastExprClass) { - const implicit_cast = @ptrCast(*const clang.ImplicitCastExpr, callee); - const cast_kind = implicit_cast.getCastKind(); - if (cast_kind == .BuiltinFnToFnPtr) break :blk raw_fn_expr; - if (cast_kind == .FunctionToPointerDecay) { - const subexpr = implicit_cast.getSubExpr(); - if (subexpr.getStmtClass() == .DeclRefExprClass) { - const decl_ref = @ptrCast(*const clang.DeclRefExpr, subexpr); - const named_decl = decl_ref.getFoundDecl(); - if (@ptrCast(*const clang.Decl, named_decl).getKind() == .Function) { - break :blk raw_fn_expr; - } - } - } - } - break :blk try transCreateNodeUnwrapNull(rp.c, raw_fn_expr); - } else + const fn_expr = if (is_ptr and fn_ty != null and !cIsFunctionDeclRef(callee)) + try transCreateNodeUnwrapNull(rp.c, raw_fn_expr) + else raw_fn_expr; const num_args = stmt.getNumArgs(); @@ -3379,6 +3396,9 @@ fn transUnaryOperator(rp: RestorePoint, scope: *Scope, stmt: *const clang.UnaryO else return transCreatePreCrement(rp, scope, stmt, .AssignSub, .MinusEqual, "-=", used), .AddrOf => { + if (cIsFunctionDeclRef(op_expr)) { + return transExpr(rp, scope, op_expr, used, .r_value); + } const op_node = try transCreateNodeSimplePrefixOp(rp.c, .AddressOf, .Ampersand, "&"); op_node.rhs = try transExpr(rp, scope, op_expr, used, .r_value); return &op_node.base; diff --git a/src/zig_clang.cpp b/src/zig_clang.cpp index 9bd68859e8..8dc6a0823b 100644 --- a/src/zig_clang.cpp +++ b/src/zig_clang.cpp @@ -2773,6 +2773,11 @@ struct ZigClangSourceLocation ZigClangUnaryOperator_getBeginLoc(const struct Zig return bitcast(casted->getBeginLoc()); } +struct ZigClangQualType ZigClangValueDecl_getType(const struct ZigClangValueDecl *self) { + auto casted = reinterpret_cast(self); + return bitcast(casted->getType()); +} + const struct ZigClangExpr *ZigClangWhileStmt_getCond(const struct ZigClangWhileStmt *self) { auto casted = reinterpret_cast(self); return reinterpret_cast(casted->getCond()); diff --git a/src/zig_clang.h b/src/zig_clang.h index 42587b1719..6fe1da0bc1 100644 --- a/src/zig_clang.h +++ b/src/zig_clang.h @@ -1200,6 +1200,8 @@ ZIG_EXTERN_C struct ZigClangQualType ZigClangUnaryOperator_getType(const struct ZIG_EXTERN_C const struct ZigClangExpr *ZigClangUnaryOperator_getSubExpr(const struct ZigClangUnaryOperator *); ZIG_EXTERN_C struct ZigClangSourceLocation ZigClangUnaryOperator_getBeginLoc(const struct ZigClangUnaryOperator *); +ZIG_EXTERN_C struct ZigClangQualType ZigClangValueDecl_getType(const struct ZigClangValueDecl *); + ZIG_EXTERN_C const struct ZigClangExpr *ZigClangWhileStmt_getCond(const struct ZigClangWhileStmt *); ZIG_EXTERN_C const struct ZigClangStmt *ZigClangWhileStmt_getBody(const struct ZigClangWhileStmt *); diff --git a/test/run_translated_c.zig b/test/run_translated_c.zig index a99271eb41..a8a3a0e21b 100644 --- a/test/run_translated_c.zig +++ b/test/run_translated_c.zig @@ -818,4 +818,60 @@ pub fn addCases(cases: *tests.RunTranslatedCContext) void { \\ return 0; \\} , ""); + + cases.add("Address of function is no-op", + \\#include + \\#include + \\typedef int (*myfunc)(int); + \\int a(int arg) { return arg + 1;} + \\int b(int arg) { return arg + 2;} + \\int caller(myfunc fn, int arg) { + \\ return fn(arg); + \\} + \\int main() { + \\ myfunc arr[3] = {&a, &b, a}; + \\ myfunc foo = a; + \\ myfunc bar = &(a); + \\ if (foo != bar) abort(); + \\ if (arr[0] == arr[1]) abort(); + \\ if (arr[0] != arr[2]) abort(); + \\ if (caller(b, 40) != 42) abort(); + \\ if (caller(&b, 40) != 42) abort(); + \\ return 0; + \\} + , ""); + + cases.add("Obscure ways of calling functions; issue #4124", + \\#include + \\static int add(int a, int b) { + \\ return a + b; + \\} + \\typedef int (*adder)(int, int); + \\typedef void (*funcptr)(void); + \\int main() { + \\ if ((add)(1, 2) != 3) abort(); + \\ if ((&add)(1, 2) != 3) abort(); + \\ if (add(3, 1) != 4) abort(); + \\ if ((*add)(2, 3) != 5) abort(); + \\ if ((**add)(7, -1) != 6) abort(); + \\ if ((***add)(-2, 9) != 7) abort(); + \\ + \\ int (*ptr)(int a, int b); + \\ ptr = add; + \\ + \\ if (ptr(1, 2) != 3) abort(); + \\ if ((*ptr)(3, 1) != 4) abort(); + \\ if ((**ptr)(2, 3) != 5) abort(); + \\ if ((***ptr)(7, -1) != 6) abort(); + \\ if ((****ptr)(-2, 9) != 7) abort(); + \\ + \\ funcptr addr1 = (funcptr)(add); + \\ funcptr addr2 = (funcptr)(&add); + \\ + \\ if (addr1 != addr2) abort(); + \\ if (((int(*)(int, int))addr1)(1, 2) != 3) abort(); + \\ if (((adder)addr2)(1, 2) != 3) abort(); + \\ return 0; + \\} + , ""); } diff --git a/test/translate_c.zig b/test/translate_c.zig index 10ac76a2c5..75d00d12f4 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -2802,8 +2802,8 @@ pub fn addCases(cases: *tests.TranslateCContext) void { \\ fn_f64(3); \\ fn_bool(@as(c_int, 123) != 0); \\ fn_bool(@as(c_int, 0) != 0); - \\ fn_bool(@ptrToInt(&fn_int) != 0); - \\ fn_int(@intCast(c_int, @ptrToInt(&fn_int))); + \\ fn_bool(@ptrToInt(fn_int) != 0); + \\ fn_int(@intCast(c_int, @ptrToInt(fn_int))); \\ fn_ptr(@intToPtr(?*c_void, @as(c_int, 42))); \\} });