From 8c79438f6b76f1ad4b4941cdb46ae1e7aa12ce14 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 7 May 2016 10:05:59 -0700 Subject: [PATCH] better array concatenation semantics closes #87 --- src/all_types.hpp | 3 +- src/analyze.cpp | 105 ++++++++++++++++++++++++++++--------------- src/ast_render.cpp | 2 +- src/codegen.cpp | 4 +- src/eval.cpp | 4 +- src/parser.cpp | 2 +- std/str.zig | 3 ++ test/run_tests.cpp | 10 ++++- test/self_hosted.zig | 14 ++++++ 9 files changed, 103 insertions(+), 44 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 49e1cb510f..ce5fa98c54 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -54,6 +54,7 @@ struct ConstPtrValue { ConstExprValue **ptr; // len should almost always be 1. exceptions include C strings uint64_t len; + bool is_c_str; }; struct ConstErrValue { @@ -341,7 +342,7 @@ enum BinOpType { BinOpTypeDiv, BinOpTypeMod, BinOpTypeUnwrapMaybe, - BinOpTypeStrCat, + BinOpTypeArrayCat, BinOpTypeArrayMult, }; diff --git a/src/analyze.cpp b/src/analyze.cpp index 928f576947..66531052ab 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -2770,6 +2770,7 @@ static TypeTableEntry *resolve_expr_const_val_as_c_string_lit(CodeGen *g, AstNod int len_with_null = buf_len(str) + 1; expr->const_val.data.x_ptr.ptr = allocate(len_with_null); expr->const_val.data.x_ptr.len = len_with_null; + expr->const_val.data.x_ptr.is_c_str = true; ConstExprValue *all_chars = allocate(len_with_null); for (int i = 0; i < buf_len(str); i += 1) { @@ -2974,7 +2975,7 @@ static bool is_op_allowed(TypeTableEntry *type, BinOpType op) { case BinOpTypeDiv: case BinOpTypeMod: case BinOpTypeUnwrapMaybe: - case BinOpTypeStrCat: + case BinOpTypeArrayCat: case BinOpTypeArrayMult: zig_unreachable(); } @@ -3379,19 +3380,42 @@ static TypeTableEntry *analyze_bin_op_expr(CodeGen *g, ImportTableEntry *import, return g->builtin_types.entry_invalid; } } - case BinOpTypeStrCat: + case BinOpTypeArrayCat: { AstNode **op1 = node->data.bin_op_expr.op1->parent_field; AstNode **op2 = node->data.bin_op_expr.op2->parent_field; - TypeTableEntry *str_type = get_slice_type(g, g->builtin_types.entry_u8, true); + TypeTableEntry *op1_type = analyze_expression(g, import, context, nullptr, *op1); + TypeTableEntry *child_type; + if (op1_type->id == TypeTableEntryIdInvalid) { + return g->builtin_types.entry_invalid; + } else if (op1_type->id == TypeTableEntryIdArray) { + child_type = op1_type->data.array.child_type; + } else if (op1_type->id == TypeTableEntryIdPointer && + op1_type->data.pointer.child_type == g->builtin_types.entry_u8) { + child_type = op1_type->data.pointer.child_type; + } else { + add_node_error(g, *op1, buf_sprintf("expected array or C string literal, got '%s'", + buf_ptr(&op1_type->name))); + return g->builtin_types.entry_invalid; + } - TypeTableEntry *op1_type = analyze_expression(g, import, context, str_type, *op1); - TypeTableEntry *op2_type = analyze_expression(g, import, context, str_type, *op2); + TypeTableEntry *op2_type = analyze_expression(g, import, context, nullptr, *op2); - if (op1_type->id == TypeTableEntryIdInvalid || - op2_type->id == TypeTableEntryIdInvalid) - { + if (op2_type->id == TypeTableEntryIdInvalid) { + return g->builtin_types.entry_invalid; + } else if (op2_type->id == TypeTableEntryIdArray) { + if (op2_type->data.array.child_type != child_type) { + add_node_error(g, *op2, buf_sprintf("expected array of type '%s', got '%s'", + buf_ptr(&child_type->name), + buf_ptr(&op2_type->name))); + return g->builtin_types.entry_invalid; + } + } else if (op2_type->id == TypeTableEntryIdPointer && + op2_type->data.pointer.child_type == g->builtin_types.entry_u8) { + } else { + add_node_error(g, *op2, buf_sprintf("expected array or C string literal, got '%s'", + buf_ptr(&op2_type->name))); return g->builtin_types.entry_invalid; } @@ -3407,41 +3431,52 @@ static TypeTableEntry *analyze_bin_op_expr(CodeGen *g, ImportTableEntry *import, bad_node = nullptr; } if (bad_node) { - add_node_error(g, bad_node, buf_sprintf("string concatenation requires constant expression")); + add_node_error(g, bad_node, buf_sprintf("array concatenation requires constant expression")); return g->builtin_types.entry_invalid; } + ConstExprValue *const_val = &get_resolved_expr(node)->const_val; const_val->ok = true; const_val->depends_on_compile_var = op1_val->depends_on_compile_var || op2_val->depends_on_compile_var; - ConstExprValue *all_fields = allocate(2); - ConstExprValue *ptr_field = &all_fields[0]; - ConstExprValue *len_field = &all_fields[1]; - - const_val->data.x_struct.fields = allocate(2); - const_val->data.x_struct.fields[0] = ptr_field; - const_val->data.x_struct.fields[1] = len_field; - - len_field->ok = true; - uint64_t op1_len = op1_val->data.x_struct.fields[1]->data.x_bignum.data.x_uint; - uint64_t op2_len = op2_val->data.x_struct.fields[1]->data.x_bignum.data.x_uint; - uint64_t len = op1_len + op2_len; - bignum_init_unsigned(&len_field->data.x_bignum, len); - - ptr_field->ok = true; - ptr_field->data.x_ptr.ptr = allocate(len); - ptr_field->data.x_ptr.len = len; - - uint64_t i = 0; - for (uint64_t op1_i = 0; op1_i < op1_len; op1_i += 1, i += 1) { - ptr_field->data.x_ptr.ptr[i] = op1_val->data.x_struct.fields[0]->data.x_ptr.ptr[op1_i]; + if (op1_type->id == TypeTableEntryIdArray) { + uint64_t new_len = op1_type->data.array.len + op2_type->data.array.len; + const_val->data.x_array.fields = allocate(new_len); + uint64_t next_index = 0; + for (uint64_t i = 0; i < op1_type->data.array.len; i += 1, next_index += 1) { + const_val->data.x_array.fields[next_index] = op1_val->data.x_array.fields[i]; + } + for (uint64_t i = 0; i < op2_type->data.array.len; i += 1, next_index += 1) { + const_val->data.x_array.fields[next_index] = op2_val->data.x_array.fields[i]; + } + return get_array_type(g, child_type, new_len); + } else if (op1_type->id == TypeTableEntryIdPointer) { + if (!op1_val->data.x_ptr.is_c_str) { + add_node_error(g, *op1, + buf_sprintf("expected array or C string literal, got '%s'", + buf_ptr(&op1_type->name))); + return g->builtin_types.entry_invalid; + } else if (!op2_val->data.x_ptr.is_c_str) { + add_node_error(g, *op2, + buf_sprintf("expected array or C string literal, got '%s'", + buf_ptr(&op2_type->name))); + return g->builtin_types.entry_invalid; + } + const_val->data.x_ptr.is_c_str = true; + const_val->data.x_ptr.len = op1_val->data.x_ptr.len + op2_val->data.x_ptr.len - 1; + const_val->data.x_ptr.ptr = allocate(const_val->data.x_ptr.len); + uint64_t next_index = 0; + for (uint64_t i = 0; i < op1_val->data.x_ptr.len - 1; i += 1, next_index += 1) { + const_val->data.x_ptr.ptr[next_index] = op1_val->data.x_ptr.ptr[i]; + } + for (uint64_t i = 0; i < op2_val->data.x_ptr.len; i += 1, next_index += 1) { + const_val->data.x_ptr.ptr[next_index] = op2_val->data.x_ptr.ptr[i]; + } + return op1_type; + } else { + zig_unreachable(); } - for (uint64_t op2_i = 0; op2_i < op2_len; op2_i += 1, i += 1) { - ptr_field->data.x_ptr.ptr[i] = op2_val->data.x_struct.fields[0]->data.x_ptr.ptr[op2_i]; - } - - return str_type; } case BinOpTypeArrayMult: return analyze_array_mult(g, import, context, expected_type, node); diff --git a/src/ast_render.cpp b/src/ast_render.cpp index d9112715b5..36cb90793d 100644 --- a/src/ast_render.cpp +++ b/src/ast_render.cpp @@ -37,7 +37,7 @@ static const char *bin_op_str(BinOpType bin_op) { case BinOpTypeAssignBoolAnd: return "&&="; case BinOpTypeAssignBoolOr: return "||="; case BinOpTypeUnwrapMaybe: return "??"; - case BinOpTypeStrCat: return "++"; + case BinOpTypeArrayCat: return "++"; case BinOpTypeArrayMult: return "**"; } zig_unreachable(); diff --git a/src/codegen.cpp b/src/codegen.cpp index 8ecdd36358..6c73eff14a 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1665,7 +1665,7 @@ static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g, AstNode *source_node, case BinOpTypeAssignBoolAnd: case BinOpTypeAssignBoolOr: case BinOpTypeUnwrapMaybe: - case BinOpTypeStrCat: + case BinOpTypeArrayCat: case BinOpTypeArrayMult: zig_unreachable(); } @@ -1972,7 +1972,7 @@ static LLVMValueRef gen_unwrap_maybe_expr(CodeGen *g, AstNode *node) { static LLVMValueRef gen_bin_op_expr(CodeGen *g, AstNode *node) { switch (node->data.bin_op_expr.bin_op) { case BinOpTypeInvalid: - case BinOpTypeStrCat: + case BinOpTypeArrayCat: case BinOpTypeArrayMult: zig_unreachable(); case BinOpTypeAssign: diff --git a/src/eval.cpp b/src/eval.cpp index f7c4f5e2a6..8b7984b46c 100644 --- a/src/eval.cpp +++ b/src/eval.cpp @@ -296,7 +296,7 @@ int eval_const_expr_bin_op(ConstExprValue *op1_val, TypeTableEntry *op1_type, return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_mod, op1_type); case BinOpTypeUnwrapMaybe: zig_panic("TODO"); - case BinOpTypeStrCat: + case BinOpTypeArrayCat: case BinOpTypeArrayMult: case BinOpTypeInvalid: zig_unreachable(); @@ -345,7 +345,7 @@ static bool eval_bin_op_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_val) case BinOpTypeDiv: case BinOpTypeMod: case BinOpTypeUnwrapMaybe: - case BinOpTypeStrCat: + case BinOpTypeArrayCat: case BinOpTypeArrayMult: break; case BinOpTypeInvalid: diff --git a/src/parser.cpp b/src/parser.cpp index 2c25526ec6..5cba265ac0 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -1480,7 +1480,7 @@ static BinOpType tok_to_add_op(Token *token) { switch (token->id) { case TokenIdPlus: return BinOpTypeAdd; case TokenIdDash: return BinOpTypeSub; - case TokenIdPlusPlus: return BinOpTypeStrCat; + case TokenIdPlusPlus: return BinOpTypeArrayCat; default: return BinOpTypeInvalid; } } diff --git a/std/str.zig b/std/str.zig index ca0c17898f..060921d199 100644 --- a/std/str.zig +++ b/std/str.zig @@ -1,5 +1,8 @@ const assert = @import("index.zig").assert; +// fix https://github.com/andrewrk/zig/issues/140 +// and then make this able to run at compile time +#static_eval_enable(false) pub fn len(ptr: &const u8) -> isize { var count: isize = 0; while (ptr[count] != 0; count += 1) {} diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 130b99fd65..0fc63ee81f 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -999,11 +999,17 @@ extern fn foo() -> i32; const x = foo(); )SOURCE", 1, ".tmp_source.zig:3:11: error: global variable initializer requires constant expression"); - add_compile_fail_case("non compile time string concatenation", R"SOURCE( + add_compile_fail_case("array concatenation with wrong type", R"SOURCE( fn f(s: []u8) -> []u8 { s ++ "foo" } - )SOURCE", 1, ".tmp_source.zig:3:5: error: string concatenation requires constant expression"); + )SOURCE", 1, ".tmp_source.zig:3:5: error: expected array or C string literal, got '[]u8'"); + + add_compile_fail_case("non compile time array concatenation", R"SOURCE( +fn f(s: [10]u8) -> []u8 { + s ++ "foo" +} + )SOURCE", 1, ".tmp_source.zig:3:5: error: array concatenation requires constant expression"); add_compile_fail_case("c_import with bogus include", R"SOURCE( const c = @c_import(@c_include("bogus.h")); diff --git a/test/self_hosted.zig b/test/self_hosted.zig index 4c5a32b85a..705dadca3d 100644 --- a/test/self_hosted.zig +++ b/test/self_hosted.zig @@ -1551,3 +1551,17 @@ fn combine_non_wrap_with_wrap() { assert(@typeof(c) == i32w); assert(@typeof(d) == i32w); } + +#attribute("test") +fn c_string_concatenation() { + const a = c"OK" ++ c" IT " ++ c"WORKED"; + const b = c"OK IT WORKED"; + + const len = str.len(b); + const len_with_null = len + 1; + {var i: i32 = 0; while (i < len_with_null; i += 1) { + assert(a[i] == b[i]); + }} + assert(a[len] == 0); + assert(b[len] == 0); +}