From 9aea99a999997e223307d8559e0ff9fa613839a3 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 7 Jan 2016 05:29:11 -0700 Subject: [PATCH] implement array slicing syntax closes #52 --- doc/langref.md | 4 ++- src/analyze.cpp | 55 ++++++++++++++++++++++++++++++++++++- src/analyze.hpp | 2 ++ src/codegen.cpp | 68 ++++++++++++++++++++++++++++++++++++++++------ src/parser.cpp | 50 ++++++++++++++++++++++++++++------ src/parser.hpp | 9 ++++++ test/run_tests.cpp | 29 ++++++++++++++++++++ 7 files changed, 199 insertions(+), 18 deletions(-) diff --git a/doc/langref.md b/doc/langref.md index 140fb97d5f..7d8efef421 100644 --- a/doc/langref.md +++ b/doc/langref.md @@ -148,7 +148,7 @@ CastExpression : CastExpression token(as) Type | PrefixOpExpression PrefixOpExpression : PrefixOp PrefixOpExpression | SuffixOpExpression -SuffixOpExpression : PrimaryExpression option(FnCallExpression | ArrayAccessExpression | FieldAccessExpression) +SuffixOpExpression : PrimaryExpression option(FnCallExpression | ArrayAccessExpression | FieldAccessExpression | SliceExpression) FieldAccessExpression : token(Dot) token(Symbol) @@ -156,6 +156,8 @@ FnCallExpression : token(LParen) list(Expression, token(Comma)) token(RParen) ArrayAccessExpression : token(LBracket) Expression token(RBracket) +SliceExpression : token(LBracket) Expression token(Ellipsis) option(Expression) token(RBracket) option(token(Const)) + PrefixOp : token(Not) | token(Dash) | token(Tilde) | token(Star) | (token(Ampersand) option(token(Const))) PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType diff --git a/src/analyze.cpp b/src/analyze.cpp index 9d6fc01e77..826f04c57d 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -23,6 +23,8 @@ static AstNode *first_executing_node(AstNode *node) { return first_executing_node(node->data.bin_op_expr.op1); case NodeTypeArrayAccessExpr: return first_executing_node(node->data.array_access_expr.array_ref_expr); + case NodeTypeSliceExpr: + return first_executing_node(node->data.slice_expr.array_ref_expr); case NodeTypeFieldAccessExpr: return first_executing_node(node->data.field_access_expr.struct_expr); case NodeTypeCastExpr: @@ -875,6 +877,7 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import, case NodeTypeBinOpExpr: case NodeTypeFnCallExpr: case NodeTypeArrayAccessExpr: + case NodeTypeSliceExpr: case NodeTypeNumberLiteral: case NodeTypeStringLiteral: case NodeTypeCharLiteral: @@ -950,6 +953,7 @@ static void preview_types(CodeGen *g, ImportTableEntry *import, AstNode *node) { case NodeTypeBinOpExpr: case NodeTypeFnCallExpr: case NodeTypeArrayAccessExpr: + case NodeTypeSliceExpr: case NodeTypeNumberLiteral: case NodeTypeStringLiteral: case NodeTypeCharLiteral: @@ -1349,6 +1353,50 @@ static TypeTableEntry *analyze_field_access_expr(CodeGen *g, ImportTableEntry *i return return_type; } +static TypeTableEntry *analyze_slice_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, + AstNode *node) +{ + TypeTableEntry *array_type = analyze_expression(g, import, context, nullptr, + node->data.slice_expr.array_ref_expr); + + TypeTableEntry *return_type; + + if (array_type->id == TypeTableEntryIdInvalid) { + return_type = g->builtin_types.entry_invalid; + } else if (array_type->id == TypeTableEntryIdArray) { + return_type = get_unknown_size_array_type(g, import, array_type->data.array.child_type, + node->data.slice_expr.is_const); + } else if (array_type->id == TypeTableEntryIdPointer) { + return_type = get_unknown_size_array_type(g, import, array_type->data.pointer.child_type, + node->data.slice_expr.is_const); + } else if (array_type->id == TypeTableEntryIdStruct && + array_type->data.structure.is_unknown_size_array) + { + return_type = get_unknown_size_array_type(g, import, + array_type->data.structure.fields[0].type_entry, + node->data.slice_expr.is_const); + } else { + add_node_error(g, node, + buf_sprintf("slice of non-array type '%s'", buf_ptr(&array_type->name))); + return_type = g->builtin_types.entry_invalid; + } + + if (return_type->id != TypeTableEntryIdInvalid) { + assert(node->codegen_node); + node->codegen_node->data.struct_val_expr_node.type_entry = return_type; + node->codegen_node->data.struct_val_expr_node.source_node = node; + context->struct_val_expr_alloca_list.append(&node->codegen_node->data.struct_val_expr_node); + } + + analyze_expression(g, import, context, g->builtin_types.entry_usize, node->data.slice_expr.start); + + if (node->data.slice_expr.end) { + analyze_expression(g, import, context, g->builtin_types.entry_usize, node->data.slice_expr.end); + } + + return return_type; +} + static TypeTableEntry *analyze_array_access_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, AstNode *node) { @@ -1363,7 +1411,8 @@ static TypeTableEntry *analyze_array_access_expr(CodeGen *g, ImportTableEntry *i return_type = array_type->data.pointer.child_type; } else { if (array_type->id != TypeTableEntryIdInvalid) { - add_node_error(g, node, buf_sprintf("array access of non-array")); + add_node_error(g, node, + buf_sprintf("array access of non-array type '%s'", buf_ptr(&array_type->name))); } return_type = g->builtin_types.entry_invalid; } @@ -2197,6 +2246,9 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, // for reading array access; assignment handled elsewhere return_type = analyze_array_access_expr(g, import, context, node); break; + case NodeTypeSliceExpr: + return_type = analyze_slice_expr(g, import, context, node); + break; case NodeTypeFieldAccessExpr: return_type = analyze_field_access_expr(g, import, context, node); break; @@ -2541,6 +2593,7 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import, case NodeTypeBinOpExpr: case NodeTypeFnCallExpr: case NodeTypeArrayAccessExpr: + case NodeTypeSliceExpr: case NodeTypeNumberLiteral: case NodeTypeStringLiteral: case NodeTypeCharLiteral: diff --git a/src/analyze.hpp b/src/analyze.hpp index d14ae714d7..c34ec84c05 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -355,9 +355,11 @@ struct CodeGenNode { StructDeclNode struct_decl_node; // for NodeTypeStructDecl FieldAccessNode field_access_node; // for NodeTypeFieldAccessExpr CastNode cast_node; // for NodeTypeCastExpr + // note: I've been using this field on some non-number literal nodes too. NumberLiteralNode num_lit_node; // for NodeTypeNumberLiteral VarDeclNode var_decl_node; // for NodeTypeVariableDeclaration StructValFieldNode struct_val_field_node; // for NodeTypeStructValueField + // note: I've been using this field on some non-struct val expressions too. StructValExprNode struct_val_expr_node; // for NodeTypeStructValueExpr IfVarNode if_var_node; // for NodeTypeStructValueExpr ParamDeclNode param_decl_node; // for NodeTypeParamDecl diff --git a/src/codegen.cpp b/src/codegen.cpp index 0f4499c48d..850afd71a4 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -215,26 +215,34 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) { } } -static LLVMValueRef gen_array_ptr(CodeGen *g, AstNode *node) { - assert(node->type == NodeTypeArrayAccessExpr); - - AstNode *array_expr_node = node->data.array_access_expr.array_ref_expr; - TypeTableEntry *type_entry = get_expr_type(array_expr_node); +static LLVMValueRef gen_array_base_ptr(CodeGen *g, AstNode *node) { + TypeTableEntry *type_entry = get_expr_type(node); LLVMValueRef array_ptr; - if (array_expr_node->type == NodeTypeFieldAccessExpr) { - array_ptr = gen_field_access_expr(g, array_expr_node, true); + if (node->type == NodeTypeFieldAccessExpr) { + array_ptr = gen_field_access_expr(g, node, true); if (type_entry->id == TypeTableEntryIdPointer) { // we have a double pointer so we must dereference it once add_debug_source_node(g, node); array_ptr = LLVMBuildLoad(g->builder, array_ptr, ""); } } else { - array_ptr = gen_expr(g, array_expr_node); + array_ptr = gen_expr(g, node); } assert(LLVMGetTypeKind(LLVMTypeOf(array_ptr)) == LLVMPointerTypeKind); + return array_ptr; +} + +static LLVMValueRef gen_array_ptr(CodeGen *g, AstNode *node) { + assert(node->type == NodeTypeArrayAccessExpr); + + AstNode *array_expr_node = node->data.array_access_expr.array_ref_expr; + TypeTableEntry *type_entry = get_expr_type(array_expr_node); + + LLVMValueRef array_ptr = gen_array_base_ptr(g, array_expr_node); + LLVMValueRef subscript_value = gen_expr(g, node->data.array_access_expr.subscript); assert(subscript_value); @@ -299,6 +307,48 @@ static LLVMValueRef gen_field_ptr(CodeGen *g, AstNode *node, TypeTableEntry **ou return LLVMBuildStructGEP(g->builder, struct_ptr, codegen_field_access->field_index, ""); } +static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) { + assert(node->type == NodeTypeSliceExpr); + + AstNode *array_ref_node = node->data.slice_expr.array_ref_expr; + TypeTableEntry *array_type = get_expr_type(array_ref_node); + + LLVMValueRef tmp_struct_ptr = node->codegen_node->data.struct_val_expr_node.ptr; + + if (array_type->id == TypeTableEntryIdArray) { + LLVMValueRef array_ptr = gen_array_base_ptr(g, array_ref_node); + LLVMValueRef start_val = gen_expr(g, node->data.slice_expr.start); + LLVMValueRef end_val; + if (node->data.slice_expr.end) { + end_val = gen_expr(g, node->data.slice_expr.end); + } else { + end_val = LLVMConstInt(g->builtin_types.entry_usize->type_ref, array_type->data.array.len, false); + } + + add_debug_source_node(g, node); + LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 0, ""); + LLVMValueRef indices[] = { + LLVMConstNull(g->builtin_types.entry_usize->type_ref), + start_val, + }; + LLVMValueRef slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, indices, 2, ""); + LLVMBuildStore(g->builder, slice_start_ptr, ptr_field_ptr); + + LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 1, ""); + LLVMValueRef len_value = LLVMBuildSub(g->builder, end_val, start_val, ""); + LLVMBuildStore(g->builder, len_value, len_field_ptr); + + return tmp_struct_ptr; + } else if (array_type->id == TypeTableEntryIdPointer) { + zig_panic("TODO gen_slice_expr pointer"); + } else if (array_type->id == TypeTableEntryIdStruct) { + assert(array_type->data.structure.is_unknown_size_array); + zig_panic("TODO gen_slice_expr unknown size array"); + } else { + zig_unreachable(); + } +} + static LLVMValueRef gen_array_access_expr(CodeGen *g, AstNode *node, bool is_lvalue) { assert(node->type == NodeTypeArrayAccessExpr); @@ -1443,6 +1493,8 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) { return gen_fn_call_expr(g, node); case NodeTypeArrayAccessExpr: return gen_array_access_expr(g, node, false); + case NodeTypeSliceExpr: + return gen_slice_expr(g, node); case NodeTypeFieldAccessExpr: return gen_field_access_expr(g, node, false); case NodeTypeUnreachable: diff --git a/src/parser.cpp b/src/parser.cpp index 1f78424fb1..e8da95bb3d 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -90,6 +90,8 @@ const char *node_type_str(NodeType node_type) { return "FnCallExpr"; case NodeTypeArrayAccessExpr: return "ArrayAccessExpr"; + case NodeTypeSliceExpr: + return "SliceExpr"; case NodeTypeExternBlock: return "ExternBlock"; case NodeTypeDirective: @@ -298,6 +300,14 @@ void ast_print(AstNode *node, int indent) { ast_print(node->data.array_access_expr.array_ref_expr, indent + 2); ast_print(node->data.array_access_expr.subscript, indent + 2); break; + case NodeTypeSliceExpr: + fprintf(stderr, "%s\n", node_type_str(node->type)); + ast_print(node->data.slice_expr.array_ref_expr, indent + 2); + ast_print(node->data.slice_expr.start, indent + 2); + if (node->data.slice_expr.end) { + ast_print(node->data.slice_expr.end, indent + 2); + } + break; case NodeTypeDirective: fprintf(stderr, "%s\n", node_type_str(node->type)); break; @@ -1381,9 +1391,10 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool } /* -SuffixOpExpression : PrimaryExpression option(FnCallExpression | ArrayAccessExpression | FieldAccessExpression) +SuffixOpExpression : PrimaryExpression option(FnCallExpression | ArrayAccessExpression | FieldAccessExpression | SliceExpression) FnCallExpression : token(LParen) list(Expression, token(Comma)) token(RParen) ArrayAccessExpression : token(LBracket) Expression token(RBracket) +SliceExpression : token(LBracket) Expression token(Ellipsis) option(Expression) token(RBracket) option(token(Const)) FieldAccessExpression : token(Dot) token(Symbol) */ static AstNode *ast_parse_suffix_op_expr(ParseContext *pc, int *token_index, bool mandatory) { @@ -1405,15 +1416,38 @@ static AstNode *ast_parse_suffix_op_expr(ParseContext *pc, int *token_index, boo } else if (token->id == TokenIdLBracket) { *token_index += 1; - AstNode *node = ast_create_node(pc, NodeTypeArrayAccessExpr, token); - node->data.array_access_expr.array_ref_expr = primary_expr; - node->data.array_access_expr.subscript = ast_parse_expression(pc, token_index, true); + AstNode *expr_node = ast_parse_expression(pc, token_index, true); - Token *r_bracket = &pc->tokens->at(*token_index); - *token_index += 1; - ast_expect_token(pc, r_bracket, TokenIdRBracket); + Token *ellipsis_or_r_bracket = &pc->tokens->at(*token_index); - primary_expr = node; + if (ellipsis_or_r_bracket->id == TokenIdEllipsis) { + *token_index += 1; + + AstNode *node = ast_create_node(pc, NodeTypeSliceExpr, token); + node->data.slice_expr.array_ref_expr = primary_expr; + node->data.slice_expr.start = expr_node; + node->data.slice_expr.end = ast_parse_expression(pc, token_index, false); + + ast_eat_token(pc, token_index, TokenIdRBracket); + + Token *const_tok = &pc->tokens->at(*token_index); + if (const_tok->id == TokenIdKeywordConst) { + *token_index += 1; + node->data.slice_expr.is_const = true; + } + + primary_expr = node; + } else if (ellipsis_or_r_bracket->id == TokenIdRBracket) { + *token_index += 1; + + AstNode *node = ast_create_node(pc, NodeTypeArrayAccessExpr, token); + node->data.array_access_expr.array_ref_expr = primary_expr; + node->data.array_access_expr.subscript = expr_node; + + primary_expr = node; + } else { + ast_invalid_token_error(pc, token); + } } else if (token->id == TokenIdDot) { *token_index += 1; diff --git a/src/parser.hpp b/src/parser.hpp index 4edeb67ed4..75d0f822ef 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -41,6 +41,7 @@ enum NodeType { NodeTypePrefixOpExpr, NodeTypeFnCallExpr, NodeTypeArrayAccessExpr, + NodeTypeSliceExpr, NodeTypeFieldAccessExpr, NodeTypeUse, NodeTypeVoid, @@ -181,6 +182,13 @@ struct AstNodeArrayAccessExpr { AstNode *subscript; }; +struct AstNodeSliceExpr { + AstNode *array_ref_expr; + AstNode *start; + AstNode *end; + bool is_const; +}; + struct AstNodeFieldAccessExpr { AstNode *struct_expr; Buf field_name; @@ -378,6 +386,7 @@ struct AstNode { AstNodePrefixOpExpr prefix_op_expr; AstNodeFnCallExpr fn_call_expr; AstNodeArrayAccessExpr array_access_expr; + AstNodeSliceExpr slice_expr; AstNodeUse use; AstNodeIfBoolExpr if_bool_expr; AstNodeIfVarExpr if_var_expr; diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 1038983d2e..71158e9327 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -907,6 +907,35 @@ pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 { "min i16: -32768\n" "min i32: -2147483648\n" "min i64: -9223372036854775808\n"); + + + add_simple_case("slicing", R"SOURCE( +use "std.zig"; +pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 { + var array : [20]i32; + + array[5] = 1234; + + var slice = array[5...10]; + + if (slice.len != 5) { + print_str("BAD\n"); + } + + if (slice.ptr[0] != 1234) { + print_str("BAD\n"); + } + + var slice_rest = array[10...]; + if (slice_rest.len != 10) { + print_str("BAD\n"); + } + + print_str("OK\n"); + return 0; +} + )SOURCE", "OK\n"); + } ////////////////////////////////////////////////////////////////////////////////////