Add: julia-0.6.2

Former-commit-id: ccc667cf67d569f3fb3df39aa57c2134755a7551
This commit is contained in:
2018-02-10 10:27:19 -07:00
parent 94220957d7
commit 019f8e3064
723 changed files with 276164 additions and 0 deletions

View File

@@ -0,0 +1,76 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license
module Distributed
# imports for extension
import Base: getindex, wait, put!, take!, fetch, isready, push!, length,
hash, ==, connect, kill, serialize, deserialize, close, showerror
# imports for use
using Base: Process, Semaphore, JLOptions, AnyDict, buffer_writes, wait_connected,
VERSION_STRING, sync_begin, sync_add, sync_end, async_run_thunk,
binding_module, notify_error, atexit, julia_exename, julia_cmd,
AsyncGenerator, display_error, acquire, release, invokelatest, warn_once,
shell_escape, uv_error
# NOTE: clusterserialize.jl imports additional symbols from Base.Serializer for use
export
@spawn,
@spawnat,
@fetch,
@fetchfrom,
@everywhere,
@parallel,
addprocs,
CachingPool,
clear!,
ClusterManager,
default_worker_pool,
init_worker,
interrupt,
launch,
manage,
myid,
nprocs,
nworkers,
pmap,
procs,
remote,
remotecall,
remotecall_fetch,
remotecall_wait,
remote_do,
rmprocs,
workers,
WorkerPool,
RemoteChannel,
Future,
WorkerConfig,
RemoteException,
ProcessExitedException,
# Add the following into Base as some Packages access them via Base.
# Also documented as such.
process_messages,
remoteref_id,
channel_from_id,
worker_id_from_socket,
cluster_cookie,
start_worker,
# Used only by shared arrays.
check_same_host
include("clusterserialize.jl")
include("cluster.jl") # cluster setup and management, addprocs
include("messages.jl")
include("process_messages.jl") # process incoming messages
include("remotecall.jl") # the remotecall* api
include("macros.jl") # @spawn and friends
include("workerpool.jl")
include("pmap.jl")
include("managers.jl") # LocalManager and SSHManager
end

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,249 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license
using Base.Serializer: object_number, serialize_cycle, deserialize_cycle, writetag,
__deserialized_types__, serialize_typename, deserialize_typename,
TYPENAME_TAG, object_numbers, reset_state, serialize_type
import Base.Serializer: lookup_object_number, remember_object
mutable struct ClusterSerializer{I<:IO} <: AbstractSerializer
io::I
counter::Int
table::ObjectIdDict
pending_refs::Vector{Int}
pid::Int # Worker we are connected to.
tn_obj_sent::Set{UInt64} # TypeName objects sent
glbs_sent::Dict{UInt64, UInt64} # (key,value) -> (object_id, hash_value)
glbs_in_tnobj::Dict{UInt64, Vector{Symbol}} # Track globals referenced in
# anonymous functions.
anonfunc_id::UInt64
function ClusterSerializer{I}(io::I) where I<:IO
new(io, 0, ObjectIdDict(), Int[], Base.worker_id_from_socket(io),
Set{UInt64}(), Dict{UInt64, UInt64}(), Dict{UInt64, Vector{Symbol}}(), 0)
end
end
ClusterSerializer(io::IO) = ClusterSerializer{typeof(io)}(io)
const known_object_data = Dict{UInt64,Any}()
function lookup_object_number(s::ClusterSerializer, n::UInt64)
return get(known_object_data, n, nothing)
end
function remember_object(s::ClusterSerializer, o::ANY, n::UInt64)
known_object_data[n] = o
if isa(o, TypeName) && !haskey(object_numbers, o)
# set up reverse mapping for serialize
object_numbers[o] = n
end
return nothing
end
function deserialize(s::ClusterSerializer, ::Type{TypeName})
full_body_sent = deserialize(s)
number = read(s.io, UInt64)
if !full_body_sent
tn = lookup_object_number(s, number)::TypeName
remember_object(s, tn, number)
deserialize_cycle(s, tn)
else
tn = deserialize_typename(s, number)
end
# retrieve arrays of global syms sent if any and deserialize them all.
foreach(sym->deserialize_global_from_main(s, sym), deserialize(s))
return tn
end
function serialize(s::ClusterSerializer, t::TypeName)
serialize_cycle(s, t) && return
writetag(s.io, TYPENAME_TAG)
identifier = object_number(t)
send_whole = !(identifier in s.tn_obj_sent)
serialize(s, send_whole)
write(s.io, identifier)
if send_whole
# Track globals referenced in this anonymous function.
# This information is used to resend modified globals when we
# only send the identifier.
prev = s.anonfunc_id
s.anonfunc_id = identifier
serialize_typename(s, t)
s.anonfunc_id = prev
push!(s.tn_obj_sent, identifier)
finalizer(t, x->cleanup_tname_glbs(s, identifier))
end
# Send global refs if required.
syms = syms_2b_sent(s, identifier)
serialize(s, syms)
foreach(sym->serialize_global_from_main(s, sym), syms)
nothing
end
function serialize(s::ClusterSerializer, g::GlobalRef)
# Record if required and then invoke the default GlobalRef serializer.
sym = g.name
if g.mod === Main && isdefined(g.mod, sym)
if (binding_module(Main, sym) === Main) && (s.anonfunc_id != 0) &&
!startswith(string(sym), "#") # Anonymous functions are handled via FULL_GLOBALREF_TAG
push!(get!(s.glbs_in_tnobj, s.anonfunc_id, []), sym)
end
end
invoke(serialize, Tuple{AbstractSerializer, GlobalRef}, s, g)
end
# Send/resend a global object if
# a) has not been sent previously, i.e., we are seeing this object_id for the first time, or,
# b) hash value has changed or
# c) is a bits type
function syms_2b_sent(s::ClusterSerializer, identifier)
lst = Symbol[]
check_syms = get(s.glbs_in_tnobj, identifier, [])
for sym in check_syms
v = getfield(Main, sym)
if isbits(v)
push!(lst, sym)
else
oid = object_id(v)
if haskey(s.glbs_sent, oid)
# We have sent this object before, see if it has changed.
s.glbs_sent[oid] != hash(sym, hash(v)) && push!(lst, sym)
else
push!(lst, sym)
end
end
end
return unique(lst)
end
function serialize_global_from_main(s::ClusterSerializer, sym)
v = getfield(Main, sym)
oid = object_id(v)
record_v = true
if isbits(v)
record_v = false
elseif !haskey(s.glbs_sent, oid)
# set up a finalizer the first time this object is sent
try
finalizer(v, x->delete_global_tracker(s,x))
catch ex
# Do not track objects that cannot be finalized.
if isa(ex, ErrorException)
record_v = false
else
rethrow(ex)
end
end
end
record_v && (s.glbs_sent[oid] = hash(sym, hash(v)))
serialize(s, isconst(Main, sym))
serialize(s, v)
end
function deserialize_global_from_main(s::ClusterSerializer, sym)
sym_isconst = deserialize(s)
v = deserialize(s)
if sym_isconst
@eval Main const $sym = $v
else
@eval Main $sym = $v
end
end
function delete_global_tracker(s::ClusterSerializer, v)
oid = object_id(v)
if haskey(s.glbs_sent, oid)
delete!(s.glbs_sent, oid)
end
# TODO: A global binding is released and gc'ed here but it continues
# to occupy memory on the remote node. Would be nice to release memory
# if possible.
end
function cleanup_tname_glbs(s::ClusterSerializer, identifier)
delete!(s.glbs_in_tnobj, identifier)
end
# TODO: cleanup from s.tn_obj_sent
# Specialized serialize-deserialize implementations for CapturedException to partially
# recover from any deserialization errors in `CapturedException.ex`
function serialize(s::ClusterSerializer, ex::CapturedException)
serialize_type(s, typeof(ex))
serialize(s, string(typeof(ex.ex))) # String type should not result in a deser error
serialize(s, ex.processed_bt) # Currently should not result in a deser error
serialize(s, ex.ex) # can result in a UndefVarError on the remote node
# if a type used in ex.ex is undefined on the remote node.
end
function original_ex(s::ClusterSerializer, ex_str, remote_stktrace)
local pid_str = ""
try
pid_str = string(" from worker ", worker_id_from_socket(s.io))
end
stk_str = remote_stktrace ? "Remote" : "Local"
ErrorException(string("Error deserializing a remote exception", pid_str, "\n",
"Remote(original) exception of type ", ex_str, "\n",
stk_str, " stacktrace : "))
end
function deserialize(s::ClusterSerializer, t::Type{<:CapturedException})
ex_str = deserialize(s)
local bt
local capex
try
bt = deserialize(s)
catch e
throw(CompositeException([
original_ex(s, ex_str, false),
CapturedException(e, catch_backtrace())
]))
end
try
capex = deserialize(s)
catch e
throw(CompositeException([
CapturedException(original_ex(s, ex_str, true), bt),
CapturedException(e, catch_backtrace())
]))
end
return CapturedException(capex, bt)
end
"""
clear!(syms, pids=workers(); mod=Main)
Clears global bindings in modules by initializing them to `nothing`.
`syms` should be of type `Symbol` or a collection of `Symbol`s . `pids` and `mod`
identify the processes and the module in which global variables are to be
reinitialized. Only those names found to be defined under `mod` are cleared.
An exception is raised if a global constant is requested to be cleared.
"""
function clear!(syms, pids=workers(); mod=Main)
@sync for p in pids
@async remotecall_wait(clear_impl!, p, syms, mod)
end
end
clear!(sym::Symbol, pid::Int; mod=Main) = clear!([sym], [pid]; mod=mod)
clear!(sym::Symbol, pids=workers(); mod=Main) = clear!([sym], pids; mod=mod)
clear!(syms, pid::Int; mod=Main) = clear!(syms, [pid]; mod=mod)
clear_impl!(syms, mod::Module) = foreach(x->clear_impl!(x,mod), syms)
clear_impl!(sym::Symbol, mod::Module) = isdefined(mod, sym) && @eval(mod, global $sym = nothing)

View File

@@ -0,0 +1,223 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license
let nextidx = 0
global nextproc
function nextproc()
p = -1
if p == -1
p = workers()[(nextidx % nworkers()) + 1]
nextidx += 1
end
p
end
end
spawnat(p, thunk) = sync_add(remotecall(thunk, p))
spawn_somewhere(thunk) = spawnat(nextproc(),thunk)
macro spawn(expr)
thunk = esc(:(()->($expr)))
:(spawn_somewhere($thunk))
end
macro spawnat(p, expr)
thunk = esc(:(()->($expr)))
:(spawnat($(esc(p)), $thunk))
end
"""
@fetch
Equivalent to `fetch(@spawn expr)`.
See [`fetch`](@ref) and [`@spawn`](@ref).
"""
macro fetch(expr)
thunk = esc(:(()->($expr)))
:(remotecall_fetch($thunk, nextproc()))
end
"""
@fetchfrom
Equivalent to `fetch(@spawnat p expr)`.
See [`fetch`](@ref) and [`@spawnat`](@ref).
"""
macro fetchfrom(p, expr)
thunk = esc(:(()->($expr)))
:(remotecall_fetch($thunk, $(esc(p))))
end
# extract a list of modules to import from an expression
extract_imports!(imports, x) = imports
function extract_imports!(imports, ex::Expr)
if Meta.isexpr(ex, (:import, :using))
return push!(imports, ex.args[1])
elseif Meta.isexpr(ex, :let)
return extract_imports!(imports, ex.args[1])
elseif Meta.isexpr(ex, (:toplevel, :block))
for i in eachindex(ex.args)
extract_imports!(imports, ex.args[i])
end
end
return imports
end
extract_imports(x) = extract_imports!(Symbol[], x)
"""
@everywhere expr
Execute an expression under `Main` everywhere. Equivalent to calling
`eval(Main, expr)` on all processes. Errors on any of the processes are collected into a
`CompositeException` and thrown. For example :
@everywhere bar=1
will define `Main.bar` on all processes.
Unlike [`@spawn`](@ref) and [`@spawnat`](@ref),
`@everywhere` does not capture any local variables. Prefixing
`@everywhere` with [`@eval`](@ref) allows us to broadcast
local variables using interpolation :
foo = 1
@eval @everywhere bar=\$foo
The expression is evaluated under `Main` irrespective of where `@everywhere` is called from.
For example :
module FooBar
foo() = @everywhere bar()=myid()
end
FooBar.foo()
will result in `Main.bar` being defined on all processes and not `FooBar.bar`.
"""
macro everywhere(ex)
imps = [Expr(:import, m) for m in extract_imports(ex)]
quote
$(isempty(imps) ? nothing : Expr(:toplevel, imps...))
sync_begin()
for pid in workers()
async_run_thunk(()->remotecall_fetch(eval_ew_expr, pid, $(Expr(:quote,ex))))
yield() # ensure that the remotecall_fetch has been started
end
# execute locally last as we do not want local execution to block serialization
# of the request to remote nodes.
if nprocs() > 1
async_run_thunk(()->eval_ew_expr($(Expr(:quote,ex))))
end
sync_end()
end
end
eval_ew_expr(ex) = (eval(Main, ex); nothing)
# Statically split range [1,N] into equal sized chunks for np processors
function splitrange(N::Int, np::Int)
each = div(N,np)
extras = rem(N,np)
nchunks = each > 0 ? np : extras
chunks = Vector{UnitRange{Int}}(nchunks)
lo = 1
for i in 1:nchunks
hi = lo + each - 1
if extras > 0
hi += 1
extras -= 1
end
chunks[i] = lo:hi
lo = hi+1
end
return chunks
end
function preduce(reducer, f, R)
N = length(R)
chunks = splitrange(N, nworkers())
all_w = workers()[1:length(chunks)]
w_exec = Task[]
for (idx,pid) in enumerate(all_w)
t = Task(()->remotecall_fetch(f, pid, reducer, R, first(chunks[idx]), last(chunks[idx])))
schedule(t)
push!(w_exec, t)
end
reduce(reducer, [wait(t) for t in w_exec])
end
function pfor(f, R)
[@spawn f(R, first(c), last(c)) for c in splitrange(length(R), nworkers())]
end
function make_preduce_body(var, body)
quote
function (reducer, R, lo::Int, hi::Int)
$(esc(var)) = R[lo]
ac = $(esc(body))
if lo != hi
for $(esc(var)) in R[(lo+1):hi]
ac = reducer(ac, $(esc(body)))
end
end
ac
end
end
end
function make_pfor_body(var, body)
quote
function (R, lo::Int, hi::Int)
for $(esc(var)) in R[lo:hi]
$(esc(body))
end
end
end
end
"""
@parallel
A parallel for loop of the form :
@parallel [reducer] for var = range
body
end
The specified range is partitioned and locally executed across all workers. In case an
optional reducer function is specified, `@parallel` performs local reductions on each worker
with a final reduction on the calling process.
Note that without a reducer function, `@parallel` executes asynchronously, i.e. it spawns
independent tasks on all available workers and returns immediately without waiting for
completion. To wait for completion, prefix the call with [`@sync`](@ref), like :
@sync @parallel for var = range
body
end
"""
macro parallel(args...)
na = length(args)
if na==1
loop = args[1]
elseif na==2
reducer = args[1]
loop = args[2]
else
throw(ArgumentError("wrong number of arguments to @parallel"))
end
if !isa(loop,Expr) || loop.head !== :for
error("malformed @parallel loop")
end
var = loop.args[1].args[1]
r = loop.args[1].args[2]
body = loop.args[2]
if na==1
thecall = :(pfor($(make_pfor_body(var, body)), $(esc(r))))
else
thecall = :(preduce($(esc(reducer)), $(make_preduce_body(var, body)), $(esc(r))))
end
thecall
end

View File

@@ -0,0 +1,530 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license
# Built-in SSH and Local Managers
struct SSHManager <: ClusterManager
machines::Dict
function SSHManager(machines)
# machines => array of machine elements
# machine => address or (address, cnt)
# address => string of form `[user@]host[:port] bind_addr[:bind_port]`
# cnt => :auto or number
# :auto launches NUM_CORES number of workers at address
# number launches the specified number of workers at address
mhist = Dict()
for m in machines
if isa(m, Tuple)
host=m[1]
cnt=m[2]
else
host=m
cnt=1
end
current_cnt = get(mhist, host, 0)
if isa(cnt, Number)
mhist[host] = isa(current_cnt, Number) ? current_cnt + Int(cnt) : Int(cnt)
else
mhist[host] = cnt
end
end
new(mhist)
end
end
function check_addprocs_args(kwargs)
valid_kw_names = collect(keys(default_addprocs_params()))
for keyname in kwargs
!(keyname[1] in valid_kw_names) && throw(ArgumentError("Invalid keyword argument $(keyname[1])"))
end
end
# SSHManager
# start and connect to processes via SSH, optionally through an SSH tunnel.
# the tunnel is only used from the head (process 1); the nodes are assumed
# to be mutually reachable without a tunnel, as is often the case in a cluster.
# Default value of kw arg max_parallel is the default value of MaxStartups in sshd_config
# A machine is either a <hostname> or a tuple of (<hostname>, count)
"""
addprocs(machines; tunnel=false, sshflags=\`\`, max_parallel=10, kwargs...) -> List of process identifiers
Add processes on remote machines via SSH. Requires `julia` to be installed in the same
location on each node, or to be available via a shared file system.
`machines` is a vector of machine specifications. Workers are started for each specification.
A machine specification is either a string `machine_spec` or a tuple - `(machine_spec, count)`.
`machine_spec` is a string of the form `[user@]host[:port] [bind_addr[:port]]`. `user`
defaults to current user, `port` to the standard ssh port. If `[bind_addr[:port]]` is
specified, other workers will connect to this worker at the specified `bind_addr` and
`port`.
`count` is the number of workers to be launched on the specified host. If specified as
`:auto` it will launch as many workers as the number of cores on the specific host.
Keyword arguments:
* `tunnel`: if `true` then SSH tunneling will be used to connect to the worker from the
master process. Default is `false`.
* `sshflags`: specifies additional ssh options, e.g. ```sshflags=\`-i /home/foo/bar.pem\````
* `max_parallel`: specifies the maximum number of workers connected to in parallel at a
host. Defaults to 10.
* `dir`: specifies the working directory on the workers. Defaults to the host's current
directory (as found by `pwd()`)
* `enable_threaded_blas`: if `true` then BLAS will run on multiple threads in added
processes. Default is `false`.
* `exename`: name of the `julia` executable. Defaults to `"\$JULIA_HOME/julia"` or
`"\$JULIA_HOME/julia-debug"` as the case may be.
* `exeflags`: additional flags passed to the worker processes.
* `topology`: Specifies how the workers connect to each other. Sending a message between
unconnected workers results in an error.
+ `topology=:all_to_all`: All processes are connected to each other. The default.
+ `topology=:master_slave`: Only the driver process, i.e. `pid` 1 connects to the
workers. The workers do not connect to each other.
+ `topology=:custom`: The `launch` method of the cluster manager specifies the
connection topology via fields `ident` and `connect_idents` in `WorkerConfig`.
A worker with a cluster manager identity `ident` will connect to all workers specified
in `connect_idents`.
Environment variables :
If the master process fails to establish a connection with a newly launched worker within
60.0 seconds, the worker treats it as a fatal situation and terminates.
This timeout can be controlled via environment variable `JULIA_WORKER_TIMEOUT`.
The value of `JULIA_WORKER_TIMEOUT` on the master process specifies the number of seconds a
newly launched worker waits for connection establishment.
"""
function addprocs(machines::AbstractVector; tunnel=false, sshflags=``, max_parallel=10, kwargs...)
check_addprocs_args(kwargs)
addprocs(SSHManager(machines); tunnel=tunnel, sshflags=sshflags, max_parallel=max_parallel, kwargs...)
end
function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy::Condition)
# Launch one worker on each unique host in parallel. Additional workers are launched later.
# Wait for all launches to complete.
launch_tasks = Vector{Any}(length(manager.machines))
for (i,(machine, cnt)) in enumerate(manager.machines)
let machine=machine, cnt=cnt
launch_tasks[i] = @schedule try
launch_on_machine(manager, machine, cnt, params, launched, launch_ntfy)
catch e
print(STDERR, "exception launching on machine $(machine) : $(e)\n")
end
end
end
for t in launch_tasks
wait(t)
end
notify(launch_ntfy)
end
show(io::IO, manager::SSHManager) = println(io, "SSHManager(machines=", manager.machines, ")")
function launch_on_machine(manager::SSHManager, machine, cnt, params, launched, launch_ntfy::Condition)
dir = params[:dir]
exename = params[:exename]
exeflags = params[:exeflags]
# machine could be of the format [user@]host[:port] bind_addr[:bind_port]
# machine format string is split on whitespace
machine_bind = split(machine)
if isempty(machine_bind)
throw(ArgumentError("invalid machine definition format string: \"$machine\$"))
end
if length(machine_bind) > 1
exeflags = `--bind-to $(machine_bind[2]) $exeflags`
end
exeflags = `$exeflags --worker $(cluster_cookie())`
machine_def = split(machine_bind[1], ':')
# if this machine def has a port number, add the port information to the ssh flags
if length(machine_def) > 2
throw(ArgumentError("invalid machine definition format string: invalid port format \"$machine_def\""))
end
host = machine_def[1]
portopt = ``
if length(machine_def) == 2
portstr = machine_def[2]
if !all(isdigit, portstr) || (p = parse(Int,portstr); p < 1 || p > 65535)
msg = "invalid machine definition format string: invalid port format \"$machine_def\""
throw(ArgumentError(msg))
end
portopt = ` -p $(machine_def[2]) `
end
sshflags = `$(params[:sshflags]) $portopt`
# Build up the ssh command
# the default worker timeout
tval = haskey(ENV, "JULIA_WORKER_TIMEOUT") ?
`export JULIA_WORKER_TIMEOUT=$(ENV["JULIA_WORKER_TIMEOUT"])\;` : ``
# Julia process with passed in command line flag arguments
cmd = `cd $dir '&&' $tval $exename $exeflags`
# shell login (-l) with string command (-c) to launch julia process
cmd = `sh -l -c $(shell_escape(cmd))`
# remote launch with ssh with given ssh flags / host / port information
# -T → disable pseudo-terminal allocation
# -a → disable forwarding of auth agent connection
# -x → disable X11 forwarding
# -o ClearAllForwardings → option if forwarding connections and
# forwarded connections are causing collisions
# -n → Redirects stdin from /dev/null (actually, prevents reading from stdin).
# Used when running ssh in the background.
cmd = `ssh -T -a -x -o ClearAllForwardings=yes -n $sshflags $host $(shell_escape(cmd))`
# launch the remote Julia process
# detach launches the command in a new process group, allowing it to outlive
# the initial julia process (Ctrl-C and teardown methods are handled through messages)
# for the launched processes.
io, pobj = open(pipeline(detach(cmd), stderr=STDERR), "r")
wconfig = WorkerConfig()
wconfig.io = io
wconfig.host = host
wconfig.tunnel = params[:tunnel]
wconfig.sshflags = sshflags
wconfig.exeflags = exeflags
wconfig.exename = exename
wconfig.count = cnt
wconfig.max_parallel = params[:max_parallel]
wconfig.enable_threaded_blas = params[:enable_threaded_blas]
push!(launched, wconfig)
notify(launch_ntfy)
end
function manage(manager::SSHManager, id::Integer, config::WorkerConfig, op::Symbol)
if op == :interrupt
ospid = get(config.ospid, 0)
if ospid > 0
host = get(config.host)
sshflags = get(config.sshflags)
if !success(`ssh -T -a -x -o ClearAllForwardings=yes -n $sshflags $host "kill -2 $ospid"`)
warn(STDERR,"error sending a Ctrl-C to julia worker $id on $host")
end
else
# This state can happen immediately after an addprocs
warn(STDERR,"worker $id cannot be presently interrupted.")
end
end
end
let tunnel_port = 9201
global next_tunnel_port
function next_tunnel_port()
retval = tunnel_port
if tunnel_port > 32000
tunnel_port = 9201
else
tunnel_port += 1
end
retval
end
end
"""
ssh_tunnel(user, host, bind_addr, port, sshflags) -> localport
Establish an SSH tunnel to a remote worker.
Returns a port number `localport` such that `localhost:localport` connects to `host:port`.
"""
function ssh_tunnel(user, host, bind_addr, port, sshflags)
port = Int(port)
cnt = ntries = 100
# if we cannot do port forwarding, bail immediately
# the connection is forwarded to `port` on the remote server over the local port `localport`
# the -f option backgrounds the ssh session
# `sleep 60` command specifies that an alloted time of 60 seconds is allowed to start the
# remote julia process and establish the network connections specified by the process topology.
# If no connections are made within 60 seconds, ssh will exit and an error will be printed on the
# process that launched the remote process.
ssh = `ssh -T -a -x -o ExitOnForwardFailure=yes`
while cnt > 0
localport = next_tunnel_port()
if success(detach(`$ssh -f $sshflags $user@$host -L $localport:$bind_addr:$port sleep 60`))
return localport
end
cnt -= 1
end
throw(ErrorException(
string("unable to create SSH tunnel after ", ntries, " tries. No free port?")))
end
# LocalManager
struct LocalManager <: ClusterManager
np::Integer
restrict::Bool # Restrict binding to 127.0.0.1 only
end
"""
addprocs(; kwargs...) -> List of process identifiers
Equivalent to `addprocs(Sys.CPU_CORES; kwargs...)`
Note that workers do not run a `.juliarc.jl` startup script, nor do they synchronize their
global state (such as global variables, new method definitions, and loaded modules) with any
of the other running processes.
"""
addprocs(; kwargs...) = addprocs(Sys.CPU_CORES; kwargs...)
"""
addprocs(np::Integer; restrict=true, kwargs...) -> List of process identifiers
Launches workers using the in-built `LocalManager` which only launches workers on the
local host. This can be used to take advantage of multiple cores. `addprocs(4)` will add 4
processes on the local machine. If `restrict` is `true`, binding is restricted to
`127.0.0.1`. Keyword args `dir`, `exename`, `exeflags`, `topology`, and
`enable_threaded_blas` have the same effect as documented for `addprocs(machines)`.
"""
function addprocs(np::Integer; restrict=true, kwargs...)
check_addprocs_args(kwargs)
addprocs(LocalManager(np, restrict); kwargs...)
end
show(io::IO, manager::LocalManager) = println(io, "LocalManager()")
function launch(manager::LocalManager, params::Dict, launched::Array, c::Condition)
dir = params[:dir]
exename = params[:exename]
exeflags = params[:exeflags]
bind_to = manager.restrict ? `127.0.0.1` : `$(LPROC.bind_addr)`
for i in 1:manager.np
io, pobj = open(pipeline(detach(
setenv(`$(julia_cmd(exename)) $exeflags --bind-to $bind_to --worker $(cluster_cookie())`, dir=dir)),
stderr=STDERR), "r")
wconfig = WorkerConfig()
wconfig.process = pobj
wconfig.io = io
wconfig.enable_threaded_blas = params[:enable_threaded_blas]
push!(launched, wconfig)
end
notify(c)
end
function manage(manager::LocalManager, id::Integer, config::WorkerConfig, op::Symbol)
if op == :interrupt
kill(get(config.process), 2)
end
end
"""
launch(manager::ClusterManager, params::Dict, launched::Array, launch_ntfy::Condition)
Implemented by cluster managers. For every Julia worker launched by this function, it should
append a `WorkerConfig` entry to `launched` and notify `launch_ntfy`. The function MUST exit
once all workers, requested by `manager` have been launched. `params` is a dictionary of all
keyword arguments [`addprocs`](@ref) was called with.
"""
launch
"""
manage(manager::ClusterManager, id::Integer, config::WorkerConfig. op::Symbol)
Implemented by cluster managers. It is called on the master process, during a worker's
lifetime, with appropriate `op` values:
- with `:register`/`:deregister` when a worker is added / removed from the Julia worker pool.
- with `:interrupt` when `interrupt(workers)` is called. The `ClusterManager`
should signal the appropriate worker with an interrupt signal.
- with `:finalize` for cleanup purposes.
"""
manage
# DefaultClusterManager for the default TCP transport - used by both SSHManager and LocalManager
struct DefaultClusterManager <: ClusterManager
end
const tunnel_hosts_map = Dict{AbstractString, Semaphore}()
"""
connect(manager::ClusterManager, pid::Int, config::WorkerConfig) -> (instrm::IO, outstrm::IO)
Implemented by cluster managers using custom transports. It should establish a logical
connection to worker with id `pid`, specified by `config` and return a pair of `IO`
objects. Messages from `pid` to current process will be read off `instrm`, while messages to
be sent to `pid` will be written to `outstrm`. The custom transport implementation must
ensure that messages are delivered and received completely and in order.
`connect(manager::ClusterManager.....)` sets up TCP/IP socket connections in-between
workers.
"""
function connect(manager::ClusterManager, pid::Int, config::WorkerConfig)
if !isnull(config.connect_at)
# this is a worker-to-worker setup call.
return connect_w2w(pid, config)
end
# master connecting to workers
if !isnull(config.io)
(bind_addr, port) = read_worker_host_port(get(config.io))
pubhost=get(config.host, bind_addr)
config.host = pubhost
config.port = port
else
pubhost=get(config.host)
port=get(config.port)
bind_addr=get(config.bind_addr, pubhost)
end
tunnel = get(config.tunnel, false)
s = split(pubhost,'@')
user = ""
if length(s) > 1
user = s[1]
pubhost = s[2]
else
if haskey(ENV, "USER")
user = ENV["USER"]
elseif tunnel
error("USER must be specified either in the environment ",
"or as part of the hostname when tunnel option is used")
end
end
if tunnel
if !haskey(tunnel_hosts_map, pubhost)
tunnel_hosts_map[pubhost] = Semaphore(get(config.max_parallel, typemax(Int)))
end
sem = tunnel_hosts_map[pubhost]
sshflags = get(config.sshflags)
acquire(sem)
try
(s, bind_addr) = connect_to_worker(pubhost, bind_addr, port, user, sshflags)
finally
release(sem)
end
else
(s, bind_addr) = connect_to_worker(bind_addr, port)
end
config.bind_addr = bind_addr
# write out a subset of the connect_at required for further worker-worker connection setups
config.connect_at = (bind_addr, port)
if !isnull(config.io)
let pid = pid
redirect_worker_output(pid, get(config.io))
end
end
(s, s)
end
function connect_w2w(pid::Int, config::WorkerConfig)
(rhost, rport) = get(config.connect_at)
config.host = rhost
config.port = rport
(s, bind_addr) = connect_to_worker(rhost, rport)
(s,s)
end
const client_port = Ref{Cushort}(0)
function socket_reuse_port()
@static if is_linux() || is_apple()
s = TCPSocket(delay = false)
# Linux requires the port to be bound before setting REUSEPORT, OSX after.
is_linux() && bind_client_port(s)
rc = ccall(:jl_tcp_reuseport, Int32, (Ptr{Void},), s.handle)
if rc > 0 # SO_REUSEPORT is unsupported, just return the ephemerally bound socket
return s
elseif rc < 0
# This is an issue only on systems with lots of client connections, hence delay the warning
nworkers() > 128 && warn_once("Error trying to reuse client port number, falling back to regular socket.")
# provide a clean new socket
return TCPSocket()
end
is_apple() && bind_client_port(s)
return s
else
return TCPSocket()
end
end
function bind_client_port(s)
err = ccall(:jl_tcp_bind, Int32, (Ptr{Void}, UInt16, UInt32, Cuint),
s.handle, hton(client_port[]), hton(UInt32(0)), 0)
uv_error("bind() failed", err)
_addr, port = Base._sockname(s, true)
client_port[] = port
return s
end
function connect_to_worker(host::AbstractString, port::Integer)
# Revert support for now. Client socket port number reuse
# does not play well in a scenario where worker processes are repeatedly
# created and torn down, i.e., when the new workers end up reusing a
# a previous listen port.
s = TCPSocket()
connect(s, host, UInt16(port))
# Avoid calling getaddrinfo if possible - involves a DNS lookup
# host may be a stringified ipv4 / ipv6 address or a dns name
bind_addr = nothing
try
bind_addr = string(parse(IPAddr,host))
catch
bind_addr = string(getaddrinfo(host))
end
(s, bind_addr)
end
function connect_to_worker(host::AbstractString, bind_addr::AbstractString, port::Integer, tunnel_user::AbstractString, sshflags)
s = connect("localhost", ssh_tunnel(tunnel_user, host, bind_addr, UInt16(port), sshflags))
(s, bind_addr)
end
"""
kill(manager::ClusterManager, pid::Int, config::WorkerConfig)
Implemented by cluster managers.
It is called on the master process, by [`rmprocs`](@ref).
It should cause the remote worker specified by `pid` to exit.
`kill(manager::ClusterManager.....)` executes a remote `exit()`
on `pid`.
"""
function kill(manager::ClusterManager, pid::Int, config::WorkerConfig)
remote_do(exit, pid) # For TCP based transports this will result in a close of the socket
# at our end, which will result in a cleanup of the worker.
nothing
end

View File

@@ -0,0 +1,215 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license
abstract type AbstractMsg end
const REF_ID = Ref(1)
next_ref_id() = (id = REF_ID[]; REF_ID[] = id+1; id)
struct RRID
whence::Int
id::Int
RRID() = RRID(myid(),next_ref_id())
RRID(whence, id) = new(whence,id)
end
hash(r::RRID, h::UInt) = hash(r.whence, hash(r.id, h))
==(r::RRID, s::RRID) = (r.whence==s.whence && r.id==s.id)
## Wire format description
#
# Each message has three parts, which are written in order to the worker's stream.
# 1) A header of type MsgHeader is serialized to the stream (via `serialize`).
# 2) A message of type AbstractMsg is then serialized.
# 3) Finally, a fixed bounday of 10 bytes is written.
# Message header stored separately from body to be able to send back errors if
# a deserialization error occurs when reading the message body.
struct MsgHeader
response_oid::RRID
notify_oid::RRID
MsgHeader(respond_oid=RRID(0,0), notify_oid=RRID(0,0)) =
new(respond_oid, notify_oid)
end
# Special oid (0,0) uses to indicate a null ID.
# Used instead of Nullable to decrease wire size of header.
null_id(id) = id == RRID(0, 0)
struct CallMsg{Mode} <: AbstractMsg
f::Function
args::Tuple
kwargs::Array
end
struct CallWaitMsg <: AbstractMsg
f::Function
args::Tuple
kwargs::Array
end
struct RemoteDoMsg <: AbstractMsg
f::Function
args::Tuple
kwargs::Array
end
struct ResultMsg <: AbstractMsg
value::Any
end
# Worker initialization messages
struct IdentifySocketMsg <: AbstractMsg
from_pid::Int
end
struct IdentifySocketAckMsg <: AbstractMsg
end
struct JoinPGRPMsg <: AbstractMsg
self_pid::Int
other_workers::Array
topology::Symbol
enable_threaded_blas::Bool
end
struct JoinCompleteMsg <: AbstractMsg
cpu_cores::Int
ospid::Int
end
# Avoiding serializing AbstractMsg containers results in a speedup
# of approximately 10%. Can be removed once module Serializer
# has been suitably improved.
const msgtypes = Any[CallWaitMsg, IdentifySocketAckMsg, IdentifySocketMsg,
JoinCompleteMsg, JoinPGRPMsg, RemoteDoMsg, ResultMsg,
CallMsg{:call}, CallMsg{:call_fetch}]
for (idx, tname) in enumerate(msgtypes)
exprs = Any[ :(serialize(s, o.$fld)) for fld in fieldnames(tname) ]
@eval function serialize_msg(s::AbstractSerializer, o::$tname)
write(s.io, UInt8($idx))
$(exprs...)
return nothing
end
end
let msg_cases = :(assert(false))
for i = length(msgtypes):-1:1
mti = msgtypes[i]
msg_cases = :(if idx == $i
return $(Expr(:call, QuoteNode(mti), fill(:(deserialize(s)), nfields(mti))...))
else
$msg_cases
end)
end
@eval function deserialize_msg(s::AbstractSerializer)
idx = read(s.io, UInt8)
$msg_cases
end
end
function send_msg_unknown(s::IO, header, msg)
error("attempt to send to unknown socket")
end
function send_msg(s::IO, header, msg)
id = worker_id_from_socket(s)
if id > -1
return send_msg(worker_from_id(id), header, msg)
end
send_msg_unknown(s, header, msg)
end
function send_msg_now(s::IO, header, msg::AbstractMsg)
id = worker_id_from_socket(s)
if id > -1
return send_msg_now(worker_from_id(id), header, msg)
end
send_msg_unknown(s, header, msg)
end
function send_msg_now(w::Worker, header, msg)
send_msg_(w, header, msg, true)
end
function send_msg(w::Worker, header, msg)
send_msg_(w, header, msg, false)
end
function flush_gc_msgs(w::Worker)
if !isdefined(w, :w_stream)
return
end
w.gcflag = false
new_array = Any[]
msgs = w.add_msgs
w.add_msgs = new_array
if !isempty(msgs)
remote_do(add_clients, w, msgs)
end
# del_msgs gets populated by finalizers, so be very careful here about ordering of allocations
new_array = Any[]
msgs = w.del_msgs
w.del_msgs = new_array
if !isempty(msgs)
#print("sending delete of $msgs\n")
remote_do(del_clients, w, msgs)
end
end
# Boundary inserted between messages on the wire, used for recovering
# from deserialization errors. Picked arbitrarily.
# A size of 10 bytes indicates ~ ~1e24 possible boundaries, so chance of collision
# with message contents is negligible.
const MSG_BOUNDARY = UInt8[0x79, 0x8e, 0x8e, 0xf5, 0x6e, 0x9b, 0x2e, 0x97, 0xd5, 0x7d]
# Faster serialization/deserialization of MsgHeader and RRID
function serialize_hdr_raw(io, hdr)
write(io, hdr.response_oid.whence, hdr.response_oid.id, hdr.notify_oid.whence, hdr.notify_oid.id)
end
function deserialize_hdr_raw(io)
data = read(io, Ref{NTuple{4,Int}}())[]
return MsgHeader(RRID(data[1], data[2]), RRID(data[3], data[4]))
end
function send_msg_(w::Worker, header, msg, now::Bool)
check_worker_state(w)
io = w.w_stream
lock(io.lock)
try
reset_state(w.w_serializer)
serialize_hdr_raw(io, header)
invokelatest(serialize_msg, w.w_serializer, msg) # io is wrapped in w_serializer
write(io, MSG_BOUNDARY)
if !now && w.gcflag
flush_gc_msgs(w)
else
flush(io)
end
finally
unlock(io.lock)
end
end
function flush_gc_msgs()
try
for w in (PGRP::ProcessGroup).workers
if isa(w,Worker) && w.gcflag && (w.state == W_CONNECTED)
flush_gc_msgs(w)
end
end
catch e
bt = catch_backtrace()
@schedule showerror(STDERR, e, bt)
end
end
function send_connection_hdr(w::Worker, cookie=true)
# For a connection initiated from the remote side to us, we only send the version,
# else when we initiate a connection we first send the cookie followed by our version.
# The remote side validates the cookie.
if cookie
write(w.w_stream, LPROC.cookie)
end
write(w.w_stream, rpad(VERSION_STRING, HDR_VERSION_LEN)[1:HDR_VERSION_LEN])
end

View File

@@ -0,0 +1,295 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license
mutable struct BatchProcessingError <: Exception
data
ex
end
"""
pgenerate([::WorkerPool], f, c...) -> iterator
Apply `f` to each element of `c` in parallel using available workers and tasks.
For multiple collection arguments, apply `f` elementwise.
Results are returned in order as they become available.
Note that `f` must be made available to all worker processes; see
[Code Availability and Loading Packages](@ref)
for details.
"""
function pgenerate(p::WorkerPool, f, c)
if length(p) == 0
return AsyncGenerator(f, c; ntasks=()->nworkers(p))
end
batches = batchsplit(c, min_batch_count = length(p) * 3)
return Iterators.flatten(AsyncGenerator(remote(p, b -> asyncmap(f, b)), batches))
end
pgenerate(p::WorkerPool, f, c1, c...) = pgenerate(p, a->f(a...), zip(c1, c...))
pgenerate(f, c) = pgenerate(default_worker_pool(), f, c)
pgenerate(f, c1, c...) = pgenerate(a->f(a...), zip(c1, c...))
"""
pmap([::AbstractWorkerPool], f, c...; distributed=true, batch_size=1, on_error=nothing, retry_delays=[]), retry_check=nothing) -> collection
Transform collection `c` by applying `f` to each element using available
workers and tasks.
For multiple collection arguments, apply `f` elementwise.
Note that `f` must be made available to all worker processes; see
[Code Availability and Loading Packages](@ref) for details.
If a worker pool is not specified, all available workers, i.e., the default worker pool
is used.
By default, `pmap` distributes the computation over all specified workers. To use only the
local process and distribute over tasks, specify `distributed=false`.
This is equivalent to using [`asyncmap`](@ref). For example,
`pmap(f, c; distributed=false)` is equivalent to `asyncmap(f,c; ntasks=()->nworkers())`
`pmap` can also use a mix of processes and tasks via the `batch_size` argument. For batch sizes
greater than 1, the collection is processed in multiple batches, each of length `batch_size` or less.
A batch is sent as a single request to a free worker, where a local [`asyncmap`](@ref) processes
elements from the batch using multiple concurrent tasks.
Any error stops `pmap` from processing the remainder of the collection. To override this behavior
you can specify an error handling function via argument `on_error` which takes in a single argument, i.e.,
the exception. The function can stop the processing by rethrowing the error, or, to continue, return any value
which is then returned inline with the results to the caller.
Consider the following two examples. The first one returns the exception object inline,
the second a 0 in place of any exception:
```julia-repl
julia> pmap(x->iseven(x) ? error("foo") : x, 1:4; on_error=identity)
4-element Array{Any,1}:
1
ErrorException("foo")
3
ErrorException("foo")
julia> pmap(x->iseven(x) ? error("foo") : x, 1:4; on_error=ex->0)
4-element Array{Int64,1}:
1
0
3
0
```
Errors can also be handled by retrying failed computations. Keyword arguments `retry_delays` and
`retry_check` are passed through to [`retry`](@ref) as keyword arguments `delays` and `check`
respectively. If batching is specified, and an entire batch fails, all items in
the batch are retried.
Note that if both `on_error` and `retry_delays` are specified, the `on_error` hook is called
before retrying. If `on_error` does not throw (or rethrow) an exception, the element will not
be retried.
Example: On errors, retry `f` on an element a maximum of 3 times without any delay between retries.
```julia
pmap(f, c; retry_delays = zeros(3))
```
Example: Retry `f` only if the exception is not of type `InexactError`, with exponentially increasing
delays up to 3 times. Return a `NaN` in place for all `InexactError` occurrences.
```julia
pmap(f, c; on_error = e->(isa(e, InexactError) ? NaN : rethrow(e)), retry_delays = ExponentialBackOff(n = 3))
```
"""
function pmap(p::AbstractWorkerPool, f, c; distributed=true, batch_size=1, on_error=nothing,
retry_delays=[], retry_check=nothing)
f_orig = f
# Don't do remote calls if there are no workers.
if (length(p) == 0) || (length(p) == 1 && fetch(p.channel) == myid())
distributed = false
end
# Don't do batching if not doing remote calls.
if !distributed
batch_size = 1
end
# If not batching, do simple remote call.
if batch_size == 1
if on_error !== nothing
f = wrap_on_error(f, on_error)
end
if distributed
f = remote(p, f)
end
if length(retry_delays) > 0
f = wrap_retry(f, retry_delays, retry_check)
end
return asyncmap(f, c; ntasks=()->nworkers(p))
else
# During batch processing, We need to ensure that if on_error is set, it is called
# for each element in error, and that we return as many elements as the original list.
# retry, if set, has to be called element wise and we will do a best-effort
# to ensure that we do not call mapped function on the same element more than length(retry_delays).
# This guarantee is not possible in case of worker death / network errors, wherein
# we will retry the entire batch on a new worker.
handle_errors = ((on_error !== nothing) || (length(retry_delays) > 0))
# Unlike the non-batch case, in batch mode, we trap all errors and the on_error hook (if present)
# is processed later in non-batch mode.
if handle_errors
f = wrap_on_error(f, (x,e)->BatchProcessingError(x,e); capture_data=true)
end
f = wrap_batch(f, p, handle_errors)
results = asyncmap(f, c; ntasks=()->nworkers(p), batch_size=batch_size)
# process errors if any.
if handle_errors
process_batch_errors!(p, f_orig, results, on_error, retry_delays, retry_check)
end
return results
end
end
pmap(p::AbstractWorkerPool, f, c1, c...; kwargs...) = pmap(p, a->f(a...), zip(c1, c...); kwargs...)
pmap(f, c; kwargs...) = pmap(default_worker_pool(), f, c; kwargs...)
pmap(f, c1, c...; kwargs...) = pmap(a->f(a...), zip(c1, c...); kwargs...)
function wrap_on_error(f, on_error; capture_data=false)
return x -> begin
try
f(x)
catch e
if capture_data
on_error(x, e)
else
on_error(e)
end
end
end
end
function wrap_retry(f, retry_delays, retry_check)
retry(delays=retry_delays, check=retry_check) do x
try
f(x)
catch e
rethrow(extract_exception(e))
end
end
end
function wrap_batch(f, p, handle_errors)
f = asyncmap_batch(f)
return batch -> begin
try
remotecall_fetch(f, p, batch)
catch e
if handle_errors
return Any[BatchProcessingError(batch[i], e) for i in 1:length(batch)]
else
rethrow(e)
end
end
end
end
asyncmap_batch(f) = batch -> asyncmap(x->f(x...), batch)
extract_exception(e) = isa(e, RemoteException) ? e.captured.ex : e
function process_batch_errors!(p, f, results, on_error, retry_delays, retry_check)
# Handle all the ones in error in another pmap, with batch size set to 1
reprocess = []
for (idx, v) in enumerate(results)
if isa(v, BatchProcessingError)
push!(reprocess, (idx,v))
end
end
if length(reprocess) > 0
errors = [x[2] for x in reprocess]
exceptions = [x.ex for x in errors]
state = start(retry_delays)
if (length(retry_delays) > 0) &&
(retry_check==nothing || all([retry_check(state,ex)[2] for ex in exceptions]))
# BatchProcessingError.data is a tuple of original args
error_processed = pmap(p, x->f(x...), [x.data for x in errors];
on_error = on_error, retry_delays = collect(retry_delays)[2:end], retry_check = retry_check)
elseif on_error !== nothing
error_processed = map(on_error, exceptions)
else
throw(CompositeException(exceptions))
end
for (idx, v) in enumerate(error_processed)
results[reprocess[idx][1]] = v
end
end
nothing
end
"""
head_and_tail(c, n) -> head, tail
Returns `head`: the first `n` elements of `c`;
and `tail`: an iterator over the remaining elements.
```jldoctest
julia> a = 1:10
1:10
julia> b, c = Base.head_and_tail(a, 3)
([1,2,3],Base.Iterators.Rest{UnitRange{Int64},Int64}(1:10,4))
julia> collect(c)
7-element Array{Any,1}:
4
5
6
7
8
9
10
```
"""
function head_and_tail(c, n)
head = Vector{eltype(c)}(n)
s = start(c)
i = 0
while i < n && !done(c, s)
i += 1
head[i], s = next(c, s)
end
return resize!(head, i), Iterators.rest(c, s)
end
"""
batchsplit(c; min_batch_count=1, max_batch_size=100) -> iterator
Split a collection into at least `min_batch_count` batches.
Equivalent to `partition(c, max_batch_size)` when `length(c) >> max_batch_size`.
"""
function batchsplit(c; min_batch_count=1, max_batch_size=100)
if min_batch_count < 1
throw(ArgumentError("min_batch_count must be ≥ 1, got $min_batch_count"))
end
if max_batch_size < 1
throw(ArgumentError("max_batch_size must be ≥ 1, got $max_batch_size"))
end
# Split collection into batches, then peek at the first few batches
batches = Iterators.partition(c, max_batch_size)
head, tail = head_and_tail(batches, min_batch_count)
# If there are not enough batches, use a smaller batch size
if length(head) < min_batch_count
batch_size = max(1, div(sum(length, head), min_batch_count))
return Iterators.partition(collect(Iterators.flatten(head)), batch_size)
end
return Iterators.flatten((head, tail))
end

View File

@@ -0,0 +1,353 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license
# data stored by the owner of a remote reference
def_rv_channel() = Channel(1)
mutable struct RemoteValue
c::AbstractChannel
clientset::IntSet # Set of workerids that have a reference to this channel.
# Keeping ids instead of a count aids in cleaning up upon
# a worker exit.
waitingfor::Int # processor we need to hear from to fill this, or 0
RemoteValue(c) = new(c, IntSet(), 0)
end
wait(rv::RemoteValue) = wait(rv.c)
## core messages: do, call, fetch, wait, ref, put! ##
mutable struct RemoteException <: Exception
pid::Int
captured::CapturedException
end
"""
RemoteException(captured)
Exceptions on remote computations are captured and rethrown locally. A `RemoteException`
wraps the `pid` of the worker and a captured exception. A `CapturedException` captures the
remote exception and a serializable form of the call stack when the exception was raised.
"""
RemoteException(captured) = RemoteException(myid(), captured)
function showerror(io::IO, re::RemoteException)
(re.pid != myid()) && print(io, "On worker ", re.pid, ":\n")
showerror(io, get_root_exception(re.captured))
end
isa_exception_container(ex) = (isa(ex, RemoteException) ||
isa(ex, CapturedException) ||
isa(ex, CompositeException))
function get_root_exception(ex)
if isa(ex, RemoteException)
return get_root_exception(ex.captured)
elseif isa(ex, CapturedException) && isa_exception_container(ex.ex)
return get_root_exception(ex.ex)
elseif isa(ex, CompositeException) && length(ex.exceptions) > 0 && isa_exception_container(ex.exceptions[1])
return get_root_exception(ex.exceptions[1])
else
return ex
end
end
function run_work_thunk(thunk, print_error)
local result
try
result = thunk()
catch err
ce = CapturedException(err, catch_backtrace())
result = RemoteException(ce)
print_error && showerror(STDERR, ce)
end
return result
end
function run_work_thunk(rv::RemoteValue, thunk)
put!(rv, run_work_thunk(thunk, false))
nothing
end
function schedule_call(rid, thunk)
return lock(client_refs) do
rv = RemoteValue(def_rv_channel())
(PGRP::ProcessGroup).refs[rid] = rv
push!(rv.clientset, rid.whence)
@schedule run_work_thunk(rv, thunk)
return rv
end
end
function deliver_result(sock::IO, msg, oid, value)
#print("$(myid()) sending result $oid\n")
if msg === :call_fetch || isa(value, RemoteException)
val = value
else
val = :OK
end
try
send_msg_now(sock, MsgHeader(oid), ResultMsg(val))
catch e
# terminate connection in case of serialization error
# otherwise the reading end would hang
print(STDERR, "fatal error on ", myid(), ": ")
display_error(e, catch_backtrace())
wid = worker_id_from_socket(sock)
close(sock)
if myid()==1
rmprocs(wid)
elseif wid == 1
exit(1)
else
remote_do(rmprocs, 1, wid)
end
end
end
## message event handlers ##
function process_messages(r_stream::TCPSocket, w_stream::TCPSocket, incoming::Bool=true)
@schedule process_tcp_streams(r_stream, w_stream, incoming)
end
function process_tcp_streams(r_stream::TCPSocket, w_stream::TCPSocket, incoming::Bool)
disable_nagle(r_stream)
wait_connected(r_stream)
if r_stream != w_stream
disable_nagle(w_stream)
wait_connected(w_stream)
end
message_handler_loop(r_stream, w_stream, incoming)
end
"""
Base.process_messages(r_stream::IO, w_stream::IO, incoming::Bool=true)
Called by cluster managers using custom transports. It should be called when the custom
transport implementation receives the first message from a remote worker. The custom
transport must manage a logical connection to the remote worker and provide two
`IO` objects, one for incoming messages and the other for messages addressed to the
remote worker.
If `incoming` is `true`, the remote peer initiated the connection.
Whichever of the pair initiates the connection sends the cluster cookie and its
Julia version number to perform the authentication handshake.
See also [`cluster_cookie`](@ref).
"""
function process_messages(r_stream::IO, w_stream::IO, incoming::Bool=true)
@schedule message_handler_loop(r_stream, w_stream, incoming)
end
function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool)
wpid=0 # the worker r_stream is connected to.
boundary = similar(MSG_BOUNDARY)
try
version = process_hdr(r_stream, incoming)
serializer = ClusterSerializer(r_stream)
# The first message will associate wpid with r_stream
header = deserialize_hdr_raw(r_stream)
msg = deserialize_msg(serializer)
handle_msg(msg, header, r_stream, w_stream, version)
wpid = worker_id_from_socket(r_stream)
@assert wpid > 0
readbytes!(r_stream, boundary, length(MSG_BOUNDARY))
while true
reset_state(serializer)
header = deserialize_hdr_raw(r_stream)
# println("header: ", header)
try
msg = invokelatest(deserialize_msg, serializer)
catch e
# Deserialization error; discard bytes in stream until boundary found
boundary_idx = 1
while true
# This may throw an EOF error if the terminal boundary was not written
# correctly, triggering the higher-scoped catch block below
byte = read(r_stream, UInt8)
if byte == MSG_BOUNDARY[boundary_idx]
boundary_idx += 1
if boundary_idx > length(MSG_BOUNDARY)
break
end
else
boundary_idx = 1
end
end
# remotecalls only rethrow RemoteExceptions. Any other exception is treated as
# data to be returned. Wrap this exception in a RemoteException.
remote_err = RemoteException(myid(), CapturedException(e, catch_backtrace()))
# println("Deserialization error. ", remote_err)
if !null_id(header.response_oid)
ref = lookup_ref(header.response_oid)
put!(ref, remote_err)
end
if !null_id(header.notify_oid)
deliver_result(w_stream, :call_fetch, header.notify_oid, remote_err)
end
continue
end
readbytes!(r_stream, boundary, length(MSG_BOUNDARY))
# println("got msg: ", typeof(msg))
handle_msg(msg, header, r_stream, w_stream, version)
end
catch e
# Check again as it may have been set in a message handler but not propagated to the calling block above
wpid = worker_id_from_socket(r_stream)
if (wpid < 1)
println(STDERR, e, CapturedException(e, catch_backtrace()))
println(STDERR, "Process($(myid())) - Unknown remote, closing connection.")
else
werr = worker_from_id(wpid)
oldstate = werr.state
set_worker_state(werr, W_TERMINATED)
# If unhandleable error occurred talking to pid 1, exit
if wpid == 1
if isopen(w_stream)
print(STDERR, "fatal error on ", myid(), ": ")
display_error(e, catch_backtrace())
end
exit(1)
end
# Will treat any exception as death of node and cleanup
# since currently we do not have a mechanism for workers to reconnect
# to each other on unhandled errors
deregister_worker(wpid)
end
isopen(r_stream) && close(r_stream)
isopen(w_stream) && close(w_stream)
if (myid() == 1) && (wpid > 1)
if oldstate != W_TERMINATING
println(STDERR, "Worker $wpid terminated.")
rethrow(e)
end
end
return nothing
end
end
function process_hdr(s, validate_cookie)
if validate_cookie
cookie = read(s, HDR_COOKIE_LEN)
if length(cookie) < HDR_COOKIE_LEN
error("Cookie read failed. Connection closed by peer.")
end
self_cookie = cluster_cookie()
for i in 1:HDR_COOKIE_LEN
if UInt8(self_cookie[i]) != cookie[i]
error("Process($(myid())) - Invalid connection credentials sent by remote.")
end
end
end
# When we have incompatible julia versions trying to connect to each other,
# and can be detected, raise an appropriate error.
# For now, just return the version.
version = read(s, HDR_VERSION_LEN)
if length(version) < HDR_VERSION_LEN
error("Version read failed. Connection closed by peer.")
end
return VersionNumber(strip(String(version)))
end
function handle_msg(msg::CallMsg{:call}, header, r_stream, w_stream, version)
schedule_call(header.response_oid, ()->msg.f(msg.args...; msg.kwargs...))
end
function handle_msg(msg::CallMsg{:call_fetch}, header, r_stream, w_stream, version)
@schedule begin
v = run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), false)
deliver_result(w_stream, :call_fetch, header.notify_oid, v)
end
end
function handle_msg(msg::CallWaitMsg, header, r_stream, w_stream, version)
@schedule begin
rv = schedule_call(header.response_oid, ()->msg.f(msg.args...; msg.kwargs...))
deliver_result(w_stream, :call_wait, header.notify_oid, fetch(rv.c))
end
end
function handle_msg(msg::RemoteDoMsg, header, r_stream, w_stream, version)
@schedule run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), true)
end
function handle_msg(msg::ResultMsg, header, r_stream, w_stream, version)
put!(lookup_ref(header.response_oid), msg.value)
end
function handle_msg(msg::IdentifySocketMsg, header, r_stream, w_stream, version)
# register a new peer worker connection
w=Worker(msg.from_pid, r_stream, w_stream, cluster_manager; version=version)
send_connection_hdr(w, false)
send_msg_now(w, MsgHeader(), IdentifySocketAckMsg())
end
function handle_msg(msg::IdentifySocketAckMsg, header, r_stream, w_stream, version)
w = map_sock_wrkr[r_stream]
w.version = version
end
function handle_msg(msg::JoinPGRPMsg, header, r_stream, w_stream, version)
LPROC.id = msg.self_pid
controller = Worker(1, r_stream, w_stream, cluster_manager; version=version)
register_worker(LPROC)
topology(msg.topology)
if !msg.enable_threaded_blas
disable_threaded_libs()
end
wait_tasks = Task[]
for (connect_at, rpid) in msg.other_workers
wconfig = WorkerConfig()
wconfig.connect_at = connect_at
let rpid=rpid, wconfig=wconfig
t = @async connect_to_peer(cluster_manager, rpid, wconfig)
push!(wait_tasks, t)
end
end
for wt in wait_tasks; wait(wt); end
send_connection_hdr(controller, false)
send_msg_now(controller, MsgHeader(RRID(0,0), header.notify_oid), JoinCompleteMsg(Sys.CPU_CORES, getpid()))
end
function connect_to_peer(manager::ClusterManager, rpid::Int, wconfig::WorkerConfig)
try
(r_s, w_s) = connect(manager, rpid, wconfig)
w = Worker(rpid, r_s, w_s, manager; config=wconfig)
process_messages(w.r_stream, w.w_stream, false)
send_connection_hdr(w, true)
send_msg_now(w, MsgHeader(), IdentifySocketMsg(myid()))
catch e
display_error(e, catch_backtrace())
println(STDERR, "Error [$e] on $(myid()) while connecting to peer $rpid. Exiting.")
exit(1)
end
end
function handle_msg(msg::JoinCompleteMsg, header, r_stream, w_stream, version)
w = map_sock_wrkr[r_stream]
environ = get(w.config.environ, Dict())
environ[:cpu_cores] = msg.cpu_cores
w.config.environ = environ
w.config.ospid = msg.ospid
w.version = version
ntfy_channel = lookup_ref(header.notify_oid)
put!(ntfy_channel, w.id)
push!(default_worker_pool(), w.id)
end

View File

@@ -0,0 +1,554 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license
"""
client_refs
Tracks whether a particular `AbstractRemoteRef`
(identified by its RRID) exists on this worker.
The `client_refs` lock is also used to synchronize access to `.refs` and associated `clientset` state.
"""
const client_refs = WeakKeyDict{Any, Void}() # used as a WeakKeySet
abstract type AbstractRemoteRef end
mutable struct Future <: AbstractRemoteRef
where::Int
whence::Int
id::Int
v::Nullable{Any}
Future(w::Int, rrid::RRID) = Future(w, rrid, Nullable{Any}())
Future(w::Int, rrid::RRID, v) = (r = new(w,rrid.whence,rrid.id,v); return test_existing_ref(r))
end
mutable struct RemoteChannel{T<:AbstractChannel} <: AbstractRemoteRef
where::Int
whence::Int
id::Int
function RemoteChannel{T}(w::Int, rrid::RRID) where T<:AbstractChannel
r = new(w, rrid.whence, rrid.id)
return test_existing_ref(r)
end
end
function test_existing_ref(r::AbstractRemoteRef)
found = getkey(client_refs, r, nothing)
if found !== nothing
@assert r.where > 0
if isa(r, Future) && isnull(found.v) && !isnull(r.v)
# we have recd the value from another source, probably a deserialized ref, send a del_client message
send_del_client(r)
found.v = r.v
end
return found::typeof(r)
end
client_refs[r] = nothing
finalizer(r, finalize_ref)
return r
end
function finalize_ref(r::AbstractRemoteRef)
if r.where > 0 # Handle the case of the finalizer having been called manually
islocked(client_refs) && return finalizer(r, finalize_ref) # delay finalizer for later, when it's not already locked
delete!(client_refs, r)
if isa(r, RemoteChannel)
send_del_client(r)
else
# send_del_client only if the reference has not been set
isnull(r.v) && send_del_client(r)
r.v = Nullable{Any}()
end
r.where = 0
end
nothing
end
Future(w::LocalProcess) = Future(w.id)
Future(w::Worker) = Future(w.id)
"""
Future(pid::Integer=myid())
Create a `Future` on process `pid`.
The default `pid` is the current process.
"""
Future(pid::Integer=myid()) = Future(pid, RRID())
"""
RemoteChannel(pid::Integer=myid())
Make a reference to a `Channel{Any}(1)` on process `pid`.
The default `pid` is the current process.
"""
RemoteChannel(pid::Integer=myid()) = RemoteChannel{Channel{Any}}(pid, RRID())
"""
RemoteChannel(f::Function, pid::Integer=myid())
Create references to remote channels of a specific size and type. `f()` is a function that
when executed on `pid` must return an implementation of an `AbstractChannel`.
For example, `RemoteChannel(()->Channel{Int}(10), pid)`, will return a reference to a
channel of type `Int` and size 10 on `pid`.
The default `pid` is the current process.
"""
function RemoteChannel(f::Function, pid::Integer=myid())
remotecall_fetch(pid, f, RRID()) do f, rrid
rv=lookup_ref(rrid, f)
RemoteChannel{typeof(rv.c)}(myid(), rrid)
end
end
hash(r::AbstractRemoteRef, h::UInt) = hash(r.whence, hash(r.id, h))
==(r::AbstractRemoteRef, s::AbstractRemoteRef) = (r.whence==s.whence && r.id==s.id)
"""
Base.remoteref_id(r::AbstractRemoteRef) -> RRID
`Future`s and `RemoteChannel`s are identified by fields:
* `where` - refers to the node where the underlying object/storage
referred to by the reference actually exists.
* `whence` - refers to the node the remote reference was created from.
Note that this is different from the node where the underlying object
referred to actually exists. For example calling `RemoteChannel(2)`
from the master process would result in a `where` value of 2 and
a `whence` value of 1.
* `id` is unique across all references created from the worker specified by `whence`.
Taken together, `whence` and `id` uniquely identify a reference across all workers.
`Base.remoteref_id` is a low-level API which returns a `Base.RRID`
object that wraps `whence` and `id` values of a remote reference.
"""
remoteref_id(r::AbstractRemoteRef) = RRID(r.whence, r.id)
"""
Base.channel_from_id(id) -> c
A low-level API which returns the backing `AbstractChannel` for an `id` returned by
[`remoteref_id`](@ref).
The call is valid only on the node where the backing channel exists.
"""
function channel_from_id(id)
rv = lock(client_refs) do
return get(PGRP.refs, id, false)
end
if rv === false
throw(ErrorException("Local instance of remote reference not found"))
end
return rv.c
end
lookup_ref(rrid::RRID, f=def_rv_channel) = lookup_ref(PGRP, rrid, f)
function lookup_ref(pg, rrid, f)
return lock(client_refs) do
rv = get(pg.refs, rrid, false)
if rv === false
# first we've heard of this ref
rv = RemoteValue(invokelatest(f))
pg.refs[rrid] = rv
push!(rv.clientset, rrid.whence)
end
return rv
end::RemoteValue
end
"""
isready(rr::Future)
Determine whether a [`Future`](@ref) has a value stored to it.
If the argument `Future` is owned by a different node, this call will block to wait for the answer.
It is recommended to wait for `rr` in a separate task instead
or to use a local [`Channel`](@ref) as a proxy:
c = Channel(1)
@async put!(c, remotecall_fetch(long_computation, p))
isready(c) # will not block
"""
function isready(rr::Future)
!isnull(rr.v) && return true
rid = remoteref_id(rr)
return if rr.where == myid()
isready(lookup_ref(rid).c)
else
remotecall_fetch(rid->isready(lookup_ref(rid).c), rr.where, rid)
end
end
"""
isready(rr::RemoteChannel, args...)
Determine whether a [`RemoteChannel`](@ref) has a value stored to it.
Note that this function can cause race conditions, since by the
time you receive its result it may no longer be true. However,
it can be safely used on a [`Future`](@ref) since they are assigned only once.
"""
function isready(rr::RemoteChannel, args...)
rid = remoteref_id(rr)
return if rr.where == myid()
isready(lookup_ref(rid).c, args...)
else
remotecall_fetch(rid->isready(lookup_ref(rid).c, args...), rr.where, rid)
end
end
del_client(rr::AbstractRemoteRef) = del_client(remoteref_id(rr), myid())
del_client(id, client) = del_client(PGRP, id, client)
function del_client(pg, id, client)
lock(client_refs) do
rv = get(pg.refs, id, false)
if rv !== false
delete!(rv.clientset, client)
if isempty(rv.clientset)
delete!(pg.refs, id)
#print("$(myid()) collected $id\n")
end
end
end
nothing
end
function del_clients(pairs::Vector)
for p in pairs
del_client(p[1], p[2])
end
end
any_gc_flag = Condition()
function start_gc_msgs_task()
@schedule while true
wait(any_gc_flag)
flush_gc_msgs()
end
end
function send_del_client(rr)
if rr.where == myid()
del_client(rr)
elseif id_in_procs(rr.where) # process only if a valid worker
w = worker_from_id(rr.where)
push!(w.del_msgs, (remoteref_id(rr), myid()))
w.gcflag = true
notify(any_gc_flag)
end
end
function add_client(id, client)
lock(client_refs) do
rv = lookup_ref(id)
push!(rv.clientset, client)
end
nothing
end
function add_clients(pairs::Vector)
for p in pairs
add_client(p[1], p[2]...)
end
end
function send_add_client(rr::AbstractRemoteRef, i)
if rr.where == myid()
add_client(remoteref_id(rr), i)
elseif (i != rr.where) && id_in_procs(rr.where)
# don't need to send add_client if the message is already going
# to the processor that owns the remote ref. it will add_client
# itself inside deserialize().
w = worker_from_id(rr.where)
push!(w.add_msgs, (remoteref_id(rr), i))
w.gcflag = true
notify(any_gc_flag)
end
end
channel_type{T}(rr::RemoteChannel{T}) = T
serialize(s::AbstractSerializer, f::Future) = serialize(s, f, isnull(f.v))
serialize(s::AbstractSerializer, rr::RemoteChannel) = serialize(s, rr, true)
function serialize(s::AbstractSerializer, rr::AbstractRemoteRef, addclient)
if addclient
p = worker_id_from_socket(s.io)
(p !== rr.where) && send_add_client(rr, p)
end
invoke(serialize, Tuple{AbstractSerializer, Any}, s, rr)
end
function deserialize(s::AbstractSerializer, t::Type{<:Future})
f = deserialize_rr(s,t)
Future(f.where, RRID(f.whence, f.id), f.v) # ctor adds to client_refs table
end
function deserialize(s::AbstractSerializer, t::Type{<:RemoteChannel})
rr = deserialize_rr(s,t)
# call ctor to make sure this rr gets added to the client_refs table
RemoteChannel{channel_type(rr)}(rr.where, RRID(rr.whence, rr.id))
end
function deserialize_rr(s, t)
rr = invoke(deserialize, Tuple{AbstractSerializer, DataType}, s, t)
if rr.where == myid()
# send_add_client() is not executed when the ref is being
# serialized to where it exists
add_client(remoteref_id(rr), myid())
end
rr
end
# make a thunk to call f on args in a way that simulates what would happen if
# the function were sent elsewhere
function local_remotecall_thunk(f, args, kwargs)
if isempty(args) && isempty(kwargs)
return f
end
return ()->f(args...; kwargs...)
end
function remotecall(f, w::LocalProcess, args...; kwargs...)
rr = Future(w)
schedule_call(remoteref_id(rr), local_remotecall_thunk(f, args, kwargs))
return rr
end
function remotecall(f, w::Worker, args...; kwargs...)
rr = Future(w)
send_msg(w, MsgHeader(remoteref_id(rr)), CallMsg{:call}(f, args, kwargs))
return rr
end
"""
remotecall(f, id::Integer, args...; kwargs...) -> Future
Call a function `f` asynchronously on the given arguments on the specified process.
Returns a [`Future`](@ref).
Keyword arguments, if any, are passed through to `f`.
"""
remotecall(f, id::Integer, args...; kwargs...) = remotecall(f, worker_from_id(id), args...; kwargs...)
function remotecall_fetch(f, w::LocalProcess, args...; kwargs...)
v=run_work_thunk(local_remotecall_thunk(f,args, kwargs), false)
return isa(v, RemoteException) ? throw(v) : v
end
function remotecall_fetch(f, w::Worker, args...; kwargs...)
# can be weak, because the program will have no way to refer to the Ref
# itself, it only gets the result.
oid = RRID()
rv = lookup_ref(oid)
rv.waitingfor = w.id
send_msg(w, MsgHeader(RRID(0,0), oid), CallMsg{:call_fetch}(f, args, kwargs))
v = take!(rv)
lock(client_refs) do
delete!(PGRP.refs, oid)
end
return isa(v, RemoteException) ? throw(v) : v
end
"""
remotecall_fetch(f, id::Integer, args...; kwargs...)
Perform `fetch(remotecall(...))` in one message.
Keyword arguments, if any, are passed through to `f`.
Any remote exceptions are captured in a
[`RemoteException`](@ref) and thrown.
See also [`fetch`](@ref) and [`remotecall`](@ref).
"""
remotecall_fetch(f, id::Integer, args...; kwargs...) =
remotecall_fetch(f, worker_from_id(id), args...; kwargs...)
remotecall_wait(f, w::LocalProcess, args...; kwargs...) = wait(remotecall(f, w, args...; kwargs...))
function remotecall_wait(f, w::Worker, args...; kwargs...)
prid = RRID()
rv = lookup_ref(prid)
rv.waitingfor = w.id
rr = Future(w)
send_msg(w, MsgHeader(remoteref_id(rr), prid), CallWaitMsg(f, args, kwargs))
v = fetch(rv.c)
lock(client_refs) do
delete!(PGRP.refs, prid)
end
isa(v, RemoteException) && throw(v)
return rr
end
"""
remotecall_wait(f, id::Integer, args...; kwargs...)
Perform a faster `wait(remotecall(...))` in one message on the `Worker` specified by worker id `id`.
Keyword arguments, if any, are passed through to `f`.
See also [`wait`](@ref) and [`remotecall`](@ref).
"""
remotecall_wait(f, id::Integer, args...; kwargs...) =
remotecall_wait(f, worker_from_id(id), args...; kwargs...)
function remote_do(f, w::LocalProcess, args...; kwargs...)
# the LocalProcess version just performs in local memory what a worker
# does when it gets a :do message.
# same for other messages on LocalProcess.
thk = local_remotecall_thunk(f, args, kwargs)
schedule(Task(thk))
nothing
end
function remote_do(f, w::Worker, args...; kwargs...)
send_msg(w, MsgHeader(), RemoteDoMsg(f, args, kwargs))
nothing
end
"""
remote_do(f, id::Integer, args...; kwargs...) -> nothing
Executes `f` on worker `id` asynchronously.
Unlike [`remotecall`](@ref), it does not store the
result of computation, nor is there a way to wait for its completion.
A successful invocation indicates that the request has been accepted for execution on
the remote node.
While consecutive `remotecall`s to the same worker are serialized in the order they are
invoked, the order of executions on the remote worker is undetermined. For example,
`remote_do(f1, 2); remotecall(f2, 2); remote_do(f3, 2)` will serialize the call
to `f1`, followed by `f2` and `f3` in that order. However, it is not guaranteed that `f1`
is executed before `f3` on worker 2.
Any exceptions thrown by `f` are printed to [`STDERR`](@ref) on the remote worker.
Keyword arguments, if any, are passed through to `f`.
"""
remote_do(f, id::Integer, args...; kwargs...) = remote_do(f, worker_from_id(id), args...; kwargs...)
# have the owner of rr call f on it
function call_on_owner(f, rr::AbstractRemoteRef, args...)
rid = remoteref_id(rr)
if rr.where == myid()
f(rid, args...)
else
remotecall_fetch(f, rr.where, rid, args...)
end
end
function wait_ref(rid, callee, args...)
v = fetch_ref(rid, args...)
if isa(v, RemoteException)
if myid() == callee
throw(v)
else
return v
end
end
nothing
end
wait(r::Future) = (!isnull(r.v) && return r; call_on_owner(wait_ref, r, myid()); r)
wait(r::RemoteChannel, args...) = (call_on_owner(wait_ref, r, myid(), args...); r)
function fetch(r::Future)
!isnull(r.v) && return get(r.v)
v=call_on_owner(fetch_ref, r)
r.v=v
send_del_client(r)
v
end
fetch_ref(rid, args...) = fetch(lookup_ref(rid).c, args...)
fetch(r::RemoteChannel, args...) = call_on_owner(fetch_ref, r, args...)
"""
fetch(x)
Waits and fetches a value from `x` depending on the type of `x`:
* [`Future`](@ref): Wait for and get the value of a `Future`. The fetched value is cached locally.
Further calls to `fetch` on the same reference return the cached value. If the remote value
is an exception, throws a [`RemoteException`](@ref) which captures the remote exception and backtrace.
* [`RemoteChannel`](@ref): Wait for and get the value of a remote reference. Exceptions raised are
same as for a `Future` .
Does not remove the item fetched.
"""
fetch(x::ANY) = x
isready(rv::RemoteValue, args...) = isready(rv.c, args...)
"""
put!(rr::Future, v)
Store a value to a [`Future`](@ref) `rr`.
`Future`s are write-once remote references.
A `put!` on an already set `Future` throws an `Exception`.
All asynchronous remote calls return `Future`s and set the
value to the return value of the call upon completion.
"""
function put!(rr::Future, v)
!isnull(rr.v) && error("Future can be set only once")
call_on_owner(put_future, rr, v, myid())
rr.v = v
rr
end
function put_future(rid, v, callee)
rv = lookup_ref(rid)
isready(rv) && error("Future can be set only once")
put!(rv, v)
# The callee has the value and hence can be removed from the remote store.
del_client(rid, callee)
nothing
end
put!(rv::RemoteValue, args...) = put!(rv.c, args...)
put_ref(rid, args...) = (put!(lookup_ref(rid), args...); nothing)
"""
put!(rr::RemoteChannel, args...)
Store a set of values to the [`RemoteChannel`](@ref).
If the channel is full, blocks until space is available.
Returns its first argument.
"""
put!(rr::RemoteChannel, args...) = (call_on_owner(put_ref, rr, args...); rr)
# take! is not supported on Future
take!(rv::RemoteValue, args...) = take!(rv.c, args...)
function take_ref(rid, callee, args...)
v=take!(lookup_ref(rid), args...)
isa(v, RemoteException) && (myid() == callee) && throw(v)
v
end
"""
take!(rr::RemoteChannel, args...)
Fetch value(s) from a [`RemoteChannel`](@ref) `rr`,
removing the value(s) in the processs.
"""
take!(rr::RemoteChannel, args...) = call_on_owner(take_ref, rr, myid(), args...)
# close is not supported on Future
close_ref(rid) = (close(lookup_ref(rid).c); nothing)
close(rr::RemoteChannel) = call_on_owner(close_ref, rr)
getindex(r::RemoteChannel) = fetch(r)
getindex(r::Future) = fetch(r)
getindex(r::Future, args...) = getindex(fetch(r), args...)
function getindex(r::RemoteChannel, args...)
if r.where == myid()
return getindex(fetch(r), args...)
end
return remotecall_fetch(getindex, r.where, r, args...)
end

View File

@@ -0,0 +1,297 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license
abstract type AbstractWorkerPool end
# An AbstractWorkerPool should implement
#
# `push!` - add a new worker to the overall pool (available + busy)
# `put!` - put back a worker to the available pool
# `take!` - take a worker from the available pool (to be used for remote function execution)
# `length` - number of workers available in the overall pool
# `isready` - return false if a `take!` on the pool would block, else true
#
# The default implementations of the above (on a AbstractWorkerPool) require fields
# channel::Channel{Int}
# workers::Set{Int}
#
mutable struct WorkerPool <: AbstractWorkerPool
channel::Channel{Int}
workers::Set{Int}
ref::RemoteChannel
WorkerPool(c::Channel, ref::RemoteChannel) = new(c, Set{Int}(), ref)
end
function WorkerPool()
wp = WorkerPool(Channel{Int}(typemax(Int)), RemoteChannel())
put!(wp.ref, WeakRef(wp))
wp
end
"""
WorkerPool(workers::Vector{Int})
Create a WorkerPool from a vector of worker ids.
"""
function WorkerPool(workers::Vector{Int})
pool = WorkerPool()
foreach(w->push!(pool, w), workers)
return pool
end
# On workers where this pool has been serialized to, instantiate with a dummy local channel.
WorkerPool(ref::RemoteChannel) = WorkerPool(Channel{Int}(1), ref)
function serialize(S::AbstractSerializer, pool::WorkerPool)
# Allow accessing a worker pool from other processors. When serialized,
# initialize the `ref` to point to self and only send the ref.
# Other workers will forward all put!, take!, calls to the process owning
# the ref (and hence the pool).
Serializer.serialize_type(S, typeof(pool))
serialize(S, pool.ref)
end
deserialize{T<:WorkerPool}(S::AbstractSerializer, t::Type{T}) = T(deserialize(S))
wp_local_push!(pool::AbstractWorkerPool, w::Int) = (push!(pool.workers, w); put!(pool.channel, w); pool)
wp_local_length(pool::AbstractWorkerPool) = length(pool.workers)
wp_local_isready(pool::AbstractWorkerPool) = isready(pool.channel)
function wp_local_put!(pool::AbstractWorkerPool, w::Int)
# In case of default_worker_pool, the master is implictly considered a worker, i.e.,
# it is not present in pool.workers.
# Confirm the that the worker is part of a pool before making it available.
w in pool.workers && put!(pool.channel, w)
w
end
function wp_local_workers(pool::AbstractWorkerPool)
if length(pool) == 0 && pool === default_worker_pool()
return [1]
else
return collect(pool.workers)
end
end
function wp_local_nworkers(pool::AbstractWorkerPool)
if length(pool) == 0 && pool === default_worker_pool()
return 1
else
return length(pool.workers)
end
end
function wp_local_take!(pool::AbstractWorkerPool)
# Find an active worker
worker = 0
while true
if length(pool) == 0
if pool === default_worker_pool()
# No workers, the master process is used as a worker
worker = 1
break
else
throw(ErrorException("No active worker available in pool"))
end
end
worker = take!(pool.channel)
if id_in_procs(worker)
break
else
delete!(pool.workers, worker) # Remove invalid worker from pool
end
end
return worker
end
function remotecall_pool(rc_f, f, pool::AbstractWorkerPool, args...; kwargs...)
worker = take!(pool)
try
rc_f(f, worker, args...; kwargs...)
finally
put!(pool, worker)
end
end
# Check if pool is local or remote and forward calls if required.
# NOTE: remotecall_fetch does it automatically, but this will be more efficient as
# it avoids the overhead associated with a local remotecall.
for func = (:length, :isready, :workers, :nworkers, :take!)
func_local = Symbol(string("wp_local_", func))
@eval begin
function ($func)(pool::WorkerPool)
if pool.ref.where != myid()
return remotecall_fetch(ref->($func_local)(fetch(ref).value), pool.ref.where, pool.ref)
else
return ($func_local)(pool)
end
end
# default impl
($func)(pool::AbstractWorkerPool) = ($func_local)(pool)
end
end
for func = (:push!, :put!)
func_local = Symbol(string("wp_local_", func))
@eval begin
function ($func)(pool::WorkerPool, w::Int)
if pool.ref.where != myid()
return remotecall_fetch((ref, w)->($func_local)(fetch(ref).value, w), pool.ref.where, pool.ref, w)
else
return ($func_local)(pool, w)
end
end
# default impl
($func)(pool::AbstractWorkerPool, w::Int) = ($func_local)(pool, w)
end
end
"""
remotecall(f, pool::AbstractWorkerPool, args...; kwargs...) -> Future
`WorkerPool` variant of `remotecall(f, pid, ....)`. Waits for and takes a free worker from `pool` and performs a `remotecall` on it.
"""
remotecall(f, pool::AbstractWorkerPool, args...; kwargs...) = remotecall_pool(remotecall, f, pool, args...; kwargs...)
"""
remotecall_wait(f, pool::AbstractWorkerPool, args...; kwargs...) -> Future
`WorkerPool` variant of `remotecall_wait(f, pid, ....)`. Waits for and takes a free worker from `pool` and
performs a `remotecall_wait` on it.
"""
remotecall_wait(f, pool::AbstractWorkerPool, args...; kwargs...) = remotecall_pool(remotecall_wait, f, pool, args...; kwargs...)
"""
remotecall_fetch(f, pool::AbstractWorkerPool, args...; kwargs...) -> result
`WorkerPool` variant of `remotecall_fetch(f, pid, ....)`. Waits for and takes a free worker from `pool` and
performs a `remotecall_fetch` on it.
"""
remotecall_fetch(f, pool::AbstractWorkerPool, args...; kwargs...) = remotecall_pool(remotecall_fetch, f, pool, args...; kwargs...)
"""
remote_do(f, pool::AbstractWorkerPool, args...; kwargs...) -> nothing
`WorkerPool` variant of `remote_do(f, pid, ....)`. Waits for and takes a free worker from `pool` and
performs a `remote_do` on it.
"""
remote_do(f, pool::AbstractWorkerPool, args...; kwargs...) = remotecall_pool(remote_do, f, pool, args...; kwargs...)
const _default_worker_pool = Ref{Nullable}(Nullable{WorkerPool}())
"""
default_worker_pool()
`WorkerPool` containing idle `workers()` - used by `remote(f)` and [`pmap`](@ref) (by default).
"""
function default_worker_pool()
# On workers retrieve the default worker pool from the master when accessed
# for the first time
if isnull(_default_worker_pool[])
if myid() == 1
_default_worker_pool[] = Nullable(WorkerPool())
else
_default_worker_pool[] = Nullable(remotecall_fetch(()->default_worker_pool(), 1))
end
end
return get(_default_worker_pool[])
end
"""
remote([::AbstractWorkerPool], f) -> Function
Returns an anonymous function that executes function `f` on an available worker
using [`remotecall_fetch`](@ref).
"""
remote(f) = (args...; kwargs...)->remotecall_fetch(f, default_worker_pool(), args...; kwargs...)
remote(p::AbstractWorkerPool, f) = (args...; kwargs...)->remotecall_fetch(f, p, args...; kwargs...)
mutable struct CachingPool <: AbstractWorkerPool
channel::Channel{Int}
workers::Set{Int}
# Mapping between a tuple (worker_id, f) and a remote_ref
map_obj2ref::Dict{Tuple{Int, Function}, RemoteChannel}
function CachingPool()
wp = new(Channel{Int}(typemax(Int)), Set{Int}(), Dict{Int, Function}())
finalizer(wp, clear!)
wp
end
end
serialize(s::AbstractSerializer, cp::CachingPool) = throw(ErrorException("CachingPool objects are not serializable."))
"""
CachingPool(workers::Vector{Int})
An implementation of an `AbstractWorkerPool`.
[`remote`](@ref), [`remotecall_fetch`](@ref),
[`pmap`](@ref) (and other remote calls which execute functions remotely)
benefit from caching the serialized/deserialized functions on the worker nodes,
especially closures (which may capture large amounts of data).
The remote cache is maintained for the lifetime of the returned `CachingPool` object.
To clear the cache earlier, use `clear!(pool)`.
For global variables, only the bindings are captured in a closure, not the data.
`let` blocks can be used to capture global data.
For example:
```
const foo=rand(10^8);
wp=CachingPool(workers())
let foo=foo
pmap(wp, i->sum(foo)+i, 1:100);
end
```
The above would transfer `foo` only once to each worker.
"""
function CachingPool(workers::Vector{Int})
pool = CachingPool()
for w in workers
push!(pool, w)
end
return pool
end
"""
clear!(pool::CachingPool) -> pool
Removes all cached functions from all participating workers.
"""
function clear!(pool::CachingPool)
for (_,rr) in pool.map_obj2ref
finalize(rr)
end
empty!(pool.map_obj2ref)
pool
end
exec_from_cache(rr::RemoteChannel, args...; kwargs...) = fetch(rr)(args...; kwargs...)
function exec_from_cache(f_ref::Tuple{Function, RemoteChannel}, args...; kwargs...)
put!(f_ref[2], f_ref[1]) # Cache locally
f_ref[1](args...; kwargs...)
end
function remotecall_pool(rc_f, f, pool::CachingPool, args...; kwargs...)
worker = take!(pool)
f_ref = get(pool.map_obj2ref, (worker, f), (f, RemoteChannel(worker)))
isa(f_ref, Tuple) && (pool.map_obj2ref[(worker, f)] = f_ref[2]) # Add to tracker
try
rc_f(exec_from_cache, worker, f_ref, args...; kwargs...)
finally
put!(pool, worker)
end
end