# This file is a part of Julia. License is MIT: https://julialang.org/license # data stored by the owner of a remote reference def_rv_channel() = Channel(1) mutable struct RemoteValue c::AbstractChannel clientset::IntSet # Set of workerids that have a reference to this channel. # Keeping ids instead of a count aids in cleaning up upon # a worker exit. waitingfor::Int # processor we need to hear from to fill this, or 0 RemoteValue(c) = new(c, IntSet(), 0) end wait(rv::RemoteValue) = wait(rv.c) ## core messages: do, call, fetch, wait, ref, put! ## mutable struct RemoteException <: Exception pid::Int captured::CapturedException end """ RemoteException(captured) Exceptions on remote computations are captured and rethrown locally. A `RemoteException` wraps the `pid` of the worker and a captured exception. A `CapturedException` captures the remote exception and a serializable form of the call stack when the exception was raised. """ RemoteException(captured) = RemoteException(myid(), captured) function showerror(io::IO, re::RemoteException) (re.pid != myid()) && print(io, "On worker ", re.pid, ":\n") showerror(io, get_root_exception(re.captured)) end isa_exception_container(ex) = (isa(ex, RemoteException) || isa(ex, CapturedException) || isa(ex, CompositeException)) function get_root_exception(ex) if isa(ex, RemoteException) return get_root_exception(ex.captured) elseif isa(ex, CapturedException) && isa_exception_container(ex.ex) return get_root_exception(ex.ex) elseif isa(ex, CompositeException) && length(ex.exceptions) > 0 && isa_exception_container(ex.exceptions[1]) return get_root_exception(ex.exceptions[1]) else return ex end end function run_work_thunk(thunk, print_error) local result try result = thunk() catch err ce = CapturedException(err, catch_backtrace()) result = RemoteException(ce) print_error && showerror(STDERR, ce) end return result end function run_work_thunk(rv::RemoteValue, thunk) put!(rv, run_work_thunk(thunk, false)) nothing end function schedule_call(rid, thunk) return lock(client_refs) do rv = RemoteValue(def_rv_channel()) (PGRP::ProcessGroup).refs[rid] = rv push!(rv.clientset, rid.whence) @schedule run_work_thunk(rv, thunk) return rv end end function deliver_result(sock::IO, msg, oid, value) #print("$(myid()) sending result $oid\n") if msg === :call_fetch || isa(value, RemoteException) val = value else val = :OK end try send_msg_now(sock, MsgHeader(oid), ResultMsg(val)) catch e # terminate connection in case of serialization error # otherwise the reading end would hang print(STDERR, "fatal error on ", myid(), ": ") display_error(e, catch_backtrace()) wid = worker_id_from_socket(sock) close(sock) if myid()==1 rmprocs(wid) elseif wid == 1 exit(1) else remote_do(rmprocs, 1, wid) end end end ## message event handlers ## function process_messages(r_stream::TCPSocket, w_stream::TCPSocket, incoming::Bool=true) @schedule process_tcp_streams(r_stream, w_stream, incoming) end function process_tcp_streams(r_stream::TCPSocket, w_stream::TCPSocket, incoming::Bool) disable_nagle(r_stream) wait_connected(r_stream) if r_stream != w_stream disable_nagle(w_stream) wait_connected(w_stream) end message_handler_loop(r_stream, w_stream, incoming) end """ Base.process_messages(r_stream::IO, w_stream::IO, incoming::Bool=true) Called by cluster managers using custom transports. It should be called when the custom transport implementation receives the first message from a remote worker. The custom transport must manage a logical connection to the remote worker and provide two `IO` objects, one for incoming messages and the other for messages addressed to the remote worker. If `incoming` is `true`, the remote peer initiated the connection. Whichever of the pair initiates the connection sends the cluster cookie and its Julia version number to perform the authentication handshake. See also [`cluster_cookie`](@ref). """ function process_messages(r_stream::IO, w_stream::IO, incoming::Bool=true) @schedule message_handler_loop(r_stream, w_stream, incoming) end function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool) wpid=0 # the worker r_stream is connected to. boundary = similar(MSG_BOUNDARY) try version = process_hdr(r_stream, incoming) serializer = ClusterSerializer(r_stream) # The first message will associate wpid with r_stream header = deserialize_hdr_raw(r_stream) msg = deserialize_msg(serializer) handle_msg(msg, header, r_stream, w_stream, version) wpid = worker_id_from_socket(r_stream) @assert wpid > 0 readbytes!(r_stream, boundary, length(MSG_BOUNDARY)) while true reset_state(serializer) header = deserialize_hdr_raw(r_stream) # println("header: ", header) try msg = invokelatest(deserialize_msg, serializer) catch e # Deserialization error; discard bytes in stream until boundary found boundary_idx = 1 while true # This may throw an EOF error if the terminal boundary was not written # correctly, triggering the higher-scoped catch block below byte = read(r_stream, UInt8) if byte == MSG_BOUNDARY[boundary_idx] boundary_idx += 1 if boundary_idx > length(MSG_BOUNDARY) break end else boundary_idx = 1 end end # remotecalls only rethrow RemoteExceptions. Any other exception is treated as # data to be returned. Wrap this exception in a RemoteException. remote_err = RemoteException(myid(), CapturedException(e, catch_backtrace())) # println("Deserialization error. ", remote_err) if !null_id(header.response_oid) ref = lookup_ref(header.response_oid) put!(ref, remote_err) end if !null_id(header.notify_oid) deliver_result(w_stream, :call_fetch, header.notify_oid, remote_err) end continue end readbytes!(r_stream, boundary, length(MSG_BOUNDARY)) # println("got msg: ", typeof(msg)) handle_msg(msg, header, r_stream, w_stream, version) end catch e # Check again as it may have been set in a message handler but not propagated to the calling block above wpid = worker_id_from_socket(r_stream) if (wpid < 1) println(STDERR, e, CapturedException(e, catch_backtrace())) println(STDERR, "Process($(myid())) - Unknown remote, closing connection.") else werr = worker_from_id(wpid) oldstate = werr.state set_worker_state(werr, W_TERMINATED) # If unhandleable error occurred talking to pid 1, exit if wpid == 1 if isopen(w_stream) print(STDERR, "fatal error on ", myid(), ": ") display_error(e, catch_backtrace()) end exit(1) end # Will treat any exception as death of node and cleanup # since currently we do not have a mechanism for workers to reconnect # to each other on unhandled errors deregister_worker(wpid) end isopen(r_stream) && close(r_stream) isopen(w_stream) && close(w_stream) if (myid() == 1) && (wpid > 1) if oldstate != W_TERMINATING println(STDERR, "Worker $wpid terminated.") rethrow(e) end end return nothing end end function process_hdr(s, validate_cookie) if validate_cookie cookie = read(s, HDR_COOKIE_LEN) if length(cookie) < HDR_COOKIE_LEN error("Cookie read failed. Connection closed by peer.") end self_cookie = cluster_cookie() for i in 1:HDR_COOKIE_LEN if UInt8(self_cookie[i]) != cookie[i] error("Process($(myid())) - Invalid connection credentials sent by remote.") end end end # When we have incompatible julia versions trying to connect to each other, # and can be detected, raise an appropriate error. # For now, just return the version. version = read(s, HDR_VERSION_LEN) if length(version) < HDR_VERSION_LEN error("Version read failed. Connection closed by peer.") end return VersionNumber(strip(String(version))) end function handle_msg(msg::CallMsg{:call}, header, r_stream, w_stream, version) schedule_call(header.response_oid, ()->msg.f(msg.args...; msg.kwargs...)) end function handle_msg(msg::CallMsg{:call_fetch}, header, r_stream, w_stream, version) @schedule begin v = run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), false) deliver_result(w_stream, :call_fetch, header.notify_oid, v) end end function handle_msg(msg::CallWaitMsg, header, r_stream, w_stream, version) @schedule begin rv = schedule_call(header.response_oid, ()->msg.f(msg.args...; msg.kwargs...)) deliver_result(w_stream, :call_wait, header.notify_oid, fetch(rv.c)) end end function handle_msg(msg::RemoteDoMsg, header, r_stream, w_stream, version) @schedule run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), true) end function handle_msg(msg::ResultMsg, header, r_stream, w_stream, version) put!(lookup_ref(header.response_oid), msg.value) end function handle_msg(msg::IdentifySocketMsg, header, r_stream, w_stream, version) # register a new peer worker connection w=Worker(msg.from_pid, r_stream, w_stream, cluster_manager; version=version) send_connection_hdr(w, false) send_msg_now(w, MsgHeader(), IdentifySocketAckMsg()) end function handle_msg(msg::IdentifySocketAckMsg, header, r_stream, w_stream, version) w = map_sock_wrkr[r_stream] w.version = version end function handle_msg(msg::JoinPGRPMsg, header, r_stream, w_stream, version) LPROC.id = msg.self_pid controller = Worker(1, r_stream, w_stream, cluster_manager; version=version) register_worker(LPROC) topology(msg.topology) if !msg.enable_threaded_blas disable_threaded_libs() end wait_tasks = Task[] for (connect_at, rpid) in msg.other_workers wconfig = WorkerConfig() wconfig.connect_at = connect_at let rpid=rpid, wconfig=wconfig t = @async connect_to_peer(cluster_manager, rpid, wconfig) push!(wait_tasks, t) end end for wt in wait_tasks; wait(wt); end send_connection_hdr(controller, false) send_msg_now(controller, MsgHeader(RRID(0,0), header.notify_oid), JoinCompleteMsg(Sys.CPU_CORES, getpid())) end function connect_to_peer(manager::ClusterManager, rpid::Int, wconfig::WorkerConfig) try (r_s, w_s) = connect(manager, rpid, wconfig) w = Worker(rpid, r_s, w_s, manager; config=wconfig) process_messages(w.r_stream, w.w_stream, false) send_connection_hdr(w, true) send_msg_now(w, MsgHeader(), IdentifySocketMsg(myid())) catch e display_error(e, catch_backtrace()) println(STDERR, "Error [$e] on $(myid()) while connecting to peer $rpid. Exiting.") exit(1) end end function handle_msg(msg::JoinCompleteMsg, header, r_stream, w_stream, version) w = map_sock_wrkr[r_stream] environ = get(w.config.environ, Dict()) environ[:cpu_cores] = msg.cpu_cores w.config.environ = environ w.config.ospid = msg.ospid w.version = version ntfy_channel = lookup_ref(header.notify_oid) put!(ntfy_channel, w.id) push!(default_worker_pool(), w.id) end