support break in suspend blocks

* you can label suspend blocks
 * labeled break supports suspend blocks

See #803
This commit is contained in:
Andrew Kelley 2018-04-18 22:21:54 -04:00
parent ca4341f7ba
commit 06909ceaab
9 changed files with 136 additions and 13 deletions

View File

@ -5918,7 +5918,7 @@ Defer(body) = ("defer" | "deferror") body
IfExpression(body) = "if" "(" Expression ")" body option("else" BlockExpression(body))
SuspendExpression(body) = "suspend" option(("|" Symbol "|" body))
SuspendExpression(body) = option(Symbol ":") "suspend" option(("|" Symbol "|" body))
IfErrorExpression(body) = "if" "(" Expression ")" option("|" option("*") Symbol "|") body "else" "|" Symbol "|" BlockExpression(body)

View File

@ -867,6 +867,7 @@ struct AstNodeAwaitExpr {
};
struct AstNodeSuspend {
Buf *name;
AstNode *block;
AstNode *promise_symbol;
};
@ -1757,6 +1758,7 @@ enum ScopeId {
ScopeIdVarDecl,
ScopeIdCImport,
ScopeIdLoop,
ScopeIdSuspend,
ScopeIdFnDef,
ScopeIdCompTime,
ScopeIdCoroPrelude,
@ -1852,6 +1854,17 @@ struct ScopeLoop {
ZigList<IrBasicBlock *> *incoming_blocks;
};
// This scope is created for a suspend block in order to have labeled
// suspend for breaking out of a suspend and for detecting if a suspend
// block is inside a suspend block.
struct ScopeSuspend {
Scope base;
Buf *name;
IrBasicBlock *resume_block;
bool reported_err;
};
// This scope is created for a comptime expression.
// NodeTypeCompTime, NodeTypeSwitchExpr
struct ScopeCompTime {

View File

@ -156,6 +156,14 @@ ScopeLoop *create_loop_scope(AstNode *node, Scope *parent) {
return scope;
}
ScopeSuspend *create_suspend_scope(AstNode *node, Scope *parent) {
assert(node->type == NodeTypeSuspend);
ScopeSuspend *scope = allocate<ScopeSuspend>(1);
init_scope(&scope->base, ScopeIdSuspend, node, parent);
scope->name = node->data.suspend.name;
return scope;
}
ScopeFnDef *create_fndef_scope(AstNode *node, Scope *parent, FnTableEntry *fn_entry) {
ScopeFnDef *scope = allocate<ScopeFnDef>(1);
init_scope(&scope->base, ScopeIdFnDef, node, parent);
@ -3616,6 +3624,7 @@ FnTableEntry *scope_get_fn_if_root(Scope *scope) {
case ScopeIdVarDecl:
case ScopeIdCImport:
case ScopeIdLoop:
case ScopeIdSuspend:
case ScopeIdCompTime:
case ScopeIdCoroPrelude:
scope = scope->parent;

View File

@ -104,6 +104,7 @@ ScopeDeferExpr *create_defer_expr_scope(AstNode *node, Scope *parent);
Scope *create_var_scope(AstNode *node, Scope *parent, VariableTableEntry *var);
ScopeCImport *create_cimport_scope(AstNode *node, Scope *parent);
ScopeLoop *create_loop_scope(AstNode *node, Scope *parent);
ScopeSuspend *create_suspend_scope(AstNode *node, Scope *parent);
ScopeFnDef *create_fndef_scope(AstNode *node, Scope *parent, FnTableEntry *fn_entry);
ScopeDecls *create_decls_scope(AstNode *node, Scope *parent, TypeTableEntry *container_type, ImportTableEntry *import);
Scope *create_comptime_scope(AstNode *node, Scope *parent);

View File

@ -654,6 +654,7 @@ static ZigLLVMDIScope *get_di_scope(CodeGen *g, Scope *scope) {
}
case ScopeIdDeferExpr:
case ScopeIdLoop:
case ScopeIdSuspend:
case ScopeIdCompTime:
case ScopeIdCoroPrelude:
return get_di_scope(g, scope->parent);

View File

@ -2829,6 +2829,18 @@ static void ir_set_cursor_at_end_and_append_block(IrBuilder *irb, IrBasicBlock *
ir_set_cursor_at_end(irb, basic_block);
}
static ScopeSuspend *get_scope_suspend(Scope *scope) {
while (scope) {
if (scope->id == ScopeIdSuspend)
return (ScopeSuspend *)scope;
if (scope->id == ScopeIdFnDef)
return nullptr;
scope = scope->parent;
}
return nullptr;
}
static ScopeDeferExpr *get_scope_defer_expr(Scope *scope) {
while (scope) {
if (scope->id == ScopeIdDeferExpr)
@ -5665,6 +5677,15 @@ static IrInstruction *ir_gen_return_from_block(IrBuilder *irb, Scope *break_scop
return ir_build_br(irb, break_scope, node, dest_block, is_comptime);
}
static IrInstruction *ir_gen_break_from_suspend(IrBuilder *irb, Scope *break_scope, AstNode *node, ScopeSuspend *suspend_scope) {
IrInstruction *is_comptime = ir_build_const_bool(irb, break_scope, node, false);
IrBasicBlock *dest_block = suspend_scope->resume_block;
ir_gen_defers_for_block(irb, break_scope, dest_block->scope, false);
return ir_build_br(irb, break_scope, node, dest_block, is_comptime);
}
static IrInstruction *ir_gen_break(IrBuilder *irb, Scope *break_scope, AstNode *node) {
assert(node->type == NodeTypeBreak);
@ -5704,6 +5725,13 @@ static IrInstruction *ir_gen_break(IrBuilder *irb, Scope *break_scope, AstNode *
assert(this_block_scope->end_block != nullptr);
return ir_gen_return_from_block(irb, break_scope, node, this_block_scope);
}
} else if (search_scope->id == ScopeIdSuspend) {
ScopeSuspend *this_suspend_scope = (ScopeSuspend *)search_scope;
if (node->data.break_expr.name != nullptr &&
(this_suspend_scope->name != nullptr && buf_eql_buf(node->data.break_expr.name, this_suspend_scope->name)))
{
return ir_gen_break_from_suspend(irb, break_scope, node, this_suspend_scope);
}
}
search_scope = search_scope->parent;
}
@ -6290,14 +6318,26 @@ static IrInstruction *ir_gen_suspend(IrBuilder *irb, Scope *parent_scope, AstNod
ScopeDeferExpr *scope_defer_expr = get_scope_defer_expr(parent_scope);
if (scope_defer_expr) {
if (!scope_defer_expr->reported_err) {
add_node_error(irb->codegen, node, buf_sprintf("cannot suspend inside defer expression"));
ErrorMsg *msg = add_node_error(irb->codegen, node, buf_sprintf("cannot suspend inside defer expression"));
add_error_note(irb->codegen, msg, scope_defer_expr->base.source_node, buf_sprintf("defer here"));
scope_defer_expr->reported_err = true;
}
return irb->codegen->invalid_instruction;
}
ScopeSuspend *existing_suspend_scope = get_scope_suspend(parent_scope);
if (existing_suspend_scope) {
if (!existing_suspend_scope->reported_err) {
ErrorMsg *msg = add_node_error(irb->codegen, node, buf_sprintf("cannot suspend inside suspend block"));
add_error_note(irb->codegen, msg, existing_suspend_scope->base.source_node, buf_sprintf("other suspend block here"));
existing_suspend_scope->reported_err = true;
}
return irb->codegen->invalid_instruction;
}
Scope *outer_scope = irb->exec->begin_scope;
IrBasicBlock *cleanup_block = ir_create_basic_block(irb, parent_scope, "SuspendCleanup");
IrBasicBlock *resume_block = ir_create_basic_block(irb, parent_scope, "SuspendResume");
IrInstruction *suspend_code;
IrInstruction *const_bool_false = ir_build_const_bool(irb, parent_scope, node, false);
@ -6316,28 +6356,28 @@ static IrInstruction *ir_gen_suspend(IrBuilder *irb, Scope *parent_scope, AstNod
} else {
child_scope = parent_scope;
}
ScopeSuspend *suspend_scope = create_suspend_scope(node, child_scope);
suspend_scope->resume_block = resume_block;
child_scope = &suspend_scope->base;
IrInstruction *save_token = ir_build_coro_save(irb, child_scope, node, irb->exec->coro_handle);
ir_gen_node(irb, node->data.suspend.block, child_scope);
suspend_code = ir_build_coro_suspend(irb, parent_scope, node, save_token, const_bool_false);
suspend_code = ir_mark_gen(ir_build_coro_suspend(irb, parent_scope, node, save_token, const_bool_false));
}
IrBasicBlock *cleanup_block = ir_create_basic_block(irb, parent_scope, "SuspendCleanup");
IrBasicBlock *resume_block = ir_create_basic_block(irb, parent_scope, "SuspendResume");
IrInstructionSwitchBrCase *cases = allocate<IrInstructionSwitchBrCase>(2);
cases[0].value = ir_build_const_u8(irb, parent_scope, node, 0);
cases[0].value = ir_mark_gen(ir_build_const_u8(irb, parent_scope, node, 0));
cases[0].block = resume_block;
cases[1].value = ir_build_const_u8(irb, parent_scope, node, 1);
cases[1].value = ir_mark_gen(ir_build_const_u8(irb, parent_scope, node, 1));
cases[1].block = cleanup_block;
ir_build_switch_br(irb, parent_scope, node, suspend_code, irb->exec->coro_suspend_block,
2, cases, const_bool_false);
ir_mark_gen(ir_build_switch_br(irb, parent_scope, node, suspend_code, irb->exec->coro_suspend_block,
2, cases, const_bool_false));
ir_set_cursor_at_end_and_append_block(irb, cleanup_block);
ir_gen_defers_for_block(irb, parent_scope, outer_scope, true);
ir_mark_gen(ir_build_br(irb, parent_scope, node, irb->exec->coro_final_cleanup_block, const_bool_false));
ir_set_cursor_at_end_and_append_block(irb, resume_block);
return ir_build_const_void(irb, parent_scope, node);
return ir_mark_gen(ir_build_const_void(irb, parent_scope, node));
}
static IrInstruction *ir_gen_node_raw(IrBuilder *irb, AstNode *node, Scope *scope,

View File

@ -648,12 +648,30 @@ static AstNode *ast_parse_asm_expr(ParseContext *pc, size_t *token_index, bool m
}
/*
SuspendExpression(body) = "suspend" "|" Symbol "|" body
SuspendExpression(body) = option(Symbol ":") "suspend" option(("|" Symbol "|" body))
*/
static AstNode *ast_parse_suspend_block(ParseContext *pc, size_t *token_index, bool mandatory) {
size_t orig_token_index = *token_index;
Token *suspend_token = &pc->tokens->at(*token_index);
Token *name_token = nullptr;
Token *token = &pc->tokens->at(*token_index);
if (token->id == TokenIdSymbol) {
*token_index += 1;
Token *colon_token = &pc->tokens->at(*token_index);
if (colon_token->id == TokenIdColon) {
*token_index += 1;
name_token = token;
token = &pc->tokens->at(*token_index);
} else if (mandatory) {
ast_expect_token(pc, colon_token, TokenIdColon);
zig_unreachable();
} else {
*token_index = orig_token_index;
return nullptr;
}
}
Token *suspend_token = token;
if (suspend_token->id == TokenIdKeywordSuspend) {
*token_index += 1;
} else if (mandatory) {
@ -675,6 +693,9 @@ static AstNode *ast_parse_suspend_block(ParseContext *pc, size_t *token_index, b
}
AstNode *node = ast_create_node(pc, NodeTypeSuspend, suspend_token);
if (name_token != nullptr) {
node->data.suspend.name = token_buf(name_token);
}
node->data.suspend.promise_symbol = ast_parse_symbol(pc, token_index);
ast_eat_token(pc, token_index, TokenIdBinOr);
node->data.suspend.block = ast_parse_block(pc, token_index, true);

View File

@ -224,3 +224,21 @@ async fn printTrace(p: promise->error!void) void {
}
};
}
test "break from suspend" {
var buf: [500]u8 = undefined;
var a = &std.heap.FixedBufferAllocator.init(buf[0..]).allocator;
var my_result: i32 = 1;
const p = try async<a> testBreakFromSuspend(&my_result);
cancel p;
std.debug.assert(my_result == 2);
}
async fn testBreakFromSuspend(my_result: &i32) void {
s: suspend |p| {
break :s;
}
*my_result += 1;
suspend;
*my_result += 1;
}

View File

@ -1,6 +1,26 @@
const tests = @import("tests.zig");
pub fn addCases(cases: &tests.CompileErrorContext) void {
cases.add("suspend inside suspend block",
\\const std = @import("std");
\\
\\export fn entry() void {
\\ var buf: [500]u8 = undefined;
\\ var a = &std.heap.FixedBufferAllocator.init(buf[0..]).allocator;
\\ const p = (async<a> foo()) catch unreachable;
\\ cancel p;
\\}
\\
\\async fn foo() void {
\\ suspend |p| {
\\ suspend |p1| {
\\ }
\\ }
\\}
,
".tmp_source.zig:12:9: error: cannot suspend inside suspend block",
".tmp_source.zig:11:5: note: other suspend block here");
cases.add("assign inline fn to non-comptime var",
\\export fn entry() void {
\\ var a = b;