diff --git a/doc/langref.md b/doc/langref.md index e5090db2c3..65cfe71fc8 100644 --- a/doc/langref.md +++ b/doc/langref.md @@ -25,7 +25,7 @@ UseDecl = "use" Expression ";" ExternDecl = "extern" (FnProto | VariableDeclaration) ";" -FnProto = "fn" option("Symbol") ParamDeclList option("->" TypeExpr) +FnProto = "fn" option("Symbol") option(ParamDeclList) ParamDeclList option("->" TypeExpr) Directive = "#" "Symbol" "(" Expression ")" diff --git a/src/all_types.hpp b/src/all_types.hpp index f7fd0a46a4..7677232d8f 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -193,8 +193,10 @@ struct AstNodeRoot { struct AstNodeFnProto { TopLevelDecl top_level_decl; Buf name; + ZigList generic_params; ZigList params; AstNode *return_type; + bool generic_params_is_var_args; bool is_var_args; bool is_extern; bool is_inline; @@ -206,6 +208,7 @@ struct AstNodeFnProto { FnTableEntry *fn_table_entry; bool skip; Expr resolved_expr; + TypeTableEntry *generic_fn_type; }; struct AstNodeFnDef { @@ -797,6 +800,21 @@ struct FnTypeParamInfo { TypeTableEntry *type; }; +struct GenericParamValue { + TypeTableEntry *type; + AstNode *node; +}; + +struct GenericFnTypeId { + AstNode *decl_node; // the generic fn or container decl node + GenericParamValue *generic_params; + int generic_param_count; +}; + +uint32_t generic_fn_type_id_hash(GenericFnTypeId *id); +bool generic_fn_type_id_eql(GenericFnTypeId *a, GenericFnTypeId *b); + + static const int fn_type_id_prealloc_param_info_count = 4; struct FnTypeId { TypeTableEntry *return_type; @@ -812,7 +830,6 @@ struct FnTypeId { uint32_t fn_type_id_hash(FnTypeId*); bool fn_type_id_eql(FnTypeId *a, FnTypeId *b); - struct TypeTableEntryPointer { TypeTableEntry *child_type; bool is_const; @@ -899,6 +916,10 @@ struct TypeTableEntryFn { LLVMCallConv calling_convention; }; +struct TypeTableEntryGenericFn { + AstNode *decl_node; +}; + struct TypeTableEntryTypeDecl { TypeTableEntry *child_type; TypeTableEntry *canonical_type; @@ -925,6 +946,7 @@ enum TypeTableEntryId { TypeTableEntryIdFn, TypeTableEntryIdTypeDecl, TypeTableEntryIdNamespace, + TypeTableEntryIdGenericFn, }; struct TypeTableEntry { @@ -947,6 +969,7 @@ struct TypeTableEntry { TypeTableEntryEnum enumeration; TypeTableEntryFn fn; TypeTableEntryTypeDecl type_decl; + TypeTableEntryGenericFn generic_fn; } data; // use these fields to make sure we don't duplicate type table entries for the same type @@ -992,6 +1015,7 @@ struct FnTableEntry { bool internal_linkage; bool is_extern; bool is_test; + BlockContext *parent_block_context; ZigList cast_alloca_list; ZigList struct_val_expr_alloca_list; @@ -1047,6 +1071,7 @@ struct CodeGen { HashMap primitive_type_table; HashMap fn_type_table; HashMap error_table; + HashMap generic_table; ZigList import_queue; int import_queue_index; @@ -1172,7 +1197,10 @@ struct VariableTableEntry { LLVMValueRef value_ref; bool is_const; bool is_ptr; // if true, value_ref is a pointer + // which node is the declaration of the variable AstNode *decl_node; + // which node contains the ConstExprValue for this variable's value + AstNode *val_node; LLVMZigDILocalVariable *di_loc_var; int src_arg_index; int gen_arg_index; diff --git a/src/analyze.cpp b/src/analyze.cpp index 7055fffb08..beffacc456 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -41,6 +41,7 @@ static VariableTableEntry *analyze_variable_declaration_raw(CodeGen *g, ImportTa AstNodeVariableDeclaration *variable_declaration, bool expr_is_maybe, AstNode *decl_node); static void scan_decls(CodeGen *g, ImportTableEntry *import, BlockContext *context, AstNode *node); +static void analyze_fn_body(CodeGen *g, FnTableEntry *fn_table_entry); static AstNode *first_executing_node(AstNode *node) { switch (node->type) { @@ -192,6 +193,7 @@ static bool type_is_complete(TypeTableEntry *type_entry) { case TypeTableEntryIdFn: case TypeTableEntryIdTypeDecl: case TypeTableEntryIdNamespace: + case TypeTableEntryIdGenericFn: return true; } zig_unreachable(); @@ -201,6 +203,14 @@ TypeTableEntry *get_smallest_unsigned_int_type(CodeGen *g, uint64_t x) { return get_int_type(g, false, bits_needed_for_unsigned(x)); } +static TypeTableEntry *get_generic_fn_type(CodeGen *g, AstNode *decl_node) { + TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdGenericFn); + buf_init_from_str(&entry->name, "(generic function)"); + entry->zero_bits = true; + entry->data.generic_fn.decl_node = decl_node; + return entry; +} + TypeTableEntry *get_pointer_to_type(CodeGen *g, TypeTableEntry *child_type, bool is_const) { assert(child_type->id != TypeTableEntryIdInvalid); TypeTableEntry **parent_pointer = &child_type->pointer_parent[(is_const ? 1 : 0)]; @@ -776,7 +786,7 @@ static TypeTableEntry *analyze_fn_proto_type(CodeGen *g, ImportTableEntry *impor } fn_type_id.is_var_args = fn_proto->is_var_args; - fn_type_id.return_type = analyze_type_expr(g, import, import->block_context, node->data.fn_proto.return_type); + fn_type_id.return_type = analyze_type_expr(g, import, context, node->data.fn_proto.return_type); if (fn_type_id.return_type->id == TypeTableEntryIdInvalid) { fn_proto->skip = true; @@ -785,7 +795,7 @@ static TypeTableEntry *analyze_fn_proto_type(CodeGen *g, ImportTableEntry *impor for (int i = 0; i < fn_type_id.param_count; i += 1) { AstNode *child = node->data.fn_proto.params.at(i); assert(child->type == NodeTypeParamDecl); - TypeTableEntry *type_entry = analyze_type_expr(g, import, import->block_context, + TypeTableEntry *type_entry = analyze_type_expr(g, import, context, child->data.param_decl.type); switch (type_entry->id) { case TypeTableEntryIdInvalid: @@ -797,6 +807,7 @@ static TypeTableEntry *analyze_fn_proto_type(CodeGen *g, ImportTableEntry *impor case TypeTableEntryIdMetaType: case TypeTableEntryIdUnreachable: case TypeTableEntryIdNamespace: + case TypeTableEntryIdGenericFn: fn_proto->skip = true; add_node_error(g, child->data.param_decl.type, buf_sprintf("parameter of type '%s' not allowed'", buf_ptr(&type_entry->name))); @@ -880,7 +891,7 @@ static bool resolve_const_expr_bool(CodeGen *g, ImportTableEntry *import, BlockC } static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_table_entry, - ImportTableEntry *import) + ImportTableEntry *import, BlockContext *containing_context) { assert(node->type == NodeTypeFnProto); AstNodeFnProto *fn_proto = &node->data.fn_proto; @@ -946,7 +957,7 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t - TypeTableEntry *fn_type = analyze_fn_proto_type(g, import, import->block_context, nullptr, node, + TypeTableEntry *fn_type = analyze_fn_proto_type(g, import, containing_context, nullptr, node, is_naked, is_cold); fn_table_entry->type_entry = fn_type; @@ -963,6 +974,8 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t } else { symbol_name = buf_sprintf("_%s", buf_ptr(&fn_table_entry->symbol_name)); } + // TODO mangle the name if it's a generic instance + fn_table_entry->fn_value = LLVMAddFunction(g->module, buf_ptr(symbol_name), fn_type->data.fn.raw_type_ref); @@ -992,12 +1005,12 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t unsigned flags = 0; bool is_optimized = g->is_release_build; LLVMZigDISubprogram *subprogram = LLVMZigCreateFunction(g->dbuilder, - import->block_context->di_scope, buf_ptr(&fn_table_entry->symbol_name), "", + containing_context->di_scope, buf_ptr(&fn_table_entry->symbol_name), "", import->di_file, line_number, fn_type->di_type, fn_table_entry->internal_linkage, is_definition, scope_line, flags, is_optimized, nullptr); - BlockContext *context = new_block_context(fn_table_entry->fn_def_node, import->block_context); + BlockContext *context = new_block_context(fn_table_entry->fn_def_node, containing_context); fn_table_entry->fn_def_node->data.fn_def.block_context = context; context->di_scope = LLVMZigSubprogramToScope(subprogram); } @@ -1321,17 +1334,35 @@ static void get_fully_qualified_decl_name(Buf *buf, AstNode *decl_node, uint8_t } } -static void preview_fn_proto(CodeGen *g, ImportTableEntry *import, AstNode *proto_node) { +static void preview_generic_fn_proto(CodeGen *g, ImportTableEntry *import, AstNode *node) { + assert(node->type == NodeTypeFnProto); + + if (node->data.fn_proto.generic_params_is_var_args) { + add_node_error(g, node, buf_sprintf("generic parameters cannot be var args")); + node->data.fn_proto.skip = true; + node->data.fn_proto.generic_fn_type = g->builtin_types.entry_invalid; + return; + } + + node->data.fn_proto.generic_fn_type = get_generic_fn_type(g, node); +} + +static void preview_fn_proto_instance(CodeGen *g, ImportTableEntry *import, AstNode *proto_node, + BlockContext *containing_context) +{ if (proto_node->data.fn_proto.skip) { return; } + bool is_generic_instance = (proto_node->data.fn_proto.generic_params.length > 0); + AstNode *parent_decl = proto_node->data.fn_proto.top_level_decl.parent_decl; + Buf *proto_name = &proto_node->data.fn_proto.name; AstNode *fn_def_node = proto_node->data.fn_proto.fn_def_node; bool is_extern = proto_node->data.fn_proto.is_extern; - Buf *proto_name = &proto_node->data.fn_proto.name; + assert(!is_extern || !is_generic_instance); if (!is_extern && proto_node->data.fn_proto.is_var_args) { add_node_error(g, proto_node, @@ -1352,13 +1383,24 @@ static void preview_fn_proto(CodeGen *g, ImportTableEntry *import, AstNode *prot g->fn_defs.append(fn_table_entry); } - bool is_main_fn = !parent_decl && (import == g->root_import) && buf_eql_str(proto_name, "main"); + bool is_main_fn = !is_generic_instance && + !parent_decl && (import == g->root_import) && + buf_eql_str(proto_name, "main"); if (is_main_fn) { g->main_fn = fn_table_entry; } proto_node->data.fn_proto.fn_table_entry = fn_table_entry; - resolve_function_proto(g, proto_node, fn_table_entry, import); + resolve_function_proto(g, proto_node, fn_table_entry, import, containing_context); +} + +static void preview_fn_proto(CodeGen *g, ImportTableEntry *import, AstNode *proto_node) { + if (proto_node->data.fn_proto.generic_params.length > 0) { + return preview_generic_fn_proto(g, import, proto_node); + } else { + return preview_fn_proto_instance(g, import, proto_node, import->block_context); + } + } static void preview_error_value_decl(CodeGen *g, AstNode *node) { @@ -1539,6 +1581,7 @@ static bool type_has_codegen_value(TypeTableEntry *type_entry) { case TypeTableEntryIdNumLitInt: case TypeTableEntryIdUndefLit: case TypeTableEntryIdNamespace: + case TypeTableEntryIdGenericFn: return false; case TypeTableEntryIdBool: @@ -2433,6 +2476,15 @@ static TypeTableEntry *resolve_expr_const_val_as_fn(CodeGen *g, AstNode *node, F return fn->type_entry; } +static TypeTableEntry *resolve_expr_const_val_as_generic_fn(CodeGen *g, AstNode *node, + TypeTableEntry *type_entry) +{ + Expr *expr = get_resolved_expr(node); + expr->const_val.ok = true; + expr->const_val.data.x_type = type_entry; + return type_entry; +} + static TypeTableEntry *resolve_expr_const_val_as_err(CodeGen *g, AstNode *node, ErrorTableEntry *err) { Expr *expr = get_resolved_expr(node); expr->const_val.ok = true; @@ -2570,14 +2622,10 @@ static TypeTableEntry *analyze_error_literal_expr(CodeGen *g, ImportTableEntry * static TypeTableEntry *analyze_var_ref(CodeGen *g, AstNode *source_node, VariableTableEntry *var) { get_resolved_expr(source_node)->variable = var; - if (var->is_const) { - AstNode *decl_node = var->decl_node; - if (decl_node->type == NodeTypeVariableDeclaration) { - AstNode *expr_node = decl_node->data.variable_declaration.expr; - ConstExprValue *other_const_val = &get_resolved_expr(expr_node)->const_val; - if (other_const_val->ok) { - return resolve_expr_const_val_as_other_expr(g, source_node, expr_node); - } + if (var->is_const && var->val_node) { + ConstExprValue *other_const_val = &get_resolved_expr(var->val_node)->const_val; + if (other_const_val->ok) { + return resolve_expr_const_val_as_other_expr(g, source_node, var->val_node); } } return var->type; @@ -2596,9 +2644,15 @@ static TypeTableEntry *analyze_decl_ref(CodeGen *g, AstNode *source_node, AstNod VariableTableEntry *var = decl_node->data.variable_declaration.variable; return analyze_var_ref(g, source_node, var); } else if (decl_node->type == NodeTypeFnProto) { - FnTableEntry *fn_entry = decl_node->data.fn_proto.fn_table_entry; - assert(fn_entry->type_entry); - return resolve_expr_const_val_as_fn(g, source_node, fn_entry); + if (decl_node->data.fn_proto.generic_params.length > 0) { + TypeTableEntry *type_entry = decl_node->data.fn_proto.generic_fn_type; + assert(type_entry); + return resolve_expr_const_val_as_generic_fn(g, source_node, type_entry); + } else { + FnTableEntry *fn_entry = decl_node->data.fn_proto.fn_table_entry; + assert(fn_entry->type_entry); + return resolve_expr_const_val_as_fn(g, source_node, fn_entry); + } } else if (decl_node->type == NodeTypeStructDecl) { return resolve_expr_const_val_as_type(g, source_node, decl_node->data.struct_decl.type_entry); } else if (decl_node->type == NodeTypeTypeDecl) { @@ -3113,7 +3167,7 @@ static TypeTableEntry *analyze_bin_op_expr(CodeGen *g, ImportTableEntry *import, // Set name to nullptr to make the variable anonymous (not visible to programmer). static VariableTableEntry *add_local_var(CodeGen *g, AstNode *source_node, ImportTableEntry *import, - BlockContext *context, Buf *name, TypeTableEntry *type_entry, bool is_const) + BlockContext *context, Buf *name, TypeTableEntry *type_entry, bool is_const, AstNode *val_node) { VariableTableEntry *variable_entry = allocate(1); variable_entry->type = type_entry; @@ -3160,6 +3214,8 @@ static VariableTableEntry *add_local_var(CodeGen *g, AstNode *source_node, Impor variable_entry->is_const = is_const; variable_entry->is_ptr = true; variable_entry->decl_node = source_node; + variable_entry->val_node = val_node; + return variable_entry; } @@ -3182,7 +3238,7 @@ static TypeTableEntry *analyze_unwrap_error_expr(CodeGen *g, ImportTableEntry *i var_node->block_context = child_context; Buf *var_name = &var_node->data.symbol_expr.symbol; node->data.unwrap_err_expr.var = add_local_var(g, var_node, import, child_context, var_name, - g->builtin_types.entry_pure_error, true); + g->builtin_types.entry_pure_error, true, nullptr); } else { child_context = parent_context; } @@ -3260,7 +3316,8 @@ static VariableTableEntry *analyze_variable_declaration_raw(CodeGen *g, ImportTa assert(type != nullptr); // should have been caught by the parser VariableTableEntry *var = add_local_var(g, source_node, import, context, - &variable_declaration->symbol, type, is_const); + &variable_declaration->symbol, type, is_const, + expr_is_maybe ? nullptr : variable_declaration->expr); variable_declaration->variable = var; @@ -3453,17 +3510,17 @@ static TypeTableEntry *analyze_for_expr(CodeGen *g, ImportTableEntry *import, Bl elem_var_node->block_context = child_context; Buf *elem_var_name = &elem_var_node->data.symbol_expr.symbol; node->data.for_expr.elem_var = add_local_var(g, elem_var_node, import, child_context, elem_var_name, - child_type, true); + child_type, true, nullptr); AstNode *index_var_node = node->data.for_expr.index_node; if (index_var_node) { Buf *index_var_name = &index_var_node->data.symbol_expr.symbol; index_var_node->block_context = child_context; node->data.for_expr.index_var = add_local_var(g, index_var_node, import, child_context, index_var_name, - g->builtin_types.entry_isize, true); + g->builtin_types.entry_isize, true, nullptr); } else { node->data.for_expr.index_var = add_local_var(g, node, import, child_context, nullptr, - g->builtin_types.entry_isize, true); + g->builtin_types.entry_isize, true, nullptr); } AstNode *for_body_node = node->data.for_expr.body; @@ -4330,6 +4387,7 @@ static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry case TypeTableEntryIdNumLitInt: case TypeTableEntryIdUndefLit: case TypeTableEntryIdNamespace: + case TypeTableEntryIdGenericFn: add_node_error(g, expr_node, buf_sprintf("type '%s' not eligible for @typeof", buf_ptr(&type_entry->name))); return g->builtin_types.entry_invalid; @@ -4541,6 +4599,92 @@ static TypeTableEntry *analyze_fn_call_raw(CodeGen *g, ImportTableEntry *import, return analyze_fn_call_ptr(g, import, context, expected_type, node, fn_table_entry->type_entry, struct_type); } +static TypeTableEntry *analyze_generic_fn_call(CodeGen *g, ImportTableEntry *import, BlockContext *parent_context, + TypeTableEntry *expected_type, AstNode *node, TypeTableEntry *generic_fn_type) +{ + assert(node->type == NodeTypeFnCallExpr); + assert(generic_fn_type->id == TypeTableEntryIdGenericFn); + + AstNode *decl_node = generic_fn_type->data.generic_fn.decl_node; + assert(decl_node->type == NodeTypeFnProto); + + int expected_param_count = decl_node->data.fn_proto.generic_params.length; + int actual_param_count = node->data.fn_call_expr.params.length; + + if (actual_param_count != expected_param_count) { + add_node_error(g, first_executing_node(node), + buf_sprintf("expected %d arguments, got %d", expected_param_count, actual_param_count)); + return g->builtin_types.entry_invalid; + } + + GenericFnTypeId *generic_fn_type_id = allocate(1); + generic_fn_type_id->decl_node = decl_node; + generic_fn_type_id->generic_param_count = actual_param_count; + generic_fn_type_id->generic_params = allocate(actual_param_count); + + BlockContext *child_context = import->block_context; + for (int i = 0; i < actual_param_count; i += 1) { + AstNode *generic_param_decl_node = decl_node->data.fn_proto.generic_params.at(i); + assert(generic_param_decl_node->type == NodeTypeParamDecl); + + AstNode **generic_param_type_node = &generic_param_decl_node->data.param_decl.type; + + TypeTableEntry *expected_param_type = analyze_expression(g, decl_node->owner, + decl_node->owner->block_context, nullptr, *generic_param_type_node); + if (expected_param_type->id == TypeTableEntryIdInvalid) { + return expected_param_type; + } + AstNode **param_node = &node->data.fn_call_expr.params.at(i); + + TypeTableEntry *param_type = analyze_expression(g, import, child_context, expected_param_type, + *param_node); + if (param_type->id == TypeTableEntryIdInvalid) { + return param_type; + } + + // set child_context so that the previous param is in scope + child_context = new_block_context(generic_param_decl_node, child_context); + + ConstExprValue *const_val = &get_resolved_expr(*param_node)->const_val; + if (const_val->ok) { + add_local_var(g, generic_param_decl_node, decl_node->owner, child_context, + &generic_param_decl_node->data.param_decl.name, param_type, true, *param_node); + } else { + add_node_error(g, *param_node, buf_sprintf("unable to resolve constant expression")); + + add_local_var(g, generic_param_decl_node, decl_node->owner, child_context, + &generic_param_decl_node->data.param_decl.name, g->builtin_types.entry_invalid, + true, nullptr); + + return g->builtin_types.entry_invalid; + } + + GenericParamValue *generic_param_value = &generic_fn_type_id->generic_params[i]; + generic_param_value->type = param_type; + generic_param_value->node = *param_node; + } + + + auto entry = g->generic_table.maybe_get(generic_fn_type_id); + if (entry) { + AstNode *impl_decl_node = entry->value; + assert(impl_decl_node->type == NodeTypeFnProto); + FnTableEntry *fn_table_entry = impl_decl_node->data.fn_proto.fn_table_entry; + return resolve_expr_const_val_as_fn(g, node, fn_table_entry); + } + + // make a type from the generic parameters supplied + assert(decl_node->type == NodeTypeFnProto); + AstNode *impl_decl_node = ast_clone_subtree(decl_node); + + preview_fn_proto_instance(g, import, decl_node, child_context); + + g->generic_table.put(generic_fn_type_id, impl_decl_node); + + FnTableEntry *fn_table_entry = decl_node->data.fn_proto.fn_table_entry; + return resolve_expr_const_val_as_fn(g, node, fn_table_entry); +} + static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, TypeTableEntry *expected_type, AstNode *node) { @@ -4627,6 +4771,8 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import return analyze_fn_call_raw(g, import, context, expected_type, node, const_val->data.x_fn, bare_struct_type); + } else if (invoke_type_entry->id == TypeTableEntryIdGenericFn) { + return analyze_generic_fn_call(g, import, context, expected_type, node, const_val->data.x_type); } else { add_node_error(g, fn_ref_expr, buf_sprintf("type '%s' not a function", buf_ptr(&invoke_type_entry->name))); @@ -4971,7 +5117,7 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import, Buf *var_name = &var_node->data.symbol_expr.symbol; var_node->block_context = child_context; prong_node->data.switch_prong.var = add_local_var(g, var_node, import, - child_context, var_name, var_type, true); + child_context, var_name, var_type, true, nullptr); prong_node->data.switch_prong.var_is_target_expr = var_is_target_expr; } } @@ -5391,7 +5537,7 @@ static void analyze_fn_body(CodeGen *g, FnTableEntry *fn_table_entry) { } VariableTableEntry *var = add_local_var(g, param_decl_node, import, context, ¶m_decl->name, - type, true); + type, true, nullptr); var->src_arg_index = i; param_decl_node->data.param_decl.variable = var; @@ -5413,7 +5559,9 @@ static void add_top_level_decl(CodeGen *g, ImportTableEntry *import, BlockContex tld->import = import; tld->name = name; - if (g->check_unused || g->is_test_build || tld->visib_mod == VisibModExport) { + bool want_as_export = (g->check_unused || g->is_test_build || tld->visib_mod == VisibModExport); + bool is_generic = (node->type == NodeTypeFnProto && node->data.fn_proto.generic_params.length > 0); + if (!is_generic && want_as_export) { g->export_queue.append(node); } @@ -5909,6 +6057,7 @@ bool handle_is_ptr(TypeTableEntry *type_entry) { case TypeTableEntryIdNumLitInt: case TypeTableEntryIdUndefLit: case TypeTableEntryIdNamespace: + case TypeTableEntryIdGenericFn: zig_unreachable(); case TypeTableEntryIdUnreachable: case TypeTableEntryIdVoid: @@ -5965,7 +6114,6 @@ uint32_t fn_type_id_hash(FnTypeId *id) { result += id->is_cold ? 3605523458 : 0; result += id->is_var_args ? 1931444534 : 0; result += hash_ptr(id->return_type); - result += id->param_count; for (int i = 0; i < id->param_count; i += 1) { FnTypeParamInfo *info = &id->param_info[i]; result += info->is_noalias ? 892356923 : 0; @@ -5999,6 +6147,76 @@ bool fn_type_id_eql(FnTypeId *a, FnTypeId *b) { return true; } +static uint32_t hash_const_val(TypeTableEntry *type, ConstExprValue *const_val) { + switch (type->id) { + case TypeTableEntryIdBool: + return const_val->data.x_bool ? 127863866 : 215080464; + case TypeTableEntryIdMetaType: + return hash_ptr(const_val->data.x_type); + case TypeTableEntryIdVoid: + return 4149439618; + case TypeTableEntryIdInt: + case TypeTableEntryIdNumLitInt: + return ((uint32_t)(bignum_to_twos_complement(&const_val->data.x_bignum) % UINT32_MAX)) * 1331471175; + case TypeTableEntryIdFloat: + case TypeTableEntryIdNumLitFloat: + return const_val->data.x_bignum.data.x_float * UINT32_MAX; + case TypeTableEntryIdPointer: + return hash_ptr(const_val->data.x_ptr.ptr); + case TypeTableEntryIdUndefLit: + return 162837799; + case TypeTableEntryIdArray: + // TODO better hashing algorithm + return 1166190605; + case TypeTableEntryIdStruct: + // TODO better hashing algorithm + return 1532530855; + case TypeTableEntryIdMaybe: + if (const_val->data.x_maybe) { + TypeTableEntry *child_type = type->data.maybe.child_type; + return hash_const_val(child_type, const_val->data.x_maybe) * 1992916303; + } else { + return 4016830364; + } + case TypeTableEntryIdErrorUnion: + // TODO better hashing algorithm + return 3415065496; + case TypeTableEntryIdPureError: + // TODO better hashing algorithm + return 2630160122; + case TypeTableEntryIdEnum: + // TODO better hashing algorithm + return 31643936; + case TypeTableEntryIdFn: + return hash_ptr(const_val->data.x_fn); + case TypeTableEntryIdTypeDecl: + return hash_ptr(const_val->data.x_type); + case TypeTableEntryIdNamespace: + return hash_ptr(const_val->data.x_import); + case TypeTableEntryIdGenericFn: + case TypeTableEntryIdInvalid: + case TypeTableEntryIdUnreachable: + zig_unreachable(); + } +} + +uint32_t generic_fn_type_id_hash(GenericFnTypeId *id) { + uint32_t result = 0; + result += hash_ptr(id->decl_node); + for (int i = 0; i < id->generic_param_count; i += 1) { + GenericParamValue *generic_param = &id->generic_params[i]; + ConstExprValue *const_val = &get_resolved_expr(generic_param->node)->const_val; + assert(const_val->ok); + result += hash_const_val(generic_param->type, const_val); + } + return result; +} + +bool generic_fn_type_id_eql(GenericFnTypeId *a, GenericFnTypeId *b) { + // TODO + return true; +} + bool type_has_bits(TypeTableEntry *type_entry) { assert(type_entry); assert(type_entry->id != TypeTableEntryIdInvalid); @@ -6027,6 +6245,7 @@ static TypeTableEntry *type_of_first_thing_in_memory(TypeTableEntry *type_entry) case TypeTableEntryIdMetaType: case TypeTableEntryIdVoid: case TypeTableEntryIdNamespace: + case TypeTableEntryIdGenericFn: zig_unreachable(); case TypeTableEntryIdArray: return type_of_first_thing_in_memory(type_entry->data.array.child_type); diff --git a/src/ast_render.cpp b/src/ast_render.cpp index 4a1d829546..59a287b886 100644 --- a/src/ast_render.cpp +++ b/src/ast_render.cpp @@ -59,15 +59,6 @@ static const char *prefix_op_str(PrefixOp prefix_op) { zig_unreachable(); } -static const char *return_prefix_str(ReturnKind kind) { - switch (kind) { - case ReturnKindError: return "%"; - case ReturnKindMaybe: return "?"; - case ReturnKindUnconditional: return ""; - } - zig_unreachable(); -} - static const char *visib_mod_string(VisibMod mod) { switch (mod) { case VisibModPub: return "pub "; @@ -195,316 +186,36 @@ static const char *node_type_str(NodeType node_type) { zig_unreachable(); } +struct AstPrint { + int indent; + FILE *f; +}; + +static void ast_print_visit(AstNode **node_ptr, void *context) { + AstNode *node = *node_ptr; + AstPrint *ap = (AstPrint *)context; + + for (int i = 0; i < ap->indent; i += 1) { + fprintf(ap->f, " "); + } + + fprintf(ap->f, "%s\n", node_type_str(node->type)); + + AstPrint new_ap; + new_ap.indent = ap->indent + 2; + new_ap.f = ap->f; + + ast_visit_node_children(node, ast_print_visit, &new_ap); +} void ast_print(FILE *f, AstNode *node, int indent) { - for (int i = 0; i < indent; i += 1) { - fprintf(f, " "); - } - assert(node->type == NodeTypeRoot || *node->parent_field == node); - - switch (node->type) { - case NodeTypeRoot: - fprintf(f, "%s\n", node_type_str(node->type)); - for (int i = 0; i < node->data.root.top_level_decls.length; i += 1) { - AstNode *child = node->data.root.top_level_decls.at(i); - ast_print(f, child, indent + 2); - } - break; - case NodeTypeFnDef: - { - fprintf(f, "%s\n", node_type_str(node->type)); - AstNode *child = node->data.fn_def.fn_proto; - ast_print(f, child, indent + 2); - ast_print(f, node->data.fn_def.body, indent + 2); - break; - } - case NodeTypeFnProto: - { - Buf *name_buf = &node->data.fn_proto.name; - fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(name_buf)); - - for (int i = 0; i < node->data.fn_proto.params.length; i += 1) { - AstNode *child = node->data.fn_proto.params.at(i); - ast_print(f, child, indent + 2); - } - - ast_print(f, node->data.fn_proto.return_type, indent + 2); - - break; - } - case NodeTypeBlock: - { - fprintf(f, "%s\n", node_type_str(node->type)); - for (int i = 0; i < node->data.block.statements.length; i += 1) { - AstNode *child = node->data.block.statements.at(i); - ast_print(f, child, indent + 2); - } - break; - } - case NodeTypeParamDecl: - { - Buf *name_buf = &node->data.param_decl.name; - fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(name_buf)); - - ast_print(f, node->data.param_decl.type, indent + 2); - - break; - } - case NodeTypeReturnExpr: - { - const char *prefix_str = return_prefix_str(node->data.return_expr.kind); - fprintf(f, "%s%s\n", prefix_str, node_type_str(node->type)); - if (node->data.return_expr.expr) - ast_print(f, node->data.return_expr.expr, indent + 2); - break; - } - case NodeTypeDefer: - { - const char *prefix_str = return_prefix_str(node->data.defer.kind); - fprintf(f, "%s%s\n", prefix_str, node_type_str(node->type)); - if (node->data.defer.expr) - ast_print(f, node->data.defer.expr, indent + 2); - break; - } - case NodeTypeVariableDeclaration: - { - Buf *name_buf = &node->data.variable_declaration.symbol; - fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(name_buf)); - if (node->data.variable_declaration.type) - ast_print(f, node->data.variable_declaration.type, indent + 2); - if (node->data.variable_declaration.expr) - ast_print(f, node->data.variable_declaration.expr, indent + 2); - break; - } - case NodeTypeTypeDecl: - { - Buf *name_buf = &node->data.type_decl.symbol; - fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(name_buf)); - ast_print(f, node->data.type_decl.child_type, indent + 2); - break; - } - case NodeTypeErrorValueDecl: - { - Buf *name_buf = &node->data.error_value_decl.name; - fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(name_buf)); - break; - } - case NodeTypeFnDecl: - fprintf(f, "%s\n", node_type_str(node->type)); - ast_print(f, node->data.fn_decl.fn_proto, indent + 2); - break; - case NodeTypeBinOpExpr: - fprintf(f, "%s %s\n", node_type_str(node->type), - bin_op_str(node->data.bin_op_expr.bin_op)); - ast_print(f, node->data.bin_op_expr.op1, indent + 2); - ast_print(f, node->data.bin_op_expr.op2, indent + 2); - break; - case NodeTypeUnwrapErrorExpr: - fprintf(f, "%s\n", node_type_str(node->type)); - ast_print(f, node->data.unwrap_err_expr.op1, indent + 2); - if (node->data.unwrap_err_expr.symbol) { - ast_print(f, node->data.unwrap_err_expr.symbol, indent + 2); - } - ast_print(f, node->data.unwrap_err_expr.op2, indent + 2); - break; - case NodeTypeFnCallExpr: - fprintf(f, "%s\n", node_type_str(node->type)); - ast_print(f, node->data.fn_call_expr.fn_ref_expr, indent + 2); - for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) { - AstNode *child = node->data.fn_call_expr.params.at(i); - ast_print(f, child, indent + 2); - } - break; - case NodeTypeArrayAccessExpr: - fprintf(f, "%s\n", node_type_str(node->type)); - ast_print(f, node->data.array_access_expr.array_ref_expr, indent + 2); - ast_print(f, node->data.array_access_expr.subscript, indent + 2); - break; - case NodeTypeSliceExpr: - fprintf(f, "%s\n", node_type_str(node->type)); - ast_print(f, node->data.slice_expr.array_ref_expr, indent + 2); - ast_print(f, node->data.slice_expr.start, indent + 2); - if (node->data.slice_expr.end) { - ast_print(f, node->data.slice_expr.end, indent + 2); - } - break; - case NodeTypeDirective: - fprintf(f, "%s\n", node_type_str(node->type)); - ast_print(f, node->data.directive.expr, indent + 2); - break; - case NodeTypePrefixOpExpr: - fprintf(f, "%s %s\n", node_type_str(node->type), - prefix_op_str(node->data.prefix_op_expr.prefix_op)); - ast_print(f, node->data.prefix_op_expr.primary_expr, indent + 2); - break; - case NodeTypeNumberLiteral: - { - NumLit kind = node->data.number_literal.kind; - const char *name = node_type_str(node->type); - if (kind == NumLitUInt) { - fprintf(f, "%s uint %" PRIu64 "\n", name, node->data.number_literal.data.x_uint); - } else { - fprintf(f, "%s float %f\n", name, node->data.number_literal.data.x_float); - } - break; - } - case NodeTypeStringLiteral: - { - const char *c = node->data.string_literal.c ? "c" : ""; - fprintf(f, "StringLiteral %s'%s'\n", c, - buf_ptr(&node->data.string_literal.buf)); - break; - } - case NodeTypeCharLiteral: - { - fprintf(f, "%s '%c'\n", node_type_str(node->type), node->data.char_literal.value); - break; - } - case NodeTypeSymbol: - fprintf(f, "Symbol %s\n", buf_ptr(&node->data.symbol_expr.symbol)); - break; - case NodeTypeUse: - fprintf(f, "%s\n", node_type_str(node->type)); - ast_print(f, node->data.use.expr, indent + 2); - break; - case NodeTypeBoolLiteral: - fprintf(f, "%s '%s'\n", node_type_str(node->type), - node->data.bool_literal.value ? "true" : "false"); - break; - case NodeTypeNullLiteral: - fprintf(f, "%s\n", node_type_str(node->type)); - break; - case NodeTypeIfBoolExpr: - fprintf(f, "%s\n", node_type_str(node->type)); - if (node->data.if_bool_expr.condition) - ast_print(f, node->data.if_bool_expr.condition, indent + 2); - ast_print(f, node->data.if_bool_expr.then_block, indent + 2); - if (node->data.if_bool_expr.else_node) - ast_print(f, node->data.if_bool_expr.else_node, indent + 2); - break; - case NodeTypeIfVarExpr: - { - Buf *name_buf = &node->data.if_var_expr.var_decl.symbol; - fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(name_buf)); - if (node->data.if_var_expr.var_decl.type) - ast_print(f, node->data.if_var_expr.var_decl.type, indent + 2); - if (node->data.if_var_expr.var_decl.expr) - ast_print(f, node->data.if_var_expr.var_decl.expr, indent + 2); - ast_print(f, node->data.if_var_expr.then_block, indent + 2); - if (node->data.if_var_expr.else_node) - ast_print(f, node->data.if_var_expr.else_node, indent + 2); - break; - } - case NodeTypeWhileExpr: - fprintf(f, "%s\n", node_type_str(node->type)); - ast_print(f, node->data.while_expr.condition, indent + 2); - ast_print(f, node->data.while_expr.body, indent + 2); - break; - case NodeTypeForExpr: - fprintf(f, "%s\n", node_type_str(node->type)); - ast_print(f, node->data.for_expr.elem_node, indent + 2); - ast_print(f, node->data.for_expr.array_expr, indent + 2); - if (node->data.for_expr.index_node) { - ast_print(f, node->data.for_expr.index_node, indent + 2); - } - ast_print(f, node->data.for_expr.body, indent + 2); - break; - case NodeTypeSwitchExpr: - fprintf(f, "%s\n", node_type_str(node->type)); - ast_print(f, node->data.switch_expr.expr, indent + 2); - for (int i = 0; i < node->data.switch_expr.prongs.length; i += 1) { - AstNode *child_node = node->data.switch_expr.prongs.at(i); - ast_print(f, child_node, indent + 2); - } - break; - case NodeTypeSwitchProng: - fprintf(f, "%s\n", node_type_str(node->type)); - for (int i = 0; i < node->data.switch_prong.items.length; i += 1) { - AstNode *child_node = node->data.switch_prong.items.at(i); - ast_print(f, child_node, indent + 2); - } - if (node->data.switch_prong.var_symbol) { - ast_print(f, node->data.switch_prong.var_symbol, indent + 2); - } - ast_print(f, node->data.switch_prong.expr, indent + 2); - break; - case NodeTypeSwitchRange: - fprintf(f, "%s\n", node_type_str(node->type)); - ast_print(f, node->data.switch_range.start, indent + 2); - ast_print(f, node->data.switch_range.end, indent + 2); - break; - case NodeTypeLabel: - fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.label.name)); - break; - case NodeTypeGoto: - fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.goto_expr.name)); - break; - case NodeTypeBreak: - fprintf(f, "%s\n", node_type_str(node->type)); - break; - case NodeTypeContinue: - fprintf(f, "%s\n", node_type_str(node->type)); - break; - case NodeTypeUndefinedLiteral: - fprintf(f, "%s\n", node_type_str(node->type)); - break; - case NodeTypeAsmExpr: - fprintf(f, "%s\n", node_type_str(node->type)); - break; - case NodeTypeFieldAccessExpr: - fprintf(f, "%s '%s'\n", node_type_str(node->type), - buf_ptr(&node->data.field_access_expr.field_name)); - ast_print(f, node->data.field_access_expr.struct_expr, indent + 2); - break; - case NodeTypeStructDecl: - fprintf(f, "%s '%s'\n", - node_type_str(node->type), buf_ptr(&node->data.struct_decl.name)); - for (int i = 0; i < node->data.struct_decl.fields.length; i += 1) { - AstNode *child = node->data.struct_decl.fields.at(i); - ast_print(f, child, indent + 2); - } - for (int i = 0; i < node->data.struct_decl.fns.length; i += 1) { - AstNode *child = node->data.struct_decl.fns.at(i); - ast_print(f, child, indent + 2); - } - break; - case NodeTypeStructField: - fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.struct_field.name)); - if (node->data.struct_field.type) { - ast_print(f, node->data.struct_field.type, indent + 2); - } - break; - case NodeTypeStructValueField: - fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.struct_val_field.name)); - ast_print(f, node->data.struct_val_field.expr, indent + 2); - break; - case NodeTypeContainerInitExpr: - fprintf(f, "%s\n", node_type_str(node->type)); - ast_print(f, node->data.container_init_expr.type, indent + 2); - for (int i = 0; i < node->data.container_init_expr.entries.length; i += 1) { - AstNode *child = node->data.container_init_expr.entries.at(i); - ast_print(f, child, indent + 2); - } - break; - case NodeTypeArrayType: - { - const char *const_str = node->data.array_type.is_const ? "const" : "var"; - fprintf(f, "%s %s\n", node_type_str(node->type), const_str); - if (node->data.array_type.size) { - ast_print(f, node->data.array_type.size, indent + 2); - } - ast_print(f, node->data.array_type.child_type, indent + 2); - break; - } - case NodeTypeErrorType: - fprintf(f, "%s\n", node_type_str(node->type)); - break; - case NodeTypeTypeLiteral: - fprintf(f, "%s\n", node_type_str(node->type)); - break; - } + AstPrint ap; + ap.indent = indent; + ap.f = f; + ast_visit_node_children(node, ast_print_visit, &ap); } + struct AstRender { int indent; int indent_size; diff --git a/src/ast_render.hpp b/src/ast_render.hpp index a971fd7a2f..7f5b1ba0e9 100644 --- a/src/ast_render.hpp +++ b/src/ast_render.hpp @@ -9,6 +9,7 @@ #define ZIG_AST_RENDER_HPP #include "all_types.hpp" +#include "parser.hpp" #include diff --git a/src/codegen.cpp b/src/codegen.cpp index 49d257b6d8..fc7c036061 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -62,6 +62,7 @@ CodeGen *codegen_create(Buf *root_source_dir, const ZigTarget *target) { g->primitive_type_table.init(32); g->fn_type_table.init(32); g->error_table.init(16); + g->generic_table.init(16); g->is_release_build = false; g->is_test_build = false; g->error_value_count = 1; @@ -2927,6 +2928,7 @@ static LLVMValueRef gen_const_val(CodeGen *g, TypeTableEntry *type_entry, ConstE case TypeTableEntryIdUndefLit: case TypeTableEntryIdVoid: case TypeTableEntryIdNamespace: + case TypeTableEntryIdGenericFn: zig_unreachable(); } diff --git a/src/parser.cpp b/src/parser.cpp index 14ac0fb916..74cf191629 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -2243,7 +2243,7 @@ static AstNode *ast_parse_block(ParseContext *pc, int *token_index, bool mandato } /* -FnProto : "fn" option("Symbol") ParamDeclList option("->" PrefixOpExpression) +FnProto = "fn" option("Symbol") option(ParamDeclList) ParamDeclList option("->" TypeExpr) */ static AstNode *ast_parse_fn_proto(ParseContext *pc, int *token_index, bool mandatory, ZigList *directives, VisibMod visib_mod) @@ -2273,6 +2273,17 @@ static AstNode *ast_parse_fn_proto(ParseContext *pc, int *token_index, bool mand ast_parse_param_decl_list(pc, token_index, &node->data.fn_proto.params, &node->data.fn_proto.is_var_args); + Token *maybe_lparen = &pc->tokens->at(*token_index); + if (maybe_lparen->id == TokenIdLParen) { + for (int i = 0; i < node->data.fn_proto.params.length; i += 1) { + node->data.fn_proto.generic_params.append(node->data.fn_proto.params.at(i)); + } + node->data.fn_proto.generic_params_is_var_args = node->data.fn_proto.is_var_args; + + node->data.fn_proto.params.resize(0); + ast_parse_param_decl_list(pc, token_index, &node->data.fn_proto.params, &node->data.fn_proto.is_var_args); + } + Token *next_token = &pc->tokens->at(*token_index); if (next_token->id == TokenIdArrow) { *token_index += 1; @@ -2626,72 +2637,73 @@ AstNode *ast_parse(Buf *buf, ZigList *tokens, ImportTableEntry *owner, return pc.root; } -static void set_field(AstNode **field) { - if (*field) { - (*field)->parent_field = field; +static void visit_field(AstNode **node, void (*visit)(AstNode **, void *context), void *context) { + if (*node) { + visit(node, context); } } -static void set_list_fields(ZigList *list) { +static void visit_node_list(ZigList *list, void (*visit)(AstNode **, void *context), void *context) { if (list) { for (int i = 0; i < list->length; i += 1) { - set_field(&list->at(i)); + visit(&list->at(i), context); } } } -void normalize_parent_ptrs(AstNode *node) { +void ast_visit_node_children(AstNode *node, void (*visit)(AstNode **, void *context), void *context) { switch (node->type) { case NodeTypeRoot: - set_list_fields(&node->data.root.top_level_decls); + visit_node_list(&node->data.root.top_level_decls, visit, context); break; case NodeTypeFnProto: - set_field(&node->data.fn_proto.return_type); - set_list_fields(node->data.fn_proto.top_level_decl.directives); - set_list_fields(&node->data.fn_proto.params); + visit_field(&node->data.fn_proto.return_type, visit, context); + visit_node_list(node->data.fn_proto.top_level_decl.directives, visit, context); + visit_node_list(&node->data.fn_proto.generic_params, visit, context); + visit_node_list(&node->data.fn_proto.params, visit, context); break; case NodeTypeFnDef: - set_field(&node->data.fn_def.fn_proto); - set_field(&node->data.fn_def.body); + visit_field(&node->data.fn_def.fn_proto, visit, context); + visit_field(&node->data.fn_def.body, visit, context); break; case NodeTypeFnDecl: - set_field(&node->data.fn_decl.fn_proto); + visit_field(&node->data.fn_decl.fn_proto, visit, context); break; case NodeTypeParamDecl: - set_field(&node->data.param_decl.type); + visit_field(&node->data.param_decl.type, visit, context); break; case NodeTypeBlock: - set_list_fields(&node->data.block.statements); + visit_node_list(&node->data.block.statements, visit, context); break; case NodeTypeDirective: - set_field(&node->data.directive.expr); + visit_field(&node->data.directive.expr, visit, context); break; case NodeTypeReturnExpr: - set_field(&node->data.return_expr.expr); + visit_field(&node->data.return_expr.expr, visit, context); break; case NodeTypeDefer: - set_field(&node->data.defer.expr); + visit_field(&node->data.defer.expr, visit, context); break; case NodeTypeVariableDeclaration: - set_list_fields(node->data.variable_declaration.top_level_decl.directives); - set_field(&node->data.variable_declaration.type); - set_field(&node->data.variable_declaration.expr); + visit_node_list(node->data.variable_declaration.top_level_decl.directives, visit, context); + visit_field(&node->data.variable_declaration.type, visit, context); + visit_field(&node->data.variable_declaration.expr, visit, context); break; case NodeTypeTypeDecl: - set_list_fields(node->data.type_decl.top_level_decl.directives); - set_field(&node->data.type_decl.child_type); + visit_node_list(node->data.type_decl.top_level_decl.directives, visit, context); + visit_field(&node->data.type_decl.child_type, visit, context); break; case NodeTypeErrorValueDecl: // none break; case NodeTypeBinOpExpr: - set_field(&node->data.bin_op_expr.op1); - set_field(&node->data.bin_op_expr.op2); + visit_field(&node->data.bin_op_expr.op1, visit, context); + visit_field(&node->data.bin_op_expr.op2, visit, context); break; case NodeTypeUnwrapErrorExpr: - set_field(&node->data.unwrap_err_expr.op1); - set_field(&node->data.unwrap_err_expr.symbol); - set_field(&node->data.unwrap_err_expr.op2); + visit_field(&node->data.unwrap_err_expr.op1, visit, context); + visit_field(&node->data.unwrap_err_expr.symbol, visit, context); + visit_field(&node->data.unwrap_err_expr.op2, visit, context); break; case NodeTypeNumberLiteral: // none @@ -2706,27 +2718,27 @@ void normalize_parent_ptrs(AstNode *node) { // none break; case NodeTypePrefixOpExpr: - set_field(&node->data.prefix_op_expr.primary_expr); + visit_field(&node->data.prefix_op_expr.primary_expr, visit, context); break; case NodeTypeFnCallExpr: - set_field(&node->data.fn_call_expr.fn_ref_expr); - set_list_fields(&node->data.fn_call_expr.params); + visit_field(&node->data.fn_call_expr.fn_ref_expr, visit, context); + visit_node_list(&node->data.fn_call_expr.params, visit, context); break; case NodeTypeArrayAccessExpr: - set_field(&node->data.array_access_expr.array_ref_expr); - set_field(&node->data.array_access_expr.subscript); + visit_field(&node->data.array_access_expr.array_ref_expr, visit, context); + visit_field(&node->data.array_access_expr.subscript, visit, context); break; case NodeTypeSliceExpr: - set_field(&node->data.slice_expr.array_ref_expr); - set_field(&node->data.slice_expr.start); - set_field(&node->data.slice_expr.end); + visit_field(&node->data.slice_expr.array_ref_expr, visit, context); + visit_field(&node->data.slice_expr.start, visit, context); + visit_field(&node->data.slice_expr.end, visit, context); break; case NodeTypeFieldAccessExpr: - set_field(&node->data.field_access_expr.struct_expr); + visit_field(&node->data.field_access_expr.struct_expr, visit, context); break; case NodeTypeUse: - set_field(&node->data.use.expr); - set_list_fields(node->data.use.top_level_decl.directives); + visit_field(&node->data.use.expr, visit, context); + visit_node_list(node->data.use.top_level_decl.directives, visit, context); break; case NodeTypeBoolLiteral: // none @@ -2738,38 +2750,38 @@ void normalize_parent_ptrs(AstNode *node) { // none break; case NodeTypeIfBoolExpr: - set_field(&node->data.if_bool_expr.condition); - set_field(&node->data.if_bool_expr.then_block); - set_field(&node->data.if_bool_expr.else_node); + visit_field(&node->data.if_bool_expr.condition, visit, context); + visit_field(&node->data.if_bool_expr.then_block, visit, context); + visit_field(&node->data.if_bool_expr.else_node, visit, context); break; case NodeTypeIfVarExpr: - set_field(&node->data.if_var_expr.var_decl.type); - set_field(&node->data.if_var_expr.var_decl.expr); - set_field(&node->data.if_var_expr.then_block); - set_field(&node->data.if_var_expr.else_node); + visit_field(&node->data.if_var_expr.var_decl.type, visit, context); + visit_field(&node->data.if_var_expr.var_decl.expr, visit, context); + visit_field(&node->data.if_var_expr.then_block, visit, context); + visit_field(&node->data.if_var_expr.else_node, visit, context); break; case NodeTypeWhileExpr: - set_field(&node->data.while_expr.condition); - set_field(&node->data.while_expr.body); + visit_field(&node->data.while_expr.condition, visit, context); + visit_field(&node->data.while_expr.body, visit, context); break; case NodeTypeForExpr: - set_field(&node->data.for_expr.elem_node); - set_field(&node->data.for_expr.array_expr); - set_field(&node->data.for_expr.index_node); - set_field(&node->data.for_expr.body); + visit_field(&node->data.for_expr.elem_node, visit, context); + visit_field(&node->data.for_expr.array_expr, visit, context); + visit_field(&node->data.for_expr.index_node, visit, context); + visit_field(&node->data.for_expr.body, visit, context); break; case NodeTypeSwitchExpr: - set_field(&node->data.switch_expr.expr); - set_list_fields(&node->data.switch_expr.prongs); + visit_field(&node->data.switch_expr.expr, visit, context); + visit_node_list(&node->data.switch_expr.prongs, visit, context); break; case NodeTypeSwitchProng: - set_list_fields(&node->data.switch_prong.items); - set_field(&node->data.switch_prong.var_symbol); - set_field(&node->data.switch_prong.expr); + visit_node_list(&node->data.switch_prong.items, visit, context); + visit_field(&node->data.switch_prong.var_symbol, visit, context); + visit_field(&node->data.switch_prong.expr, visit, context); break; case NodeTypeSwitchRange: - set_field(&node->data.switch_range.start); - set_field(&node->data.switch_range.end); + visit_field(&node->data.switch_range.start, visit, context); + visit_field(&node->data.switch_range.end, visit, context); break; case NodeTypeLabel: // none @@ -2786,32 +2798,32 @@ void normalize_parent_ptrs(AstNode *node) { case NodeTypeAsmExpr: for (int i = 0; i < node->data.asm_expr.input_list.length; i += 1) { AsmInput *asm_input = node->data.asm_expr.input_list.at(i); - set_field(&asm_input->expr); + visit_field(&asm_input->expr, visit, context); } for (int i = 0; i < node->data.asm_expr.output_list.length; i += 1) { AsmOutput *asm_output = node->data.asm_expr.output_list.at(i); - set_field(&asm_output->return_type); + visit_field(&asm_output->return_type, visit, context); } break; case NodeTypeStructDecl: - set_list_fields(&node->data.struct_decl.fields); - set_list_fields(&node->data.struct_decl.fns); - set_list_fields(node->data.struct_decl.top_level_decl.directives); + visit_node_list(&node->data.struct_decl.fields, visit, context); + visit_node_list(&node->data.struct_decl.fns, visit, context); + visit_node_list(node->data.struct_decl.top_level_decl.directives, visit, context); break; case NodeTypeStructField: - set_field(&node->data.struct_field.type); - set_list_fields(node->data.struct_field.top_level_decl.directives); + visit_field(&node->data.struct_field.type, visit, context); + visit_node_list(node->data.struct_field.top_level_decl.directives, visit, context); break; case NodeTypeContainerInitExpr: - set_field(&node->data.container_init_expr.type); - set_list_fields(&node->data.container_init_expr.entries); + visit_field(&node->data.container_init_expr.type, visit, context); + visit_node_list(&node->data.container_init_expr.entries, visit, context); break; case NodeTypeStructValueField: - set_field(&node->data.struct_val_field.expr); + visit_field(&node->data.struct_val_field.expr, visit, context); break; case NodeTypeArrayType: - set_field(&node->data.array_type.size); - set_field(&node->data.array_type.child_type); + visit_field(&node->data.array_type.size, visit, context); + visit_field(&node->data.array_type.child_type, visit, context); break; case NodeTypeErrorType: // none @@ -2821,3 +2833,29 @@ void normalize_parent_ptrs(AstNode *node) { break; } } + +static void normalize_parent_ptrs_visit(AstNode **node, void *context) { + (*node)->parent_field = node; +} + +void normalize_parent_ptrs(AstNode *node) { + ast_visit_node_children(node, normalize_parent_ptrs_visit, nullptr); +} + +static AstNode *clone_node(AstNode *old_node) { + AstNode *new_node = allocate_nonzero(1); + memcpy(new_node, old_node, sizeof(AstNode)); + return new_node; +} + +static void ast_clone_subtree_visit(AstNode **node, void *context) { + *node = clone_node(*node); + (*node)->parent_field = node; + ast_visit_node_children(*node, ast_clone_subtree_visit, nullptr); +} + +AstNode *ast_clone_subtree(AstNode *old_node) { + AstNode *new_node = clone_node(old_node); + ast_visit_node_children(new_node, ast_clone_subtree_visit, nullptr); + return new_node; +} diff --git a/src/parser.hpp b/src/parser.hpp index 6cdff22329..00f7ad1eed 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -20,10 +20,11 @@ void ast_token_error(Token *token, const char *format, ...); AstNode * ast_parse(Buf *buf, ZigList *tokens, ImportTableEntry *owner, ErrColor err_color, uint32_t *next_node_index); -const char *node_type_str(NodeType node_type); - void ast_print(AstNode *node, int indent); void normalize_parent_ptrs(AstNode *node); +AstNode *ast_clone_subtree(AstNode *node); +void ast_visit_node_children(AstNode *node, void (*visit)(AstNode **, void *context), void *context); + #endif diff --git a/test/self_hosted.zig b/test/self_hosted.zig index ef8dcdef0a..af3c26e8f0 100644 --- a/test/self_hosted.zig +++ b/test/self_hosted.zig @@ -512,6 +512,16 @@ three)"; +#attribute("test") +fn simple_generic_fn() { + assert(max(i32)(3, -1) == 3); +} + +fn max(T: type)(a: T, b: T) -> T { + return if (a > b) a else b; +} + + fn assert(b: bool) { if (!b) unreachable{} }