Add: julia-0.6.2
Former-commit-id: ccc667cf67d569f3fb3df39aa57c2134755a7551
This commit is contained in:
488
julia-0.6.2/share/julia/base/pkg/resolve/maxsum.jl
Normal file
488
julia-0.6.2/share/julia/base/pkg/resolve/maxsum.jl
Normal file
@@ -0,0 +1,488 @@
|
||||
# This file is a part of Julia. License is MIT: https://julialang.org/license
|
||||
|
||||
module MaxSum
|
||||
|
||||
include("fieldvalue.jl")
|
||||
|
||||
using .FieldValues, ..VersionWeights, ..PkgToMaxSumInterface
|
||||
|
||||
export UnsatError, Graph, Messages, maxsum
|
||||
|
||||
# An exception type used internally to signal that an unsatisfiable
|
||||
# constraint was detected
|
||||
mutable struct UnsatError <: Exception
|
||||
info
|
||||
end
|
||||
|
||||
# Some parameters to drive the decimation process
|
||||
mutable struct MaxSumParams
|
||||
nondec_iterations # number of initial iterations before starting
|
||||
# decimation
|
||||
dec_interval # number of iterations between decimations
|
||||
dec_fraction # fraction of nodes to decimate at every decimation
|
||||
# step
|
||||
|
||||
function MaxSumParams()
|
||||
accuracy = parse(Int, get(ENV, "JULIA_PKGRESOLVE_ACCURACY", "1"))
|
||||
if accuracy <= 0
|
||||
error("JULIA_PKGRESOLVE_ACCURACY must be > 0")
|
||||
end
|
||||
nondec_iterations = accuracy * 20
|
||||
dec_interval = accuracy * 10
|
||||
dec_fraction = 0.05 / accuracy
|
||||
return new(nondec_iterations, dec_interval, dec_fraction)
|
||||
end
|
||||
end
|
||||
|
||||
# Graph holds the graph structure onto which max-sum is run, in
|
||||
# sparse format
|
||||
mutable struct Graph
|
||||
# adjacency matrix:
|
||||
# for each package, has the list of neighbors
|
||||
# indices (both dependencies and dependants)
|
||||
gadj::Vector{Vector{Int}}
|
||||
|
||||
# compatibility mask:
|
||||
# for each package p0 has a list of bool masks.
|
||||
# Each entry in the list gmsk[p0] is relative to the
|
||||
# package p1 as read from gadj[p0].
|
||||
# Each mask has dimension spp1 x spp0, where
|
||||
# spp0 is the number of states of p0, and
|
||||
# spp1 is the number of states of p1.
|
||||
gmsk::Vector{Vector{BitMatrix}}
|
||||
|
||||
# dependency direction:
|
||||
# keeps track of which direction the dependency goes
|
||||
# takes 3 values:
|
||||
# 1 = dependant
|
||||
# -1 = dependency
|
||||
# 0 = both
|
||||
# Used to break symmetry between dependants and
|
||||
# dependencies (introduces a FieldValue at level l3).
|
||||
# The "both" case is for when there are dependency
|
||||
# relations which go both ways, in which case the
|
||||
# noise is left to discriminate in case of ties
|
||||
gdir::Vector{Vector{Int}}
|
||||
|
||||
# adjacency dict:
|
||||
# allows one to retrieve the indices in gadj, so that
|
||||
# gadj[p0][adjdict[p1][p0]] = p1
|
||||
adjdict::Vector{Dict{Int,Int}}
|
||||
|
||||
# states per package: same as in Interface
|
||||
spp::Vector{Int}
|
||||
|
||||
# update order: shuffled at each iteration
|
||||
perm::Vector{Int}
|
||||
|
||||
# number of packages (all Vectors above have this length)
|
||||
np::Int
|
||||
|
||||
function Graph(interface::Interface)
|
||||
deps = interface.deps
|
||||
np = interface.np
|
||||
|
||||
spp = interface.spp
|
||||
pdict = interface.pdict
|
||||
pvers = interface.pvers
|
||||
vdict = interface.vdict
|
||||
|
||||
gadj = [Int[] for i = 1:np]
|
||||
gmsk = [BitMatrix[] for i = 1:np]
|
||||
gdir = [Int[] for i = 1:np]
|
||||
adjdict = [Dict{Int,Int}() for i = 1:np]
|
||||
|
||||
for (p,d) in deps
|
||||
p0 = pdict[p]
|
||||
vdict0 = vdict[p0]
|
||||
for (vn,a) in d
|
||||
v0 = vdict0[vn]
|
||||
for (rp, rvs) in a.requires
|
||||
p1 = pdict[rp]
|
||||
|
||||
j0 = 1
|
||||
while j0 <= length(gadj[p0]) && gadj[p0][j0] != p1
|
||||
j0 += 1
|
||||
end
|
||||
j1 = 1
|
||||
while j1 <= length(gadj[p1]) && gadj[p1][j1] != p0
|
||||
j1 += 1
|
||||
end
|
||||
@assert (j0 > length(gadj[p0]) && j1 > length(gadj[p1])) ||
|
||||
(j0 <= length(gadj[p0]) && j1 <= length(gadj[p1]))
|
||||
|
||||
if j0 > length(gadj[p0])
|
||||
push!(gadj[p0], p1)
|
||||
push!(gadj[p1], p0)
|
||||
j0 = length(gadj[p0])
|
||||
j1 = length(gadj[p1])
|
||||
|
||||
adjdict[p1][p0] = j0
|
||||
adjdict[p0][p1] = j1
|
||||
|
||||
bm = trues(spp[p1], spp[p0])
|
||||
bmt = bm'
|
||||
|
||||
push!(gmsk[p0], bm)
|
||||
push!(gmsk[p1], bmt)
|
||||
|
||||
push!(gdir[p0], 1)
|
||||
push!(gdir[p1], -1)
|
||||
else
|
||||
bm = gmsk[p0][j0]
|
||||
bmt = gmsk[p1][j1]
|
||||
if gdir[p0][j0] == -1
|
||||
gdir[p0][j0] = 0
|
||||
gdir[p1][j1] = 0
|
||||
end
|
||||
end
|
||||
|
||||
for v1 = 1:length(pvers[p1])
|
||||
if pvers[p1][v1] ∉ rvs
|
||||
bm[v1, v0] = false
|
||||
bmt[v0, v1] = false
|
||||
end
|
||||
end
|
||||
bm[end,v0] = false
|
||||
bmt[v0,end] = false
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
perm = [1:np;]
|
||||
|
||||
return new(gadj, gmsk, gdir, adjdict, spp, perm, np)
|
||||
end
|
||||
end
|
||||
|
||||
# Messages has the cavity messages and the total fields, and
|
||||
# gets updated iteratively (and occasionally decimated) until
|
||||
# convergence
|
||||
mutable struct Messages
|
||||
# cavity incoming messages: for each package p0,
|
||||
# for each neighbor p1 of p0,
|
||||
# msg[p0][p1] is a vector of length spp[p0]
|
||||
# messages are normalized (i.e. the max is always 0)
|
||||
msg::Vector{Vector{Field}}
|
||||
|
||||
# overall fields: for each package p0,
|
||||
# fld[p0] is a vector of length spp[p0]
|
||||
# fields are not normalized
|
||||
fld::Vector{Field}
|
||||
|
||||
# backup of the initial value of fld, to be used when resetting
|
||||
initial_fld::Vector{Field}
|
||||
|
||||
# keep track of which variables have been decimated
|
||||
decimated::BitVector
|
||||
num_nondecimated::Int
|
||||
|
||||
function Messages(interface::Interface, graph::Graph)
|
||||
reqs = interface.reqs
|
||||
pkgs = interface.pkgs
|
||||
np = interface.np
|
||||
spp = interface.spp
|
||||
pvers = interface.pvers
|
||||
pdict = interface.pdict
|
||||
vweight = interface.vweight
|
||||
|
||||
# a "deterministic noise" function based on hashes
|
||||
function noise(p0::Int, v0::Int)
|
||||
s = pkgs[p0] * string(v0 == spp[p0] ? "UNINST" : pvers[p0][v0])
|
||||
Int128(hash(s))
|
||||
end
|
||||
|
||||
# external fields: there are 2 terms, a noise to break potential symmetries
|
||||
# and one to favor newest versions over older, and no-version over all
|
||||
fld = [[FieldValue(0, zero(VersionWeight), vweight[p0][v0], (v0==spp[p0]), 0, noise(p0,v0)) for v0 = 1:spp[p0]] for p0 = 1:np]
|
||||
|
||||
# enforce requirements
|
||||
for (rp, rvs) in reqs
|
||||
p0 = pdict[rp]
|
||||
pvers0 = pvers[p0]
|
||||
fld0 = fld[p0]
|
||||
for v0 = 1:spp[p0]-1
|
||||
vn = pvers0[v0]
|
||||
if !in(vn, rvs)
|
||||
# the state is forbidden by requirements
|
||||
fld0[v0] = FieldValue(-1)
|
||||
else
|
||||
# the state is one of those explicitly requested:
|
||||
# favor it at a higer level than normal (upgrade
|
||||
# FieldValue from l2 to l1)
|
||||
fld0[v0] += FieldValue(0, vweight[p0][v0], -vweight[p0][v0])
|
||||
end
|
||||
end
|
||||
# the uninstalled state is forbidden by requirements
|
||||
fld0[spp[p0]] = FieldValue(-1)
|
||||
end
|
||||
# normalize fields
|
||||
for p0 = 1:np
|
||||
m = maximum(fld[p0])
|
||||
for v0 = 1:spp[p0]
|
||||
fld[p0][v0] -= m
|
||||
end
|
||||
end
|
||||
|
||||
initial_fld = deepcopy(fld)
|
||||
|
||||
# initialize cavity messages to 0
|
||||
gadj = graph.gadj
|
||||
msg = [[zeros(FieldValue, spp[p0]) for p1 = 1:length(gadj[p0])] for p0 = 1:np]
|
||||
|
||||
return new(msg, fld, initial_fld, falses(np), np)
|
||||
end
|
||||
end
|
||||
|
||||
function getsolution(msgs::Messages)
|
||||
# the solution is just the location of the maximum in
|
||||
# each field
|
||||
|
||||
fld = msgs.fld
|
||||
np = length(fld)
|
||||
sol = Vector{Int}(np)
|
||||
for p0 = 1:np
|
||||
fld0 = fld[p0]
|
||||
s0 = indmax(fld0)
|
||||
if !validmax(fld0[s0])
|
||||
throw(UnsatError(p0))
|
||||
end
|
||||
sol[p0] = s0
|
||||
end
|
||||
return sol
|
||||
end
|
||||
|
||||
# This is the core of the max-sum solver:
|
||||
# for a given node p0 (i.e. a package) updates all
|
||||
# input cavity messages and fields of its neighbors
|
||||
function update(p0::Int, graph::Graph, msgs::Messages)
|
||||
gadj = graph.gadj
|
||||
gmsk = graph.gmsk
|
||||
gdir = graph.gdir
|
||||
adjdict = graph.adjdict
|
||||
spp = graph.spp
|
||||
np = graph.np
|
||||
msg = msgs.msg
|
||||
fld = msgs.fld
|
||||
decimated = msgs.decimated
|
||||
|
||||
maxdiff = zero(FieldValue)
|
||||
|
||||
gadj0 = gadj[p0]
|
||||
msg0 = msg[p0]
|
||||
fld0 = fld[p0]
|
||||
spp0 = spp[p0]
|
||||
adjdict0 = adjdict[p0]
|
||||
|
||||
# iterate over all neighbors of p0
|
||||
for j0 in 1:length(gadj0)
|
||||
|
||||
p1 = gadj0[j0]
|
||||
decimated[p1] && continue
|
||||
j1 = adjdict0[p1]
|
||||
#@assert j0 == adjdict[p1][p0]
|
||||
bm1 = gmsk[p1][j1]
|
||||
dir1 = gdir[p1][j1]
|
||||
spp1 = spp[p1]
|
||||
msg1 = msg[p1]
|
||||
|
||||
# compute the output cavity message p0->p1
|
||||
cavmsg = fld0 - msg0[j0]
|
||||
|
||||
if dir1 == -1
|
||||
# p0 depends on p1
|
||||
for v0 = 1:spp0-1
|
||||
cavmsg[v0] += FieldValue(0, VersionWeight(0), VersionWeight(0), 0, v0)
|
||||
end
|
||||
end
|
||||
|
||||
# keep the old input cavity message p0->p1
|
||||
oldmsg = msg1[j1]
|
||||
|
||||
# init the new message to minus infinity
|
||||
newmsg = [FieldValue(-1) for v1 = 1:spp1]
|
||||
|
||||
# compute the new message by passing cavmsg
|
||||
# through the constraint encoded in the bitmask
|
||||
# (nearly equivalent to:
|
||||
# newmsg = [maximum(cavmsg[bm1[:,v1]]) for v1 = 1:spp1]
|
||||
# except for the gnrg term)
|
||||
m = FieldValue(-1)
|
||||
for v1 = 1:spp1
|
||||
for v0 = 1:spp0
|
||||
if bm1[v0, v1]
|
||||
newmsg[v1] = max(newmsg[v1], cavmsg[v0])
|
||||
end
|
||||
end
|
||||
if dir1 == 1 && v1 != spp1
|
||||
# p1 depends on p0
|
||||
newmsg[v1] += FieldValue(0, VersionWeight(0), VersionWeight(0), 0, v1)
|
||||
end
|
||||
m = max(m, newmsg[v1])
|
||||
end
|
||||
if !validmax(m)
|
||||
# No state available without violating some
|
||||
# hard constraint
|
||||
throw(UnsatError(p1))
|
||||
end
|
||||
|
||||
# normalize the new message
|
||||
for v1 = 1:spp1
|
||||
newmsg[v1] -= m
|
||||
end
|
||||
|
||||
diff = newmsg - oldmsg
|
||||
maxdiff = max(maxdiff, maximum(abs.(diff)))
|
||||
|
||||
# update the field of p1
|
||||
fld1 = fld[p1]
|
||||
for v1 = 1:spp1
|
||||
fld1[v1] += diff[v1]
|
||||
end
|
||||
|
||||
# put the newly computed message in place
|
||||
msg1[j1] = newmsg
|
||||
end
|
||||
return maxdiff
|
||||
end
|
||||
|
||||
# A simple shuffling machinery for the update order in iterate()
|
||||
# (woulnd't pass any random quality test but it's arguably enough)
|
||||
let step=1
|
||||
global shuffleperm, shuffleperminit
|
||||
shuffleperminit() = (step = 1)
|
||||
function shuffleperm(graph::Graph)
|
||||
perm = graph.perm
|
||||
np = graph.np
|
||||
for j = np:-1:2
|
||||
k = mod(step,j)+1
|
||||
perm[j], perm[k] = perm[k], perm[j]
|
||||
step += isodd(j) ? 1 : k
|
||||
end
|
||||
#@assert isperm(perm)
|
||||
end
|
||||
end
|
||||
|
||||
# Call update for all nodes (i.e. packages) in
|
||||
# random order
|
||||
function iterate(graph::Graph, msgs::Messages)
|
||||
np = graph.np
|
||||
|
||||
maxdiff = zero(FieldValue)
|
||||
shuffleperm(graph)
|
||||
perm = graph.perm
|
||||
for p0 in perm
|
||||
maxdiff0 = update(p0, graph, msgs)
|
||||
maxdiff = max(maxdiff, maxdiff0)
|
||||
end
|
||||
return maxdiff
|
||||
end
|
||||
|
||||
function decimate1(p0::Int, graph::Graph, msgs::Messages)
|
||||
@assert !msgs.decimated[p0]
|
||||
fld0 = msgs.fld[p0]
|
||||
s0 = indmax(fld0)
|
||||
#println("DECIMATING $p0 ($(packages()[p0]) s0=$s0)")
|
||||
for v0 = 1:length(fld0)
|
||||
if v0 != s0
|
||||
fld0[v0] = FieldValue(-1)
|
||||
end
|
||||
end
|
||||
msgs.decimated[p0] = true
|
||||
msgs.num_nondecimated -= 1
|
||||
end
|
||||
|
||||
function reset_messages!(msgs::Messages)
|
||||
msg = msgs.msg
|
||||
fld = msgs.fld
|
||||
initial_fld = msgs.initial_fld
|
||||
decimated = msgs.decimated
|
||||
np = length(fld)
|
||||
for p0 = 1:np
|
||||
map(m->fill!(m, zero(FieldValue)), msg[p0])
|
||||
decimated[p0] && continue
|
||||
fld[p0] = copy(initial_fld[p0])
|
||||
end
|
||||
return msgs
|
||||
end
|
||||
|
||||
# If normal convergence fails (or is too slow) fix the most
|
||||
# polarized packages by adding extra infinite fields on every state
|
||||
# but the maximum
|
||||
function decimate(n::Int, graph::Graph, msgs::Messages)
|
||||
#println("DECIMATING $n NODES")
|
||||
fld = msgs.fld
|
||||
decimated = msgs.decimated
|
||||
fldorder = sortperm(fld, by=secondmax)
|
||||
for p0 in fldorder
|
||||
decimated[p0] && continue
|
||||
decimate1(p0, graph, msgs)
|
||||
n -= 1
|
||||
n == 0 && break
|
||||
end
|
||||
@assert n == 0
|
||||
reset_messages!(msgs)
|
||||
return
|
||||
end
|
||||
|
||||
# In case ties still exist at convergence, break them and
|
||||
# keep converging
|
||||
function break_ties(msgs::Messages)
|
||||
fld = msgs.fld
|
||||
for p0 = 1:length(fld)
|
||||
fld0 = fld[p0]
|
||||
z = 0
|
||||
m = typemin(FieldValue)
|
||||
for v0 = 1:length(fld0)
|
||||
if fld0[v0] > m
|
||||
m = fld0[v0]
|
||||
z = 1
|
||||
elseif fld0[v0] == m
|
||||
z += 1
|
||||
end
|
||||
end
|
||||
if z > 1
|
||||
#println("TIE! p0=$p0")
|
||||
decimate1(p0, msgs)
|
||||
return false
|
||||
end
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
# Iterative solver: run iterate() until convergence
|
||||
# (occasionally calling decimate())
|
||||
function maxsum(graph::Graph, msgs::Messages)
|
||||
params = MaxSumParams()
|
||||
|
||||
it = 0
|
||||
shuffleperminit()
|
||||
while true
|
||||
it += 1
|
||||
maxdiff = iterate(graph, msgs)
|
||||
#println("it = $it maxdiff = $maxdiff")
|
||||
|
||||
if maxdiff == zero(FieldValue)
|
||||
break_ties(msgs) && break
|
||||
continue
|
||||
end
|
||||
if it >= params.nondec_iterations &&
|
||||
(it - params.nondec_iterations) % params.dec_interval == 0
|
||||
numdec = clamp(floor(Int, params.dec_fraction * graph.np), 1, msgs.num_nondecimated)
|
||||
decimate(numdec, graph, msgs)
|
||||
msgs.num_nondecimated == 0 && break
|
||||
end
|
||||
end
|
||||
|
||||
# Finally, decimate everything just to
|
||||
# check against inconsistencies
|
||||
# (old_numnondec is saved just to prevent
|
||||
# wrong messages about accuracy)
|
||||
old_numnondec = msgs.num_nondecimated
|
||||
decimate(msgs.num_nondecimated, graph, msgs)
|
||||
msgs.num_nondecimated = old_numnondec
|
||||
|
||||
return getsolution(msgs)
|
||||
end
|
||||
|
||||
end
|
||||
Reference in New Issue
Block a user