diff --git a/src/codegen.cpp b/src/codegen.cpp index dd1f8a4722..5869ef8278 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -305,17 +305,9 @@ static void find_declarations(CodeGen *g, AstNode *node) { case NodeTypeReturnExpr: case NodeTypeRoot: case NodeTypeBlock: - case NodeTypeBoolOrExpr: + case NodeTypeBinOpExpr: case NodeTypeFnCall: case NodeTypeRootExportDecl: - case NodeTypeBoolAndExpr: - case NodeTypeComparisonExpr: - case NodeTypeBinOrExpr: - case NodeTypeBinXorExpr: - case NodeTypeBinAndExpr: - case NodeTypeBitShiftExpr: - case NodeTypeAddExpr: - case NodeTypeMultExpr: case NodeTypeCastExpr: case NodeTypePrimaryExpr: case NodeTypeGroupedExpr: @@ -481,10 +473,9 @@ static void analyze_node(CodeGen *g, AstNode *node) { analyze_node(g, node->data.return_expr.expr); } break; - case NodeTypeBoolOrExpr: - analyze_node(g, node->data.bool_or_expr.op1); - if (node->data.bool_or_expr.op2) - analyze_node(g, node->data.bool_or_expr.op2); + case NodeTypeBinOpExpr: + analyze_node(g, node->data.bin_op_expr.op1); + analyze_node(g, node->data.bin_op_expr.op2); break; case NodeTypeFnCall: { @@ -515,30 +506,6 @@ static void analyze_node(CodeGen *g, AstNode *node) { case NodeTypeDirective: // we looked at directives in the parent node break; - case NodeTypeBoolAndExpr: - zig_panic("TODO"); - break; - case NodeTypeComparisonExpr: - zig_panic("TODO"); - break; - case NodeTypeBinOrExpr: - zig_panic("TODO"); - break; - case NodeTypeBinXorExpr: - zig_panic("TODO"); - break; - case NodeTypeBinAndExpr: - zig_panic("TODO"); - break; - case NodeTypeBitShiftExpr: - zig_panic("TODO"); - break; - case NodeTypeAddExpr: - zig_panic("TODO"); - break; - case NodeTypeMultExpr: - zig_panic("TODO"); - break; case NodeTypeCastExpr: zig_panic("TODO"); break; @@ -752,168 +719,138 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) { } static LLVMValueRef gen_mult_expr(CodeGen *g, AstNode *node) { - assert(node->type == NodeTypeMultExpr); + assert(node->type == NodeTypeBinOpExpr); - LLVMValueRef val1 = gen_cast_expr(g, node->data.mult_expr.op1); + LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1); + LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2); - if (!node->data.mult_expr.op2) - return val1; - - LLVMValueRef val2 = gen_cast_expr(g, node->data.mult_expr.op2); - - switch (node->data.mult_expr.mult_op) { - case MultOpMult: + switch (node->data.bin_op_expr.bin_op) { + case BinOpTypeMult: // TODO types so we know float vs int add_debug_source_node(g, node); return LLVMBuildMul(g->builder, val1, val2, ""); - case MultOpDiv: + case BinOpTypeDiv: // TODO types so we know float vs int and signed vs unsigned add_debug_source_node(g, node); return LLVMBuildSDiv(g->builder, val1, val2, ""); - case MultOpMod: + case BinOpTypeMod: // TODO types so we know float vs int and signed vs unsigned add_debug_source_node(g, node); return LLVMBuildSRem(g->builder, val1, val2, ""); - case MultOpInvalid: + default: zig_unreachable(); } zig_unreachable(); } static LLVMValueRef gen_add_expr(CodeGen *g, AstNode *node) { - assert(node->type == NodeTypeAddExpr); + assert(node->type == NodeTypeBinOpExpr); - LLVMValueRef val1 = gen_mult_expr(g, node->data.add_expr.op1); + LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1); + LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2); - if (!node->data.add_expr.op2) - return val1; - - LLVMValueRef val2 = gen_mult_expr(g, node->data.add_expr.op2); - - switch (node->data.add_expr.add_op) { - case AddOpAdd: + switch (node->data.bin_op_expr.bin_op) { + case BinOpTypeAdd: add_debug_source_node(g, node); return LLVMBuildAdd(g->builder, val1, val2, ""); - case AddOpSub: + case BinOpTypeSub: add_debug_source_node(g, node); return LLVMBuildSub(g->builder, val1, val2, ""); - case AddOpInvalid: + default: zig_unreachable(); } zig_unreachable(); } static LLVMValueRef gen_bit_shift_expr(CodeGen *g, AstNode *node) { - assert(node->type == NodeTypeBitShiftExpr); + assert(node->type == NodeTypeBinOpExpr); - LLVMValueRef val1 = gen_add_expr(g, node->data.bit_shift_expr.op1); + LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1); + LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2); - if (!node->data.bit_shift_expr.op2) - return val1; - - LLVMValueRef val2 = gen_add_expr(g, node->data.bit_shift_expr.op2); - - switch (node->data.bit_shift_expr.bit_shift_op) { - case BitShiftOpLeft: + switch (node->data.bin_op_expr.bin_op) { + case BinOpTypeBitShiftLeft: add_debug_source_node(g, node); return LLVMBuildShl(g->builder, val1, val2, ""); - case BitShiftOpRight: + case BinOpTypeBitShiftRight: // TODO implement type system so that we know whether to do // logical or arithmetic shifting here. // signed -> arithmetic, unsigned -> logical add_debug_source_node(g, node); return LLVMBuildLShr(g->builder, val1, val2, ""); - case BitShiftOpInvalid: + default: zig_unreachable(); } zig_unreachable(); } static LLVMValueRef gen_bin_and_expr(CodeGen *g, AstNode *node) { - assert(node->type == NodeTypeBinAndExpr); + assert(node->type == NodeTypeBinOpExpr); - LLVMValueRef val1 = gen_bit_shift_expr(g, node->data.bin_and_expr.op1); - - if (!node->data.bin_and_expr.op2) - return val1; - - LLVMValueRef val2 = gen_bit_shift_expr(g, node->data.bin_and_expr.op2); + LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1); + LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2); add_debug_source_node(g, node); return LLVMBuildAnd(g->builder, val1, val2, ""); } static LLVMValueRef gen_bin_xor_expr(CodeGen *g, AstNode *node) { - assert(node->type == NodeTypeBinXorExpr); + assert(node->type == NodeTypeBinOpExpr); - LLVMValueRef val1 = gen_bin_and_expr(g, node->data.bin_xor_expr.op1); - - if (!node->data.bin_xor_expr.op2) - return val1; - - LLVMValueRef val2 = gen_bin_and_expr(g, node->data.bin_xor_expr.op2); + LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1); + LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2); add_debug_source_node(g, node); return LLVMBuildXor(g->builder, val1, val2, ""); } static LLVMValueRef gen_bin_or_expr(CodeGen *g, AstNode *node) { - assert(node->type == NodeTypeBinOrExpr); + assert(node->type == NodeTypeBinOpExpr); - LLVMValueRef val1 = gen_bin_xor_expr(g, node->data.bin_or_expr.op1); - - if (!node->data.bin_or_expr.op2) - return val1; - - LLVMValueRef val2 = gen_bin_xor_expr(g, node->data.bin_or_expr.op2); + LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1); + LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2); add_debug_source_node(g, node); return LLVMBuildOr(g->builder, val1, val2, ""); } -static LLVMIntPredicate cmp_op_to_int_predicate(CmpOp cmp_op, bool is_signed) { +static LLVMIntPredicate cmp_op_to_int_predicate(BinOpType cmp_op, bool is_signed) { switch (cmp_op) { - case CmpOpInvalid: + case BinOpTypeInvalid: zig_unreachable(); - case CmpOpEq: + case BinOpTypeCmpEq: return LLVMIntEQ; - case CmpOpNotEq: + case BinOpTypeCmpNotEq: return LLVMIntNE; - case CmpOpLessThan: + case BinOpTypeCmpLessThan: return is_signed ? LLVMIntSLT : LLVMIntULT; - case CmpOpGreaterThan: + case BinOpTypeCmpGreaterThan: return is_signed ? LLVMIntSGT : LLVMIntUGT; - case CmpOpLessOrEq: + case BinOpTypeCmpLessOrEq: return is_signed ? LLVMIntSLE : LLVMIntULE; - case CmpOpGreaterOrEq: + case BinOpTypeCmpGreaterOrEq: return is_signed ? LLVMIntSGE : LLVMIntUGE; + default: + zig_unreachable(); } - zig_unreachable(); } static LLVMValueRef gen_cmp_expr(CodeGen *g, AstNode *node) { - assert(node->type == NodeTypeComparisonExpr); + assert(node->type == NodeTypeBinOpExpr); - LLVMValueRef val1 = gen_bin_or_expr(g, node->data.comparison_expr.op1); - - if (!node->data.comparison_expr.op2) - return val1; - - LLVMValueRef val2 = gen_bin_or_expr(g, node->data.comparison_expr.op2); + LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1); + LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2); // TODO implement type system so that we know whether to do signed or unsigned comparison here - LLVMIntPredicate pred = cmp_op_to_int_predicate(node->data.comparison_expr.cmp_op, true); + LLVMIntPredicate pred = cmp_op_to_int_predicate(node->data.bin_op_expr.bin_op, true); add_debug_source_node(g, node); return LLVMBuildICmp(g->builder, pred, val1, val2, ""); } static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) { - assert(node->type == NodeTypeBoolAndExpr); + assert(node->type == NodeTypeBinOpExpr); - LLVMValueRef val1 = gen_cmp_expr(g, node->data.bool_and_expr.op1); - - if (!node->data.bool_and_expr.op2) - return val1; + LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1); // block for when val1 == true LLVMBasicBlockRef true_block = LLVMAppendBasicBlock(g->cur_fn, "BoolAndTrue"); @@ -926,7 +863,7 @@ static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) { LLVMBuildCondBr(g->builder, val1_i1, false_block, true_block); LLVMPositionBuilderAtEnd(g->builder, true_block); - LLVMValueRef val2 = gen_cmp_expr(g, node->data.bool_and_expr.op2); + LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2); add_debug_source_node(g, node); LLVMValueRef val2_i1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, ""); @@ -942,12 +879,9 @@ static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) { } static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) { - assert(expr_node->type == NodeTypeBoolOrExpr); + assert(expr_node->type == NodeTypeBinOpExpr); - LLVMValueRef val1 = gen_bool_and_expr(g, expr_node->data.bool_or_expr.op1); - - if (!expr_node->data.bool_or_expr.op2) - return val1; + LLVMValueRef val1 = gen_expr(g, expr_node->data.bin_op_expr.op1); // block for when val1 == false LLVMBasicBlockRef false_block = LLVMAppendBasicBlock(g->cur_fn, "BoolOrFalse"); @@ -960,7 +894,7 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) { LLVMBuildCondBr(g->builder, val1_i1, false_block, true_block); LLVMPositionBuilderAtEnd(g->builder, false_block); - LLVMValueRef val2 = gen_bool_and_expr(g, expr_node->data.bool_or_expr.op2); + LLVMValueRef val2 = gen_expr(g, expr_node->data.bin_op_expr.op2); add_debug_source_node(g, expr_node); LLVMValueRef val2_i1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, ""); @@ -975,6 +909,41 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) { return phi; } +static LLVMValueRef gen_bin_op_expr(CodeGen *g, AstNode *node) { + switch (node->data.bin_op_expr.bin_op) { + case BinOpTypeInvalid: + zig_unreachable(); + case BinOpTypeBoolOr: + return gen_bool_or_expr(g, node); + case BinOpTypeBoolAnd: + return gen_bool_and_expr(g, node); + case BinOpTypeCmpEq: + case BinOpTypeCmpNotEq: + case BinOpTypeCmpLessThan: + case BinOpTypeCmpGreaterThan: + case BinOpTypeCmpLessOrEq: + case BinOpTypeCmpGreaterOrEq: + return gen_cmp_expr(g, node); + case BinOpTypeBinOr: + return gen_bin_or_expr(g, node); + case BinOpTypeBinXor: + return gen_bin_xor_expr(g, node); + case BinOpTypeBinAnd: + return gen_bin_and_expr(g, node); + case BinOpTypeBitShiftLeft: + case BinOpTypeBitShiftRight: + return gen_bit_shift_expr(g, node); + case BinOpTypeAdd: + case BinOpTypeSub: + return gen_add_expr(g, node); + case BinOpTypeMult: + case BinOpTypeDiv: + case BinOpTypeMod: + return gen_mult_expr(g, node); + } + zig_unreachable(); +} + static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) { assert(node->type == NodeTypeReturnExpr); AstNode *param_node = node->data.return_expr.expr; @@ -993,10 +962,12 @@ Expression : BoolOrExpression | ReturnExpression */ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) { switch (node->type) { - case NodeTypeBoolOrExpr: - return gen_bool_or_expr(g, node); + case NodeTypeBinOpExpr: + return gen_bin_op_expr(g, node); case NodeTypeReturnExpr: return gen_return_expr(g, node); + case NodeTypeCastExpr: + return gen_cast_expr(g, node); case NodeTypeRoot: case NodeTypeRootExportDecl: case NodeTypeFnProto: @@ -1008,15 +979,6 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) { case NodeTypeFnCall: case NodeTypeExternBlock: case NodeTypeDirective: - case NodeTypeBoolAndExpr: - case NodeTypeComparisonExpr: - case NodeTypeBinOrExpr: - case NodeTypeBinXorExpr: - case NodeTypeBinAndExpr: - case NodeTypeBitShiftExpr: - case NodeTypeAddExpr: - case NodeTypeMultExpr: - case NodeTypeCastExpr: case NodeTypePrimaryExpr: return gen_primary_expr(g, node); case NodeTypeGroupedExpr: diff --git a/src/parser.cpp b/src/parser.cpp index 07ade572f0..693cc6d2c0 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -10,43 +10,27 @@ #include #include -static const char *mult_op_str(MultOp mult_op) { - switch (mult_op) { - case MultOpInvalid: return "(invalid)"; - case MultOpMult: return "*"; - case MultOpDiv: return "/"; - case MultOpMod: return "%"; - } - zig_unreachable(); -} - -static const char *add_op_str(AddOp add_op) { - switch (add_op) { - case AddOpInvalid: return "(invalid)"; - case AddOpAdd: return "+"; - case AddOpSub: return "-"; - } - zig_unreachable(); -} - -static const char *bit_shift_op_str(BitShiftOp bit_shift_op) { - switch (bit_shift_op) { - case BitShiftOpInvalid: return "(invalid)"; - case BitShiftOpLeft: return "<<"; - case BitShiftOpRight: return ">>"; - } - zig_unreachable(); -} - -static const char *cmp_op_str(CmpOp cmp_op) { - switch (cmp_op) { - case CmpOpInvalid: return "(invalid)"; - case CmpOpEq: return "="; - case CmpOpNotEq: return "!="; - case CmpOpLessThan: return "<"; - case CmpOpGreaterThan: return ">"; - case CmpOpLessOrEq: return "<="; - case CmpOpGreaterOrEq: return ">="; +static const char *bin_op_str(BinOpType bin_op) { + switch (bin_op) { + case BinOpTypeInvalid: return "(invalid)"; + case BinOpTypeBoolOr: return "||"; + case BinOpTypeBoolAnd: return "&&"; + case BinOpTypeCmpEq: return "=="; + case BinOpTypeCmpNotEq: return "!="; + case BinOpTypeCmpLessThan: return "<"; + case BinOpTypeCmpGreaterThan: return ">"; + case BinOpTypeCmpLessOrEq: return "<="; + case BinOpTypeCmpGreaterOrEq: return ">="; + case BinOpTypeBinOr: return "|"; + case BinOpTypeBinXor: return "^"; + case BinOpTypeBinAnd: return "&"; + case BinOpTypeBitShiftLeft: return "<<"; + case BinOpTypeBitShiftRight: return ">>"; + case BinOpTypeAdd: return "+"; + case BinOpTypeSub: return "-"; + case BinOpTypeMult: return "*"; + case BinOpTypeDiv: return "/"; + case BinOpTypeMod: return "%"; } zig_unreachable(); } @@ -84,8 +68,8 @@ const char *node_type_str(NodeType node_type) { return "Type"; case NodeTypeBlock: return "Block"; - case NodeTypeBoolOrExpr: - return "BoolOrExpr"; + case NodeTypeBinOpExpr: + return "BinOpExpr"; case NodeTypeFnCall: return "FnCall"; case NodeTypeExternBlock: @@ -94,22 +78,6 @@ const char *node_type_str(NodeType node_type) { return "Directive"; case NodeTypeReturnExpr: return "ReturnExpr"; - case NodeTypeBoolAndExpr: - return "BoolAndExpr"; - case NodeTypeComparisonExpr: - return "ComparisonExpr"; - case NodeTypeBinOrExpr: - return "BinOrExpr"; - case NodeTypeBinXorExpr: - return "BinXorExpr"; - case NodeTypeBinAndExpr: - return "BinAndExpr"; - case NodeTypeBitShiftExpr: - return "BitShiftExpr"; - case NodeTypeAddExpr: - return "AddExpr"; - case NodeTypeMultExpr: - return "MultExpr"; case NodeTypeCastExpr: return "CastExpr"; case NodeTypePrimaryExpr: @@ -214,11 +182,11 @@ void ast_print(AstNode *node, int indent) { fprintf(stderr, "%s\n", node_type_str(node->type)); ast_print(node->data.fn_decl.fn_proto, indent + 2); break; - case NodeTypeBoolOrExpr: - fprintf(stderr, "%s\n", node_type_str(node->type)); - ast_print(node->data.bool_or_expr.op1, indent + 2); - if (node->data.bool_or_expr.op2) - ast_print(node->data.bool_or_expr.op2, indent + 2); + case NodeTypeBinOpExpr: + fprintf(stderr, "%s %s\n", node_type_str(node->type), + bin_op_str(node->data.bin_op_expr.bin_op)); + ast_print(node->data.bin_op_expr.op1, indent + 2); + ast_print(node->data.bin_op_expr.op2, indent + 2); break; case NodeTypeFnCall: fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.fn_call.name)); @@ -230,58 +198,6 @@ void ast_print(AstNode *node, int indent) { case NodeTypeDirective: fprintf(stderr, "%s\n", node_type_str(node->type)); break; - case NodeTypeBoolAndExpr: - fprintf(stderr, "%s\n", node_type_str(node->type)); - ast_print(node->data.bool_and_expr.op1, indent + 2); - if (node->data.bool_and_expr.op2) - ast_print(node->data.bool_and_expr.op2, indent + 2); - break; - case NodeTypeComparisonExpr: - fprintf(stderr, "%s %s\n", node_type_str(node->type), - cmp_op_str(node->data.comparison_expr.cmp_op)); - ast_print(node->data.comparison_expr.op1, indent + 2); - if (node->data.comparison_expr.op2) - ast_print(node->data.comparison_expr.op2, indent + 2); - break; - case NodeTypeBinOrExpr: - fprintf(stderr, "%s\n", node_type_str(node->type)); - ast_print(node->data.bin_or_expr.op1, indent + 2); - if (node->data.bin_or_expr.op2) - ast_print(node->data.bin_or_expr.op2, indent + 2); - break; - case NodeTypeBinXorExpr: - fprintf(stderr, "%s\n", node_type_str(node->type)); - ast_print(node->data.bin_xor_expr.op1, indent + 2); - if (node->data.bin_xor_expr.op2) - ast_print(node->data.bin_xor_expr.op2, indent + 2); - break; - case NodeTypeBinAndExpr: - fprintf(stderr, "%s\n", node_type_str(node->type)); - ast_print(node->data.bin_and_expr.op1, indent + 2); - if (node->data.bin_and_expr.op2) - ast_print(node->data.bin_and_expr.op2, indent + 2); - break; - case NodeTypeBitShiftExpr: - fprintf(stderr, "%s %s\n", node_type_str(node->type), - bit_shift_op_str(node->data.bit_shift_expr.bit_shift_op)); - ast_print(node->data.bit_shift_expr.op1, indent + 2); - if (node->data.bit_shift_expr.op2) - ast_print(node->data.bit_shift_expr.op2, indent + 2); - break; - case NodeTypeAddExpr: - fprintf(stderr, "%s %s\n", node_type_str(node->type), - add_op_str(node->data.add_expr.add_op)); - ast_print(node->data.add_expr.op1, indent + 2); - if (node->data.add_expr.op2) - ast_print(node->data.add_expr.op2, indent + 2); - break; - case NodeTypeMultExpr: - fprintf(stderr, "%s %s\n", node_type_str(node->type), - mult_op_str(node->data.mult_expr.mult_op)); - ast_print(node->data.mult_expr.op1, indent + 2); - if (node->data.mult_expr.op2) - ast_print(node->data.mult_expr.op2, indent + 2); - break; case NodeTypeCastExpr: fprintf(stderr, "%s\n", node_type_str(node->type)); ast_print(node->data.cast_expr.primary_expr, indent + 2); @@ -709,26 +625,26 @@ static AstNode *ast_parse_cast_expression(ParseContext *pc, int *token_index, bo return node; } -static MultOp tok_to_mult_op(Token *token) { +static BinOpType tok_to_mult_op(Token *token) { switch (token->id) { - case TokenIdStar: return MultOpMult; - case TokenIdSlash: return MultOpDiv; - case TokenIdPercent: return MultOpMod; - default: return MultOpInvalid; + case TokenIdStar: return BinOpTypeMult; + case TokenIdSlash: return BinOpTypeDiv; + case TokenIdPercent: return BinOpTypeMod; + default: return BinOpTypeInvalid; } } /* MultiplyOperator : token(Star) | token(Slash) | token(Percent) */ -static MultOp ast_parse_mult_op(ParseContext *pc, int *token_index, bool mandatory) { +static BinOpType ast_parse_mult_op(ParseContext *pc, int *token_index, bool mandatory) { Token *token = &pc->tokens->at(*token_index); - MultOp result = tok_to_mult_op(token); - if (result == MultOpInvalid) { + BinOpType result = tok_to_mult_op(token); + if (result == BinOpTypeInvalid) { if (mandatory) { ast_invalid_token_error(pc, token); } else { - return MultOpInvalid; + return BinOpTypeInvalid; } } *token_index += 1; @@ -744,39 +660,39 @@ static AstNode *ast_parse_mult_expr(ParseContext *pc, int *token_index, bool man return nullptr; Token *token = &pc->tokens->at(*token_index); - MultOp mult_op = ast_parse_mult_op(pc, token_index, false); - if (mult_op == MultOpInvalid) + BinOpType mult_op = ast_parse_mult_op(pc, token_index, false); + if (mult_op == BinOpTypeInvalid) return operand_1; AstNode *operand_2 = ast_parse_cast_expression(pc, token_index, true); - AstNode *node = ast_create_node(NodeTypeMultExpr, token); - node->data.mult_expr.op1 = operand_1; - node->data.mult_expr.mult_op = mult_op; - node->data.mult_expr.op2 = operand_2; + AstNode *node = ast_create_node(NodeTypeBinOpExpr, token); + node->data.bin_op_expr.op1 = operand_1; + node->data.bin_op_expr.bin_op = mult_op; + node->data.bin_op_expr.op2 = operand_2; return node; } -static AddOp tok_to_add_op(Token *token) { +static BinOpType tok_to_add_op(Token *token) { switch (token->id) { - case TokenIdPlus: return AddOpAdd; - case TokenIdDash: return AddOpSub; - default: return AddOpInvalid; + case TokenIdPlus: return BinOpTypeAdd; + case TokenIdDash: return BinOpTypeSub; + default: return BinOpTypeInvalid; } } /* AdditionOperator : token(Plus) | token(Minus) */ -static AddOp ast_parse_add_op(ParseContext *pc, int *token_index, bool mandatory) { +static BinOpType ast_parse_add_op(ParseContext *pc, int *token_index, bool mandatory) { Token *token = &pc->tokens->at(*token_index); - AddOp result = tok_to_add_op(token); - if (result == AddOpInvalid) { + BinOpType result = tok_to_add_op(token); + if (result == BinOpTypeInvalid) { if (mandatory) { ast_invalid_token_error(pc, token); } else { - return AddOpInvalid; + return BinOpTypeInvalid; } } *token_index += 1; @@ -792,39 +708,39 @@ static AstNode *ast_parse_add_expr(ParseContext *pc, int *token_index, bool mand return nullptr; Token *token = &pc->tokens->at(*token_index); - AddOp add_op = ast_parse_add_op(pc, token_index, false); - if (add_op == AddOpInvalid) + BinOpType add_op = ast_parse_add_op(pc, token_index, false); + if (add_op == BinOpTypeInvalid) return operand_1; AstNode *operand_2 = ast_parse_mult_expr(pc, token_index, true); - AstNode *node = ast_create_node(NodeTypeAddExpr, token); - node->data.add_expr.op1 = operand_1; - node->data.add_expr.add_op = add_op; - node->data.add_expr.op2 = operand_2; + AstNode *node = ast_create_node(NodeTypeBinOpExpr, token); + node->data.bin_op_expr.op1 = operand_1; + node->data.bin_op_expr.bin_op = add_op; + node->data.bin_op_expr.op2 = operand_2; return node; } -static BitShiftOp tok_to_bit_shift_op(Token *token) { +static BinOpType tok_to_bit_shift_op(Token *token) { switch (token->id) { - case TokenIdBitShiftLeft: return BitShiftOpLeft; - case TokenIdBitShiftRight: return BitShiftOpRight; - default: return BitShiftOpInvalid; + case TokenIdBitShiftLeft: return BinOpTypeBitShiftLeft; + case TokenIdBitShiftRight: return BinOpTypeBitShiftRight; + default: return BinOpTypeInvalid; } } /* BitShiftOperator : token(BitShiftLeft | token(BitShiftRight) */ -static BitShiftOp ast_parse_bit_shift_op(ParseContext *pc, int *token_index, bool mandatory) { +static BinOpType ast_parse_bit_shift_op(ParseContext *pc, int *token_index, bool mandatory) { Token *token = &pc->tokens->at(*token_index); - BitShiftOp result = tok_to_bit_shift_op(token); - if (result == BitShiftOpInvalid) { + BinOpType result = tok_to_bit_shift_op(token); + if (result == BinOpTypeInvalid) { if (mandatory) { ast_invalid_token_error(pc, token); } else { - return BitShiftOpInvalid; + return BinOpTypeInvalid; } } *token_index += 1; @@ -840,16 +756,16 @@ static AstNode *ast_parse_bit_shift_expr(ParseContext *pc, int *token_index, boo return nullptr; Token *token = &pc->tokens->at(*token_index); - BitShiftOp bit_shift_op = ast_parse_bit_shift_op(pc, token_index, false); - if (bit_shift_op == BitShiftOpInvalid) + BinOpType bit_shift_op = ast_parse_bit_shift_op(pc, token_index, false); + if (bit_shift_op == BinOpTypeInvalid) return operand_1; AstNode *operand_2 = ast_parse_add_expr(pc, token_index, true); - AstNode *node = ast_create_node(NodeTypeBitShiftExpr, token); - node->data.bit_shift_expr.op1 = operand_1; - node->data.bit_shift_expr.bit_shift_op = bit_shift_op; - node->data.bit_shift_expr.op2 = operand_2; + AstNode *node = ast_create_node(NodeTypeBinOpExpr, token); + node->data.bin_op_expr.op1 = operand_1; + node->data.bin_op_expr.bin_op = bit_shift_op; + node->data.bin_op_expr.op2 = operand_2; return node; } @@ -870,9 +786,10 @@ static AstNode *ast_parse_bin_and_expr(ParseContext *pc, int *token_index, bool AstNode *operand_2 = ast_parse_bit_shift_expr(pc, token_index, true); - AstNode *node = ast_create_node(NodeTypeBinAndExpr, token); - node->data.bin_and_expr.op1 = operand_1; - node->data.bin_and_expr.op2 = operand_2; + AstNode *node = ast_create_node(NodeTypeBinOpExpr, token); + node->data.bin_op_expr.op1 = operand_1; + node->data.bin_op_expr.bin_op = BinOpTypeBinAnd; + node->data.bin_op_expr.op2 = operand_2; return node; } @@ -892,9 +809,10 @@ static AstNode *ast_parse_bin_xor_expr(ParseContext *pc, int *token_index, bool AstNode *operand_2 = ast_parse_bin_and_expr(pc, token_index, true); - AstNode *node = ast_create_node(NodeTypeBinXorExpr, token); - node->data.bin_xor_expr.op1 = operand_1; - node->data.bin_xor_expr.op2 = operand_2; + AstNode *node = ast_create_node(NodeTypeBinOpExpr, token); + node->data.bin_op_expr.op1 = operand_1; + node->data.bin_op_expr.bin_op = BinOpTypeBinXor; + node->data.bin_op_expr.op2 = operand_2; return node; } @@ -914,33 +832,34 @@ static AstNode *ast_parse_bin_or_expr(ParseContext *pc, int *token_index, bool m AstNode *operand_2 = ast_parse_bin_xor_expr(pc, token_index, true); - AstNode *node = ast_create_node(NodeTypeBinOrExpr, token); - node->data.bin_or_expr.op1 = operand_1; - node->data.bin_or_expr.op2 = operand_2; + AstNode *node = ast_create_node(NodeTypeBinOpExpr, token); + node->data.bin_op_expr.op1 = operand_1; + node->data.bin_op_expr.bin_op = BinOpTypeBinOr; + node->data.bin_op_expr.op2 = operand_2; return node; } -static CmpOp tok_to_cmp_op(Token *token) { +static BinOpType tok_to_cmp_op(Token *token) { switch (token->id) { - case TokenIdCmpEq: return CmpOpEq; - case TokenIdCmpNotEq: return CmpOpNotEq; - case TokenIdCmpLessThan: return CmpOpLessThan; - case TokenIdCmpGreaterThan: return CmpOpGreaterThan; - case TokenIdCmpLessOrEq: return CmpOpLessOrEq; - case TokenIdCmpGreaterOrEq: return CmpOpGreaterOrEq; - default: return CmpOpInvalid; + case TokenIdCmpEq: return BinOpTypeCmpEq; + case TokenIdCmpNotEq: return BinOpTypeCmpNotEq; + case TokenIdCmpLessThan: return BinOpTypeCmpLessThan; + case TokenIdCmpGreaterThan: return BinOpTypeCmpGreaterThan; + case TokenIdCmpLessOrEq: return BinOpTypeCmpLessOrEq; + case TokenIdCmpGreaterOrEq: return BinOpTypeCmpGreaterOrEq; + default: return BinOpTypeInvalid; } } -static CmpOp ast_parse_comparison_operator(ParseContext *pc, int *token_index, bool mandatory) { +static BinOpType ast_parse_comparison_operator(ParseContext *pc, int *token_index, bool mandatory) { Token *token = &pc->tokens->at(*token_index); - CmpOp result = tok_to_cmp_op(token); - if (result == CmpOpInvalid) { + BinOpType result = tok_to_cmp_op(token); + if (result == BinOpTypeInvalid) { if (mandatory) { ast_invalid_token_error(pc, token); } else { - return CmpOpInvalid; + return BinOpTypeInvalid; } } *token_index += 1; @@ -956,16 +875,16 @@ static AstNode *ast_parse_comparison_expr(ParseContext *pc, int *token_index, bo return nullptr; Token *token = &pc->tokens->at(*token_index); - CmpOp cmp_op = ast_parse_comparison_operator(pc, token_index, false); - if (cmp_op == CmpOpInvalid) + BinOpType cmp_op = ast_parse_comparison_operator(pc, token_index, false); + if (cmp_op == BinOpTypeInvalid) return operand_1; AstNode *operand_2 = ast_parse_bin_or_expr(pc, token_index, true); - AstNode *node = ast_create_node(NodeTypeComparisonExpr, token); - node->data.comparison_expr.op1 = operand_1; - node->data.comparison_expr.cmp_op = cmp_op; - node->data.comparison_expr.op2 = operand_2; + AstNode *node = ast_create_node(NodeTypeBinOpExpr, token); + node->data.bin_op_expr.op1 = operand_1; + node->data.bin_op_expr.bin_op = cmp_op; + node->data.bin_op_expr.op2 = operand_2; return node; } @@ -985,9 +904,10 @@ static AstNode *ast_parse_bool_and_expr(ParseContext *pc, int *token_index, bool AstNode *operand_2 = ast_parse_comparison_expr(pc, token_index, true); - AstNode *node = ast_create_node(NodeTypeBoolAndExpr, token); - node->data.bool_and_expr.op1 = operand_1; - node->data.bool_and_expr.op2 = operand_2; + AstNode *node = ast_create_node(NodeTypeBinOpExpr, token); + node->data.bin_op_expr.op1 = operand_1; + node->data.bin_op_expr.bin_op = BinOpTypeBoolAnd; + node->data.bin_op_expr.op2 = operand_2; return node; } @@ -1024,9 +944,10 @@ static AstNode *ast_parse_bool_or_expr(ParseContext *pc, int *token_index, bool AstNode *operand_2 = ast_parse_bool_and_expr(pc, token_index, true); - AstNode *node = ast_create_node(NodeTypeBoolOrExpr, token); - node->data.bool_or_expr.op1 = operand_1; - node->data.bool_or_expr.op2 = operand_2; + AstNode *node = ast_create_node(NodeTypeBinOpExpr, token); + node->data.bin_op_expr.op1 = operand_1; + node->data.bin_op_expr.bin_op = BinOpTypeBoolOr; + node->data.bin_op_expr.op2 = operand_2; return node; } diff --git a/src/parser.hpp b/src/parser.hpp index 626152a973..b946aa33cd 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -28,15 +28,7 @@ enum NodeType { NodeTypeExternBlock, NodeTypeDirective, NodeTypeReturnExpr, - NodeTypeBoolOrExpr, - NodeTypeBoolAndExpr, - NodeTypeComparisonExpr, - NodeTypeBinOrExpr, - NodeTypeBinXorExpr, - NodeTypeBinAndExpr, - NodeTypeBitShiftExpr, - NodeTypeAddExpr, - NodeTypeMultExpr, + NodeTypeBinOpExpr, NodeTypeCastExpr, NodeTypePrimaryExpr, NodeTypeGroupedExpr, @@ -96,9 +88,32 @@ struct AstNodeReturnExpr { AstNode *expr; }; -struct AstNodeBoolOrExpr { +enum BinOpType { + BinOpTypeInvalid, + // TODO: include assignment? + BinOpTypeBoolOr, + BinOpTypeBoolAnd, + BinOpTypeCmpEq, + BinOpTypeCmpNotEq, + BinOpTypeCmpLessThan, + BinOpTypeCmpGreaterThan, + BinOpTypeCmpLessOrEq, + BinOpTypeCmpGreaterOrEq, + BinOpTypeBinOr, + BinOpTypeBinXor, + BinOpTypeBinAnd, + BinOpTypeBitShiftLeft, + BinOpTypeBitShiftRight, + BinOpTypeAdd, + BinOpTypeSub, + BinOpTypeMult, + BinOpTypeDiv, + BinOpTypeMod, +}; + +struct AstNodeBinOpExpr { AstNode *op1; - // if op2 is non-null, do boolean or, otherwise nothing + BinOpType bin_op; AstNode *op2; }; @@ -122,87 +137,6 @@ struct AstNodeRootExportDecl { Buf name; }; -struct AstNodeBoolAndExpr { - AstNode *op1; - // if op2 is non-null, do boolean and, otherwise nothing - AstNode *op2; -}; - -enum CmpOp { - CmpOpInvalid, - CmpOpEq, - CmpOpNotEq, - CmpOpLessThan, - CmpOpGreaterThan, - CmpOpLessOrEq, - CmpOpGreaterOrEq, -}; - -struct AstNodeComparisonExpr { - AstNode *op1; - CmpOp cmp_op; - // if op2 is non-null, do cmp_op, otherwise nothing - AstNode *op2; -}; - -struct AstNodeBinOrExpr { - AstNode *op1; - // if op2 is non-null, do binary or, otherwise nothing - AstNode *op2; -}; - -struct AstNodeBinXorExpr { - AstNode *op1; - // if op2 is non-null, do binary xor, otherwise nothing - AstNode *op2; -}; - -struct AstNodeBinAndExpr { - AstNode *op1; - // if op2 is non-null, do binary and, otherwise nothing - AstNode *op2; -}; - -enum BitShiftOp { - BitShiftOpInvalid, - BitShiftOpLeft, - BitShiftOpRight, -}; - -struct AstNodeBitShiftExpr { - AstNode *op1; - BitShiftOp bit_shift_op; - // if op2 is non-null, do bit_shift_op, otherwise nothing - AstNode *op2; -}; - -enum AddOp { - AddOpInvalid, - AddOpAdd, - AddOpSub, -}; - -struct AstNodeAddExpr { - AstNode *op1; - AddOp add_op; - // if op2 is non-null, do add_op, otherwise nothing - AstNode *op2; -}; - -enum MultOp { - MultOpInvalid, - MultOpMult, - MultOpDiv, - MultOpMod, -}; - -struct AstNodeMultExpr { - AstNode *op1; - MultOp mult_op; - // if op2 is non-null, do mult_op, otherwise nothing - AstNode *op2; -}; - struct AstNodeCastExpr { AstNode *primary_expr; // if type is non-null, do cast, otherwise nothing @@ -249,18 +183,10 @@ struct AstNode { AstNodeParamDecl param_decl; AstNodeBlock block; AstNodeReturnExpr return_expr; - AstNodeBoolOrExpr bool_or_expr; + AstNodeBinOpExpr bin_op_expr; AstNodeFnCall fn_call; AstNodeExternBlock extern_block; AstNodeDirective directive; - AstNodeBoolAndExpr bool_and_expr; - AstNodeComparisonExpr comparison_expr; - AstNodeBinOrExpr bin_or_expr; - AstNodeBinXorExpr bin_xor_expr; - AstNodeBinAndExpr bin_and_expr; - AstNodeBitShiftExpr bit_shift_expr; - AstNodeAddExpr add_expr; - AstNodeMultExpr mult_expr; AstNodeCastExpr cast_expr; AstNodePrimaryExpr primary_expr; AstNodeGroupedExpr grouped_expr;