525 lines
15 KiB
Julia
525 lines
15 KiB
Julia
# This file is a part of Julia. License is MIT: https://julialang.org/license
|
|
|
|
module MaxSum
|
|
|
|
include("fieldvalue.jl")
|
|
|
|
using .FieldValues, ..VersionWeights, ..PkgToMaxSumInterface
|
|
|
|
export UnsatError, Graph, Messages, maxsum
|
|
|
|
# An exception type used internally to signal that an unsatisfiable
|
|
# constraint was detected
|
|
mutable struct UnsatError <: Exception
|
|
info
|
|
end
|
|
|
|
# Some parameters to drive the decimation process
|
|
mutable struct MaxSumParams
|
|
nondec_iterations # number of initial iterations before starting
|
|
# decimation
|
|
dec_interval # number of iterations between decimations
|
|
dec_fraction # fraction of nodes to decimate at every decimation
|
|
# step
|
|
|
|
function MaxSumParams()
|
|
accuracy = parse(Int, get(ENV, "JULIA_PKGRESOLVE_ACCURACY", "1"))
|
|
if accuracy <= 0
|
|
error("JULIA_PKGRESOLVE_ACCURACY must be > 0")
|
|
end
|
|
nondec_iterations = accuracy * 20
|
|
dec_interval = accuracy * 10
|
|
dec_fraction = 0.05 / accuracy
|
|
return new(nondec_iterations, dec_interval, dec_fraction)
|
|
end
|
|
end
|
|
|
|
# Graph holds the graph structure onto which max-sum is run, in
|
|
# sparse format
|
|
mutable struct Graph
|
|
# adjacency matrix:
|
|
# for each package, has the list of neighbors
|
|
# indices (both dependencies and dependants)
|
|
gadj::Vector{Vector{Int}}
|
|
|
|
# compatibility mask:
|
|
# for each package p0 has a list of bool masks.
|
|
# Each entry in the list gmsk[p0] is relative to the
|
|
# package p1 as read from gadj[p0].
|
|
# Each mask has dimension spp1 x spp0, where
|
|
# spp0 is the number of states of p0, and
|
|
# spp1 is the number of states of p1.
|
|
gmsk::Vector{Vector{BitMatrix}}
|
|
|
|
# dependency direction:
|
|
# keeps track of which direction the dependency goes
|
|
# takes 3 values:
|
|
# 1 = dependant
|
|
# -1 = dependency
|
|
# 0 = both
|
|
# Used to break symmetry between dependants and
|
|
# dependencies (introduces a FieldValue at level l3).
|
|
# The "both" case is for when there are dependency
|
|
# relations which go both ways, in which case the
|
|
# noise is left to discriminate in case of ties
|
|
gdir::Vector{Vector{Int}}
|
|
|
|
# adjacency dict:
|
|
# allows one to retrieve the indices in gadj, so that
|
|
# gadj[p0][adjdict[p1][p0]] = p1
|
|
# ("At which index does package p1 appear in gadj[p0]?")
|
|
adjdict::Vector{Dict{Int,Int}}
|
|
|
|
# states per package: same as in Interface
|
|
spp::Vector{Int}
|
|
|
|
# update order: shuffled at each iteration
|
|
perm::Vector{Int}
|
|
|
|
# number of packages (all Vectors above have this length)
|
|
np::Int
|
|
|
|
function Graph(interface::Interface)
|
|
deps = interface.deps
|
|
np = interface.np
|
|
|
|
spp = interface.spp
|
|
pdict = interface.pdict
|
|
pvers = interface.pvers
|
|
vdict = interface.vdict
|
|
|
|
gadj = [Int[] for i = 1:np]
|
|
gmsk = [BitMatrix[] for i = 1:np]
|
|
gdir = [Int[] for i = 1:np]
|
|
adjdict = [Dict{Int,Int}() for i = 1:np]
|
|
|
|
for (p,d) in deps
|
|
p0 = pdict[p]
|
|
vdict0 = vdict[p0]
|
|
for (vn,a) in d
|
|
v0 = vdict0[vn]
|
|
for (rp, rvs) in a.requires
|
|
p1 = pdict[rp]
|
|
|
|
j0 = 1
|
|
while j0 <= length(gadj[p0]) && gadj[p0][j0] != p1
|
|
j0 += 1
|
|
end
|
|
j1 = 1
|
|
while j1 <= length(gadj[p1]) && gadj[p1][j1] != p0
|
|
j1 += 1
|
|
end
|
|
@assert (j0 > length(gadj[p0]) && j1 > length(gadj[p1])) ||
|
|
(j0 <= length(gadj[p0]) && j1 <= length(gadj[p1]))
|
|
|
|
if j0 > length(gadj[p0])
|
|
push!(gadj[p0], p1)
|
|
push!(gadj[p1], p0)
|
|
j0 = length(gadj[p0])
|
|
j1 = length(gadj[p1])
|
|
|
|
adjdict[p1][p0] = j0
|
|
adjdict[p0][p1] = j1
|
|
|
|
bm = trues(spp[p1], spp[p0])
|
|
bmt = bm'
|
|
|
|
push!(gmsk[p0], bm)
|
|
push!(gmsk[p1], bmt)
|
|
|
|
push!(gdir[p0], 1)
|
|
push!(gdir[p1], -1)
|
|
else
|
|
bm = gmsk[p0][j0]
|
|
bmt = gmsk[p1][j1]
|
|
if gdir[p0][j0] == -1
|
|
gdir[p0][j0] = 0
|
|
gdir[p1][j1] = 0
|
|
end
|
|
end
|
|
|
|
for v1 = 1:length(pvers[p1])
|
|
if pvers[p1][v1] ∉ rvs
|
|
bm[v1, v0] = false
|
|
bmt[v0, v1] = false
|
|
end
|
|
end
|
|
bm[end,v0] = false
|
|
bmt[v0,end] = false
|
|
end
|
|
end
|
|
end
|
|
|
|
perm = [1:np;]
|
|
|
|
return new(gadj, gmsk, gdir, adjdict, spp, perm, np)
|
|
end
|
|
end
|
|
|
|
# Messages has the cavity messages and the total fields, and
|
|
# gets updated iteratively (and occasionally decimated) until
|
|
# convergence
|
|
mutable struct Messages
|
|
# cavity incoming messages: for each package p0,
|
|
# for each neighbor p1 of p0,
|
|
# msg[p0][p1] is a vector of length spp[p0]
|
|
# messages are normalized (i.e. the max is always 0)
|
|
msg::Vector{Vector{Field}}
|
|
|
|
# overall fields: for each package p0,
|
|
# fld[p0] is a vector of length spp[p0]
|
|
# fields are not normalized
|
|
fld::Vector{Field}
|
|
|
|
# backup of the initial value of fld, to be used when resetting
|
|
initial_fld::Vector{Field}
|
|
|
|
# keep track of which variables have been decimated
|
|
decimated::BitVector
|
|
num_nondecimated::Int
|
|
|
|
function Messages(interface::Interface, graph::Graph)
|
|
reqs = interface.reqs
|
|
pkgs = interface.pkgs
|
|
np = interface.np
|
|
spp = interface.spp
|
|
pvers = interface.pvers
|
|
pdict = interface.pdict
|
|
vweight = interface.vweight
|
|
|
|
# a "deterministic noise" function based on hashes
|
|
function noise(p0::Int, v0::Int)
|
|
s = pkgs[p0] * string(v0 == spp[p0] ? "UNINST" : pvers[p0][v0])
|
|
Int128(hash(s))
|
|
end
|
|
|
|
# external fields: there are 2 terms, a noise to break potential symmetries
|
|
# and one to favor newest versions over older, and no-version over all
|
|
fld = [[FieldValue(0, zero(VersionWeight), vweight[p0][v0], (v0==spp[p0]), 0, noise(p0,v0)) for v0 = 1:spp[p0]] for p0 = 1:np]
|
|
|
|
# enforce requirements
|
|
for (rp, rvs) in reqs
|
|
p0 = pdict[rp]
|
|
pvers0 = pvers[p0]
|
|
fld0 = fld[p0]
|
|
for v0 = 1:spp[p0]-1
|
|
vn = pvers0[v0]
|
|
if !in(vn, rvs)
|
|
# the state is forbidden by requirements
|
|
fld0[v0] = FieldValue(-1)
|
|
else
|
|
# the state is one of those explicitly requested:
|
|
# favor it at a higer level than normal (upgrade
|
|
# FieldValue from l2 to l1)
|
|
fld0[v0] += FieldValue(0, vweight[p0][v0], -vweight[p0][v0])
|
|
end
|
|
end
|
|
# the uninstalled state is forbidden by requirements
|
|
fld0[spp[p0]] = FieldValue(-1)
|
|
end
|
|
# normalize fields
|
|
for p0 = 1:np
|
|
m = maximum(fld[p0])
|
|
for v0 = 1:spp[p0]
|
|
fld[p0][v0] -= m
|
|
end
|
|
end
|
|
|
|
initial_fld = deepcopy(fld)
|
|
|
|
# initialize cavity messages to 0
|
|
gadj = graph.gadj
|
|
msg = [[zeros(FieldValue, spp[p0]) for p1 = 1:length(gadj[p0])] for p0 = 1:np]
|
|
|
|
return new(msg, fld, initial_fld, falses(np), np)
|
|
end
|
|
end
|
|
|
|
function getsolution(msgs::Messages)
|
|
# the solution is just the location of the maximum in
|
|
# each field
|
|
|
|
fld = msgs.fld
|
|
np = length(fld)
|
|
sol = Vector{Int}(np)
|
|
for p0 = 1:np
|
|
fld0 = fld[p0]
|
|
s0 = indmax(fld0)
|
|
if !validmax(fld0[s0])
|
|
throw(UnsatError(p0))
|
|
end
|
|
sol[p0] = s0
|
|
end
|
|
return sol
|
|
end
|
|
|
|
# This is the core of the max-sum solver:
|
|
# for a given node p0 (i.e. a package) updates all
|
|
# input cavity messages and fields of its neighbors
|
|
function update(p0::Int, graph::Graph, msgs::Messages)
|
|
gadj = graph.gadj
|
|
gmsk = graph.gmsk
|
|
gdir = graph.gdir
|
|
adjdict = graph.adjdict
|
|
spp = graph.spp
|
|
np = graph.np
|
|
msg = msgs.msg
|
|
fld = msgs.fld
|
|
decimated = msgs.decimated
|
|
|
|
maxdiff = zero(FieldValue)
|
|
|
|
gadj0 = gadj[p0]
|
|
msg0 = msg[p0]
|
|
fld0 = fld[p0]
|
|
spp0 = spp[p0]
|
|
adjdict0 = adjdict[p0]
|
|
|
|
# iterate over all neighbors of p0
|
|
for j0 in 1:length(gadj0)
|
|
|
|
p1 = gadj0[j0]
|
|
decimated[p1] && continue
|
|
j1 = adjdict0[p1]
|
|
#@assert j0 == adjdict[p1][p0]
|
|
bm1 = gmsk[p1][j1]
|
|
dir1 = gdir[p1][j1]
|
|
spp1 = spp[p1]
|
|
msg1 = msg[p1]
|
|
|
|
# compute the output cavity message p0->p1
|
|
cavmsg = fld0 - msg0[j0]
|
|
|
|
if dir1 == -1
|
|
# p0 depends on p1
|
|
for v0 = 1:spp0-1
|
|
cavmsg[v0] += FieldValue(0, VersionWeight(0), VersionWeight(0), 0, v0)
|
|
end
|
|
end
|
|
|
|
# keep the old input cavity message p0->p1
|
|
oldmsg = msg1[j1]
|
|
|
|
# init the new message to minus infinity
|
|
newmsg = [FieldValue(-1) for v1 = 1:spp1]
|
|
|
|
# compute the new message by passing cavmsg
|
|
# through the constraint encoded in the bitmask
|
|
# (nearly equivalent to:
|
|
# newmsg = [maximum(cavmsg[bm1[:,v1]]) for v1 = 1:spp1]
|
|
# except for the gnrg term)
|
|
m = FieldValue(-1)
|
|
for v1 = 1:spp1
|
|
for v0 = 1:spp0
|
|
if bm1[v0, v1]
|
|
newmsg[v1] = max(newmsg[v1], cavmsg[v0])
|
|
end
|
|
end
|
|
if dir1 == 1 && v1 != spp1
|
|
# p1 depends on p0
|
|
newmsg[v1] += FieldValue(0, VersionWeight(0), VersionWeight(0), 0, v1)
|
|
end
|
|
m = max(m, newmsg[v1])
|
|
end
|
|
if !validmax(m)
|
|
# No state available without violating some
|
|
# hard constraint
|
|
throw(UnsatError(p1))
|
|
end
|
|
|
|
# normalize the new message
|
|
for v1 = 1:spp1
|
|
newmsg[v1] -= m
|
|
end
|
|
|
|
diff = newmsg - oldmsg
|
|
maxdiff = max(maxdiff, maximum(abs.(diff)))
|
|
|
|
# update the field of p1
|
|
fld1 = fld[p1]
|
|
for v1 = 1:spp1
|
|
fld1[v1] += diff[v1]
|
|
end
|
|
|
|
# put the newly computed message in place
|
|
msg1[j1] = newmsg
|
|
end
|
|
return maxdiff
|
|
end
|
|
|
|
# A simple shuffling machinery for the update order in iterate()
|
|
# (woulnd't pass any random quality test but it's arguably enough)
|
|
let step=1
|
|
global shuffleperm, shuffleperminit
|
|
shuffleperminit() = (step = 1)
|
|
function shuffleperm(graph::Graph)
|
|
perm = graph.perm
|
|
np = graph.np
|
|
for j = np:-1:2
|
|
k = mod(step,j)+1
|
|
perm[j], perm[k] = perm[k], perm[j]
|
|
step += isodd(j) ? 1 : k
|
|
end
|
|
#@assert isperm(perm)
|
|
end
|
|
end
|
|
|
|
# Call update for all nodes (i.e. packages) in
|
|
# random order
|
|
function iterate(graph::Graph, msgs::Messages)
|
|
np = graph.np
|
|
|
|
maxdiff = zero(FieldValue)
|
|
shuffleperm(graph)
|
|
perm = graph.perm
|
|
for p0 in perm
|
|
maxdiff0 = update(p0, graph, msgs)
|
|
maxdiff = max(maxdiff, maxdiff0)
|
|
end
|
|
return maxdiff
|
|
end
|
|
|
|
function decimate1(p0::Int, graph::Graph, msgs::Messages)
|
|
decimated = msgs.decimated
|
|
fld = msgs.fld
|
|
adjdict = graph.adjdict
|
|
gmsk = graph.gmsk
|
|
|
|
@assert !decimated[p0]
|
|
fld0 = fld[p0]
|
|
s0 = indmax(fld0)
|
|
# only do the decimation if it is consistent with
|
|
# the previously decimated nodes
|
|
for p1 in find(decimated)
|
|
haskey(adjdict[p0], p1) || continue
|
|
s1 = indmax(fld[p1])
|
|
j1 = adjdict[p0][p1]
|
|
gmsk[p1][j1][s0,s1] || return false
|
|
end
|
|
#println("DECIMATING $p0 (s0=$s0 fld=$fld0)")
|
|
for v0 = 1:length(fld0)
|
|
v0 == s0 && continue
|
|
fld0[v0] = FieldValue(-1)
|
|
end
|
|
msgs.decimated[p0] = true
|
|
msgs.num_nondecimated -= 1
|
|
return true
|
|
end
|
|
|
|
function reset_messages!(msgs::Messages)
|
|
msg = msgs.msg
|
|
fld = msgs.fld
|
|
initial_fld = msgs.initial_fld
|
|
decimated = msgs.decimated
|
|
np = length(fld)
|
|
for p0 = 1:np
|
|
map(m->fill!(m, zero(FieldValue)), msg[p0])
|
|
decimated[p0] && continue
|
|
fld[p0] = copy(initial_fld[p0])
|
|
end
|
|
return msgs
|
|
end
|
|
|
|
# If normal convergence fails (or is too slow) fix the most
|
|
# polarized packages by adding extra infinite fields on every state
|
|
# but the maximum
|
|
function decimate(n::Int, graph::Graph, msgs::Messages)
|
|
#println("DECIMATING $n NODES")
|
|
adjdict = graph.adjdict
|
|
fld = msgs.fld
|
|
decimated = msgs.decimated
|
|
fldorder = sortperm(fld, by=secondmax)
|
|
did_dec = false
|
|
for p0 in fldorder
|
|
decimated[p0] && continue
|
|
did_dec |= decimate1(p0, graph, msgs)
|
|
n -= 1
|
|
n == 0 && break
|
|
end
|
|
@assert n == 0
|
|
if !did_dec
|
|
# did not succeed in decimating anything;
|
|
# try to decimate at least one node
|
|
for p0 in fldorder
|
|
decimated[p0] && continue
|
|
if decimate1(p0, graph, msgs)
|
|
did_dec = true
|
|
break
|
|
end
|
|
end
|
|
end
|
|
if !did_dec
|
|
# still didn't succeed, give up
|
|
p0 = first(fldorder[.~(decimated)])
|
|
throw(UnsatError(p0))
|
|
end
|
|
|
|
reset_messages!(msgs)
|
|
return
|
|
end
|
|
|
|
# In case ties still exist at convergence, break them and
|
|
# keep converging
|
|
function break_ties(msgs::Messages)
|
|
fld = msgs.fld
|
|
unbroken_ties = Int[]
|
|
for p0 = 1:length(fld)
|
|
fld0 = fld[p0]
|
|
z = 0
|
|
m = typemin(FieldValue)
|
|
for v0 = 1:length(fld0)
|
|
if fld0[v0] > m
|
|
m = fld0[v0]
|
|
z = 1
|
|
elseif fld0[v0] == m
|
|
z += 1
|
|
end
|
|
end
|
|
if z > 1
|
|
#println("TIE! p0=$p0")
|
|
decimate1(p0, msgs) && return false
|
|
push!(unbroken_ties, p0)
|
|
end
|
|
end
|
|
# If there were ties, but none were broken, bail out
|
|
isempty(unbroken_ties) || throw(PkgError(first(unbroken_ties)))
|
|
return true
|
|
end
|
|
|
|
# Iterative solver: run iterate() until convergence
|
|
# (occasionally calling decimate())
|
|
function maxsum(graph::Graph, msgs::Messages)
|
|
params = MaxSumParams()
|
|
|
|
it = 0
|
|
shuffleperminit()
|
|
while true
|
|
it += 1
|
|
maxdiff = iterate(graph, msgs)
|
|
#println("it = $it maxdiff = $maxdiff")
|
|
|
|
if maxdiff == zero(FieldValue)
|
|
break_ties(msgs) && break
|
|
continue
|
|
end
|
|
if it >= params.nondec_iterations &&
|
|
(it - params.nondec_iterations) % params.dec_interval == 0
|
|
numdec = clamp(floor(Int, params.dec_fraction * graph.np), 1, msgs.num_nondecimated)
|
|
decimate(numdec, graph, msgs)
|
|
msgs.num_nondecimated == 0 && break
|
|
end
|
|
end
|
|
|
|
# Finally, decimate everything just to
|
|
# check against inconsistencies
|
|
# (old_numnondec is saved just to prevent
|
|
# wrong messages about accuracy)
|
|
old_numnondec = msgs.num_nondecimated
|
|
decimate(msgs.num_nondecimated, graph, msgs)
|
|
msgs.num_nondecimated = old_numnondec
|
|
|
|
return getsolution(msgs)
|
|
end
|
|
|
|
end
|