From 1b0e90f70b4dc26c2ba96b7b5709a3ff269bb48a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 26 Nov 2017 00:58:11 -0500 Subject: [PATCH] translate-c supports switch statements --- src/translate_c.cpp | 327 +++++++++++++++++++++++++++++++------------ test/translate_c.zig | 80 +++++------ 2 files changed, 280 insertions(+), 127 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index 9e6676063d..27066a08b6 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -69,6 +69,7 @@ enum TransScopeId { TransScopeIdVar, TransScopeIdBlock, TransScopeIdRoot, + TransScopeIdWhile, }; struct TransScope { @@ -79,6 +80,8 @@ struct TransScope { struct TransScopeSwitch { TransScope base; AstNode *switch_node; + uint32_t case_index; + bool found_default; }; struct TransScopeVar { @@ -96,12 +99,19 @@ struct TransScopeRoot { TransScope base; }; +struct TransScopeWhile { + TransScope base; + AstNode *node; +}; + static TransScopeRoot *trans_scope_root_create(Context *c); +static TransScopeWhile *trans_scope_while_create(Context *c, TransScope *parent_scope); static TransScopeBlock *trans_scope_block_create(Context *c, TransScope *parent_scope); static TransScopeVar *trans_scope_var_create(Context *c, TransScope *parent_scope, Buf *wanted_name); -//static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope); +static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope); static TransScopeBlock *trans_scope_block_find(TransScope *scope); +static TransScopeSwitch *trans_scope_switch_find(TransScope *scope); static AstNode *resolve_record_decl(Context *c, const RecordDecl *record_decl); static AstNode *resolve_enum_decl(Context *c, const EnumDecl *enum_decl); @@ -238,6 +248,12 @@ static AstNode *trans_create_node_addr_of(Context *c, bool is_const, bool is_vol return node; } +static AstNode *trans_create_node_bool(Context *c, bool value) { + AstNode *bool_node = trans_create_node(c, NodeTypeBoolLiteral); + bool_node->data.bool_literal.value = value; + return bool_node; +} + static AstNode *trans_create_node_str_lit_c(Context *c, Buf *buf) { AstNode *node = trans_create_node(c, NodeTypeStringLiteral); node->data.string_literal.buf = buf; @@ -965,22 +981,30 @@ static AstNode *trans_qual_type(Context *c, QualType qt, const SourceLocation &s return trans_type(c, qt.getTypePtr(), source_loc); } -static AstNode *trans_compound_stmt(Context *c, TransScope *scope, const CompoundStmt *stmt, - TransScope **out_node_scope) +static int trans_compound_stmt_inline(Context *c, TransScope *scope, const CompoundStmt *stmt, + AstNode *block_node, TransScope **out_node_scope) { - TransScopeBlock *child_scope_block = trans_scope_block_create(c, scope); - scope = &child_scope_block->base; + assert(block_node->type == NodeTypeBlock); for (CompoundStmt::const_body_iterator it = stmt->body_begin(), end_it = stmt->body_end(); it != end_it; ++it) { AstNode *child_node; scope = trans_stmt(c, scope, *it, &child_node); if (scope == nullptr) - return nullptr; + return ErrorUnexpected; if (child_node != nullptr) - child_scope_block->node->data.block.statements.append(child_node); + block_node->data.block.statements.append(child_node); } if (out_node_scope != nullptr) { - *out_node_scope = &child_scope_block->base; + *out_node_scope = scope; } + return ErrorNone; +} + +static AstNode *trans_compound_stmt(Context *c, TransScope *scope, const CompoundStmt *stmt, + TransScope **out_node_scope) +{ + TransScopeBlock *child_scope_block = trans_scope_block_create(c, scope); + if (trans_compound_stmt_inline(c, &child_scope_block->base, stmt, child_scope_block->node, out_node_scope)) + return nullptr; return child_scope_block->node; } @@ -2081,17 +2105,18 @@ static int trans_local_declaration(Context *c, TransScope *scope, const DeclStmt } static AstNode *trans_while_loop(Context *c, TransScope *scope, const WhileStmt *stmt) { - AstNode *while_node = trans_create_node(c, NodeTypeWhileExpr); + TransScopeWhile *while_scope = trans_scope_while_create(c, scope); - while_node->data.while_expr.condition = trans_expr(c, ResultUsedYes, scope, stmt->getCond(), TransRValue); - if (while_node->data.while_expr.condition == nullptr) + while_scope->node->data.while_expr.condition = trans_expr(c, ResultUsedYes, scope, stmt->getCond(), TransRValue); + if (while_scope->node->data.while_expr.condition == nullptr) return nullptr; - TransScope *body_scope = trans_stmt(c, scope, stmt->getBody(), &while_node->data.while_expr.body); + TransScope *body_scope = trans_stmt(c, &while_scope->base, stmt->getBody(), + &while_scope->node->data.while_expr.body); if (body_scope == nullptr) return nullptr; - return while_node; + return while_scope->node; } static AstNode *trans_if_statement(Context *c, TransScope *scope, const IfStmt *stmt) { @@ -2201,11 +2226,9 @@ static AstNode *trans_unary_expr_or_type_trait_expr(Context *c, TransScope *scop } static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt *stmt) { - AstNode *while_node = trans_create_node(c, NodeTypeWhileExpr); + TransScopeWhile *while_scope = trans_scope_while_create(c, parent_scope); - AstNode *true_node = trans_create_node(c, NodeTypeBoolLiteral); - true_node->data.bool_literal.value = true; - while_node->data.while_expr.condition = true_node; + while_scope->node->data.while_expr.condition = trans_create_node_bool(c, true); AstNode *body_node; TransScope *child_scope; @@ -2222,7 +2245,7 @@ static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt // zig: } // We call the low level function so that we can set child_scope to the scope of the generated block. - if (trans_stmt_extra(c, parent_scope, stmt->getBody(), ResultUsedNo, TransRValue, &body_node, + if (trans_stmt_extra(c, &while_scope->base, stmt->getBody(), ResultUsedNo, TransRValue, &body_node, nullptr, &child_scope)) { return nullptr; @@ -2237,7 +2260,7 @@ static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt // zig: a; // zig: if (!cond) break; // zig: } - TransScopeBlock *child_block_scope = trans_scope_block_create(c, parent_scope); + TransScopeBlock *child_block_scope = trans_scope_block_create(c, &while_scope->base); body_node = child_block_scope->node; AstNode *child_statement; child_scope = trans_stmt(c, &child_block_scope->base, stmt->getBody(), &child_statement); @@ -2254,89 +2277,206 @@ static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt body_node->data.block.statements.append(terminator_node); - while_node->data.while_expr.body = body_node; + while_scope->node->data.while_expr.body = body_node; - return while_node; + return while_scope->node; } -//static AstNode *trans_switch_stmt(Context *c, TransScope *scope, const SwitchStmt *stmt) { -// AstNode *switch_block_node = trans_create_node(c, NodeTypeBlock); -// AstNode *switch_node = trans_create_node(c, NodeTypeSwitchExpr); -// const DeclStmt *var_decl_stmt = stmt->getConditionVariableDeclStmt(); -// if (var_decl_stmt != nullptr) { -// AstNode *vars_node = trans_stmt(c, switch_block_node, var_decl_stmt); -// if (vars_node == nullptr) -// return nullptr; -// if (vars_node != nullptr) -// switch_block_node->data.block.statements.append(vars_node); -// } -// switch_block_node->data.block.statements.append(switch_node); -// -// const Expr *cond_expr = stmt->getCond(); -// assert(cond_expr != nullptr); -// -// AstNode *expr_node = trans_expr(c, ResultUsedYes, switch_block_node, cond_expr, TransRValue); -// if (expr_node == nullptr) -// return nullptr; -// switch_node->data.switch_expr.expr = expr_node; -// -// AstNode *body_node = trans_stmt(c, switch_block_node, stmt->getBody()); -// if (body_node == nullptr) -// return nullptr; -// if (body_node != nullptr) -// switch_block_node->data.block.statements.append(body_node); -// -// return switch_block_node; -//} +static AstNode *trans_switch_stmt(Context *c, TransScope *parent_scope, const SwitchStmt *stmt) { + TransScopeWhile *while_scope = trans_scope_while_create(c, parent_scope); + while_scope->node->data.while_expr.condition = trans_create_node_bool(c, true); + + TransScopeBlock *block_scope = trans_scope_block_create(c, &while_scope->base); + while_scope->node->data.while_expr.body = block_scope->node; + + TransScopeSwitch *switch_scope; + + const DeclStmt *var_decl_stmt = stmt->getConditionVariableDeclStmt(); + if (var_decl_stmt == nullptr) { + switch_scope = trans_scope_switch_create(c, &block_scope->base); + } else { + AstNode *vars_node; + TransScope *var_scope = trans_stmt(c, &block_scope->base, var_decl_stmt, &vars_node); + if (var_scope == nullptr) + return nullptr; + if (vars_node != nullptr) + block_scope->node->data.block.statements.append(vars_node); + switch_scope = trans_scope_switch_create(c, var_scope); + } + block_scope->node->data.block.statements.append(switch_scope->switch_node); + + const Expr *cond_expr = stmt->getCond(); + assert(cond_expr != nullptr); + + AstNode *expr_node = trans_expr(c, ResultUsedYes, &block_scope->base, cond_expr, TransRValue); + if (expr_node == nullptr) + return nullptr; + switch_scope->switch_node->data.switch_expr.expr = expr_node; + + AstNode *body_node; + const Stmt *body_stmt = stmt->getBody(); + if (body_stmt->getStmtClass() == Stmt::CompoundStmtClass) { + if (trans_compound_stmt_inline(c, &switch_scope->base, (const CompoundStmt *)body_stmt, + block_scope->node, nullptr)) + { + return nullptr; + } + } else { + TransScope *body_scope = trans_stmt(c, &switch_scope->base, body_stmt, &body_node); + if (body_scope == nullptr) + return nullptr; + if (body_node != nullptr) + block_scope->node->data.block.statements.append(body_node); + } + + if (!switch_scope->found_default && !stmt->isAllEnumCasesCovered()) { + AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng); + prong_node->data.switch_prong.expr = trans_create_node(c, NodeTypeBreak); + switch_scope->switch_node->data.switch_expr.prongs.append(prong_node); + } + + // This is necessary if the last switch case "falls through" the end of the switch block + block_scope->node->data.block.statements.append(trans_create_node(c, NodeTypeBreak)); + + return while_scope->node; +} + +static int trans_switch_case(Context *c, TransScope *parent_scope, const CaseStmt *stmt, AstNode **out_node, + TransScope **out_scope) +{ + *out_node = nullptr; + + if (stmt->getRHS() != nullptr) { + emit_warning(c, stmt->getLocStart(), "TODO support GNU switch case a ... b extension"); + return ErrorUnexpected; + } + + TransScopeSwitch *switch_scope = trans_scope_switch_find(parent_scope); + assert(switch_scope != nullptr); + + Buf *label_name = buf_sprintf("case_%" PRIu32, switch_scope->case_index); + switch_scope->case_index += 1; + + { + // Add the prong + AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng); + AstNode *item_node = trans_expr(c, ResultUsedYes, &switch_scope->base, stmt->getLHS(), TransRValue); + if (item_node == nullptr) + return ErrorUnexpected; + prong_node->data.switch_prong.items.append(item_node); + + AstNode *goto_node = trans_create_node(c, NodeTypeGoto); + goto_node->data.goto_expr.name = label_name; + prong_node->data.switch_prong.expr = goto_node; + + switch_scope->switch_node->data.switch_expr.prongs.append(prong_node); + } + + AstNode *label_node = trans_create_node(c, NodeTypeLabel); + label_node->data.label.name = label_name; + + TransScopeBlock *scope_block = trans_scope_block_find(parent_scope); + scope_block->node->data.block.statements.append(label_node); + + AstNode *sub_stmt_node; + TransScope *new_scope = trans_stmt(c, parent_scope, stmt->getSubStmt(), &sub_stmt_node); + if (new_scope == nullptr) + return ErrorUnexpected; + if (sub_stmt_node != nullptr) + scope_block->node->data.block.statements.append(sub_stmt_node); + + *out_scope = new_scope; + return ErrorNone; +} + +static int trans_switch_default(Context *c, TransScope *parent_scope, const DefaultStmt *stmt, AstNode **out_node, + TransScope **out_scope) +{ + *out_node = nullptr; + + TransScopeSwitch *switch_scope = trans_scope_switch_find(parent_scope); + assert(switch_scope != nullptr); + + Buf *label_name = buf_sprintf("default"); + + AstNode *label_node = trans_create_node(c, NodeTypeLabel); + label_node->data.label.name = label_name; + + { + // Add the prong + AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng); + + AstNode *goto_node = trans_create_node(c, NodeTypeGoto); + goto_node->data.goto_expr.name = label_name; + prong_node->data.switch_prong.expr = goto_node; + + switch_scope->switch_node->data.switch_expr.prongs.append(prong_node); + switch_scope->found_default = true; + } + + TransScopeBlock *scope_block = trans_scope_block_find(parent_scope); + scope_block->node->data.block.statements.append(label_node); + + AstNode *sub_stmt_node; + TransScope *new_scope = trans_stmt(c, parent_scope, stmt->getSubStmt(), &sub_stmt_node); + if (new_scope == nullptr) + return ErrorUnexpected; + if (sub_stmt_node != nullptr) + scope_block->node->data.block.statements.append(sub_stmt_node); + + *out_scope = new_scope; + return ErrorNone; +} static AstNode *trans_for_loop(Context *c, TransScope *parent_scope, const ForStmt *stmt) { AstNode *loop_block_node; - TransScope *inner_scope; - AstNode *while_node = trans_create_node(c, NodeTypeWhileExpr); + TransScopeWhile *while_scope; + TransScope *cond_scope; const Stmt *init_stmt = stmt->getInit(); if (init_stmt == nullptr) { - loop_block_node = while_node; - inner_scope = parent_scope; + while_scope = trans_scope_while_create(c, parent_scope); + loop_block_node = while_scope->node; + cond_scope = parent_scope; } else { TransScopeBlock *child_scope = trans_scope_block_create(c, parent_scope); loop_block_node = child_scope->node; - inner_scope = &child_scope->base; AstNode *vars_node; - inner_scope = trans_stmt(c, &child_scope->base, init_stmt, &vars_node); - if (inner_scope == nullptr) + cond_scope = trans_stmt(c, &child_scope->base, init_stmt, &vars_node); + if (cond_scope == nullptr) return nullptr; if (vars_node != nullptr) child_scope->node->data.block.statements.append(vars_node); - child_scope->node->data.block.statements.append(while_node); + while_scope = trans_scope_while_create(c, cond_scope); + + child_scope->node->data.block.statements.append(while_scope->node); } const Stmt *cond_stmt = stmt->getCond(); if (cond_stmt == nullptr) { - AstNode *true_node = trans_create_node(c, NodeTypeBoolLiteral); - true_node->data.bool_literal.value = true; - while_node->data.while_expr.condition = true_node; + while_scope->node->data.while_expr.condition = trans_create_node_bool(c, true); } else { - TransScope *cond_scope = trans_stmt(c, inner_scope, cond_stmt, &while_node->data.while_expr.condition); - if (cond_scope == nullptr) + TransScope *end_cond_scope = trans_stmt(c, cond_scope, cond_stmt, + &while_scope->node->data.while_expr.condition); + if (end_cond_scope == nullptr) return nullptr; } const Stmt *inc_stmt = stmt->getInc(); if (inc_stmt != nullptr) { AstNode *inc_node; - TransScope *inc_scope = trans_stmt(c, inner_scope, inc_stmt, &inc_node); + TransScope *inc_scope = trans_stmt(c, cond_scope, inc_stmt, &inc_node); if (inc_scope == nullptr) return nullptr; - while_node->data.while_expr.continue_expr = inc_node; + while_scope->node->data.while_expr.continue_expr = inc_node; } - AstNode *child_statement; - TransScope *body_scope = trans_stmt(c, inner_scope, stmt->getBody(), &child_statement); + AstNode *body_statement; + TransScope *body_scope = trans_stmt(c, &while_scope->base, stmt->getBody(), &body_statement); if (body_scope == nullptr) return nullptr; - while_node->data.while_expr.body = child_statement; + while_scope->node->data.while_expr.body = body_statement; return loop_block_node; } @@ -2371,9 +2511,8 @@ static int wrap_stmt(AstNode **out_node, TransScope **out_scope, TransScope *in_ if (result_node == nullptr) return ErrorUnexpected; *out_node = result_node; - if (out_scope != nullptr) { + if (out_scope != nullptr) *out_scope = in_scope; - } return ErrorNone; } @@ -2456,18 +2595,13 @@ static int trans_stmt_extra(Context *c, TransScope *scope, const Stmt *stmt, case Stmt::ParenExprClass: return wrap_stmt(out_node, out_child_scope, scope, trans_expr(c, result_used, scope, ((const ParenExpr*)stmt)->getSubExpr(), lrvalue)); -// case Stmt::SwitchStmtClass: -// return wrap_stmt(out_node, out_child_scope, scope, -// trans_switch_stmt(c, scope, (const SwitchStmt *)stmt)); case Stmt::SwitchStmtClass: - emit_warning(c, stmt->getLocStart(), "TODO handle C SwitchStmtClass"); - return ErrorUnexpected; + return wrap_stmt(out_node, out_child_scope, scope, + trans_switch_stmt(c, scope, (const SwitchStmt *)stmt)); case Stmt::CaseStmtClass: - emit_warning(c, stmt->getLocStart(), "TODO handle C CaseStmtClass"); - return ErrorUnexpected; + return trans_switch_case(c, scope, (const CaseStmt *)stmt, out_node, out_child_scope); case Stmt::DefaultStmtClass: - emit_warning(c, stmt->getLocStart(), "TODO handle C DefaultStmtClass"); - return ErrorUnexpected; + return trans_switch_default(c, scope, (const DefaultStmt *)stmt, out_node, out_child_scope); case Stmt::NoStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C NoStmtClass"); return ErrorUnexpected; @@ -2981,7 +3115,8 @@ static AstNode *trans_expr(Context *c, ResultUsed result_used, TransScope *scope TransLRValue lrval) { AstNode *result_node; - if (trans_stmt_extra(c, scope, expr, result_used, lrval, &result_node, nullptr, nullptr)) { + TransScope *result_scope; + if (trans_stmt_extra(c, scope, expr, result_used, lrval, &result_node, &result_scope, nullptr)) { return nullptr; } return result_node; @@ -3536,6 +3671,14 @@ static TransScopeRoot *trans_scope_root_create(Context *c) { return result; } +static TransScopeWhile *trans_scope_while_create(Context *c, TransScope *parent_scope) { + TransScopeWhile *result = allocate(1); + result->base.id = TransScopeIdWhile; + result->base.parent = parent_scope; + result->node = trans_create_node(c, NodeTypeWhileExpr); + return result; +} + static TransScopeBlock *trans_scope_block_create(Context *c, TransScope *parent_scope) { TransScopeBlock *result = allocate(1); result->base.id = TransScopeIdBlock; @@ -3553,13 +3696,13 @@ static TransScopeVar *trans_scope_var_create(Context *c, TransScope *parent_scop return result; } -//static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope) { -// TransScopeSwitch *result = allocate(1); -// result->base.id = TransScopeIdSwitch; -// result->base.parent = parent_scope; -// result->switch_node = trans_create_node(c, NodeTypeSwitchExpr); -// return result; -//} +static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope) { + TransScopeSwitch *result = allocate(1); + result->base.id = TransScopeIdSwitch; + result->base.parent = parent_scope; + result->switch_node = trans_create_node(c, NodeTypeSwitchExpr); + return result; +} static TransScopeBlock *trans_scope_block_find(TransScope *scope) { while (scope != nullptr) { @@ -3571,6 +3714,16 @@ static TransScopeBlock *trans_scope_block_find(TransScope *scope) { return nullptr; } +static TransScopeSwitch *trans_scope_switch_find(TransScope *scope) { + while (scope != nullptr) { + if (scope->id == TransScopeIdSwitch) { + return (TransScopeSwitch *)scope; + } + scope = scope->parent; + } + return nullptr; +} + static void render_aliases(Context *c) { for (size_t i = 0; i < c->aliases.length; i += 1) { Alias *alias = &c->aliases.at(i); diff --git a/test/translate_c.zig b/test/translate_c.zig index b957ac4b05..e9f5e7de42 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -1001,46 +1001,46 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\} ); - //cases.add("switch statement", - // \\int foo(int x) { - // \\ switch (x) { - // \\ case 1: - // \\ x += 1; - // \\ case 2: - // \\ break; - // \\ case 3: - // \\ case 4: - // \\ return x + 1; - // \\ default: - // \\ return 10; - // \\ } - // \\ return x + 13; - // \\} - //, - // \\fn foo(_x: i32) -> i32 { - // \\ var x = _x; - // \\ switch (x) { - // \\ 1 => goto switch_case_1; - // \\ 2 => goto switch_case_2; - // \\ 3 => goto switch_case_3; - // \\ 4 => goto switch_case_4; - // \\ else => goto switch_default; - // \\ } - // \\switch_case_1: - // \\ x += 1; - // \\ goto switch_case_2; - // \\switch_case_2: - // \\ goto switch_end; - // \\switch_case_3: - // \\ goto switch_case_4; - // \\switch_case_4: - // \\ return x += 1; - // \\switch_default: - // \\ return 10; - // \\switch_end: - // \\ return x + 13; - // \\} - //); + cases.add("switch statement", + \\int foo(int x) { + \\ switch (x) { + \\ case 1: + \\ x += 1; + \\ case 2: + \\ break; + \\ case 3: + \\ case 4: + \\ return x + 1; + \\ default: + \\ return 10; + \\ } + \\ return x + 13; + \\} + , + \\fn foo(_arg_x: c_int) -> c_int { + \\ var x = _arg_x; + \\ while (true) { + \\ switch (x) { + \\ 1 => goto case_0, + \\ 2 => goto case_1, + \\ 3 => goto case_2, + \\ 4 => goto case_3, + \\ else => goto default, + \\ }; + \\ case_0: + \\ x += 1; + \\ case_1: + \\ break; + \\ case_2: + \\ case_3: + \\ return x + 1; + \\ default: + \\ return 10; + \\ break; + \\ }; + \\ return x + 13; + \\} + ); }