mollusk 0e4acfb8f2 fix incorrect folder name for julia-0.6.x
Former-commit-id: ef2c7401e0876f22d2f7762d182cfbcd5a7d9c70
2018-06-11 03:28:36 -07:00

525 lines
15 KiB
Julia

# This file is a part of Julia. License is MIT: https://julialang.org/license
module MaxSum
include("fieldvalue.jl")
using .FieldValues, ..VersionWeights, ..PkgToMaxSumInterface
export UnsatError, Graph, Messages, maxsum
# An exception type used internally to signal that an unsatisfiable
# constraint was detected
mutable struct UnsatError <: Exception
info
end
# Some parameters to drive the decimation process
mutable struct MaxSumParams
nondec_iterations # number of initial iterations before starting
# decimation
dec_interval # number of iterations between decimations
dec_fraction # fraction of nodes to decimate at every decimation
# step
function MaxSumParams()
accuracy = parse(Int, get(ENV, "JULIA_PKGRESOLVE_ACCURACY", "1"))
if accuracy <= 0
error("JULIA_PKGRESOLVE_ACCURACY must be > 0")
end
nondec_iterations = accuracy * 20
dec_interval = accuracy * 10
dec_fraction = 0.05 / accuracy
return new(nondec_iterations, dec_interval, dec_fraction)
end
end
# Graph holds the graph structure onto which max-sum is run, in
# sparse format
mutable struct Graph
# adjacency matrix:
# for each package, has the list of neighbors
# indices (both dependencies and dependants)
gadj::Vector{Vector{Int}}
# compatibility mask:
# for each package p0 has a list of bool masks.
# Each entry in the list gmsk[p0] is relative to the
# package p1 as read from gadj[p0].
# Each mask has dimension spp1 x spp0, where
# spp0 is the number of states of p0, and
# spp1 is the number of states of p1.
gmsk::Vector{Vector{BitMatrix}}
# dependency direction:
# keeps track of which direction the dependency goes
# takes 3 values:
# 1 = dependant
# -1 = dependency
# 0 = both
# Used to break symmetry between dependants and
# dependencies (introduces a FieldValue at level l3).
# The "both" case is for when there are dependency
# relations which go both ways, in which case the
# noise is left to discriminate in case of ties
gdir::Vector{Vector{Int}}
# adjacency dict:
# allows one to retrieve the indices in gadj, so that
# gadj[p0][adjdict[p1][p0]] = p1
# ("At which index does package p1 appear in gadj[p0]?")
adjdict::Vector{Dict{Int,Int}}
# states per package: same as in Interface
spp::Vector{Int}
# update order: shuffled at each iteration
perm::Vector{Int}
# number of packages (all Vectors above have this length)
np::Int
function Graph(interface::Interface)
deps = interface.deps
np = interface.np
spp = interface.spp
pdict = interface.pdict
pvers = interface.pvers
vdict = interface.vdict
gadj = [Int[] for i = 1:np]
gmsk = [BitMatrix[] for i = 1:np]
gdir = [Int[] for i = 1:np]
adjdict = [Dict{Int,Int}() for i = 1:np]
for (p,d) in deps
p0 = pdict[p]
vdict0 = vdict[p0]
for (vn,a) in d
v0 = vdict0[vn]
for (rp, rvs) in a.requires
p1 = pdict[rp]
j0 = 1
while j0 <= length(gadj[p0]) && gadj[p0][j0] != p1
j0 += 1
end
j1 = 1
while j1 <= length(gadj[p1]) && gadj[p1][j1] != p0
j1 += 1
end
@assert (j0 > length(gadj[p0]) && j1 > length(gadj[p1])) ||
(j0 <= length(gadj[p0]) && j1 <= length(gadj[p1]))
if j0 > length(gadj[p0])
push!(gadj[p0], p1)
push!(gadj[p1], p0)
j0 = length(gadj[p0])
j1 = length(gadj[p1])
adjdict[p1][p0] = j0
adjdict[p0][p1] = j1
bm = trues(spp[p1], spp[p0])
bmt = bm'
push!(gmsk[p0], bm)
push!(gmsk[p1], bmt)
push!(gdir[p0], 1)
push!(gdir[p1], -1)
else
bm = gmsk[p0][j0]
bmt = gmsk[p1][j1]
if gdir[p0][j0] == -1
gdir[p0][j0] = 0
gdir[p1][j1] = 0
end
end
for v1 = 1:length(pvers[p1])
if pvers[p1][v1] rvs
bm[v1, v0] = false
bmt[v0, v1] = false
end
end
bm[end,v0] = false
bmt[v0,end] = false
end
end
end
perm = [1:np;]
return new(gadj, gmsk, gdir, adjdict, spp, perm, np)
end
end
# Messages has the cavity messages and the total fields, and
# gets updated iteratively (and occasionally decimated) until
# convergence
mutable struct Messages
# cavity incoming messages: for each package p0,
# for each neighbor p1 of p0,
# msg[p0][p1] is a vector of length spp[p0]
# messages are normalized (i.e. the max is always 0)
msg::Vector{Vector{Field}}
# overall fields: for each package p0,
# fld[p0] is a vector of length spp[p0]
# fields are not normalized
fld::Vector{Field}
# backup of the initial value of fld, to be used when resetting
initial_fld::Vector{Field}
# keep track of which variables have been decimated
decimated::BitVector
num_nondecimated::Int
function Messages(interface::Interface, graph::Graph)
reqs = interface.reqs
pkgs = interface.pkgs
np = interface.np
spp = interface.spp
pvers = interface.pvers
pdict = interface.pdict
vweight = interface.vweight
# a "deterministic noise" function based on hashes
function noise(p0::Int, v0::Int)
s = pkgs[p0] * string(v0 == spp[p0] ? "UNINST" : pvers[p0][v0])
Int128(hash(s))
end
# external fields: there are 2 terms, a noise to break potential symmetries
# and one to favor newest versions over older, and no-version over all
fld = [[FieldValue(0, zero(VersionWeight), vweight[p0][v0], (v0==spp[p0]), 0, noise(p0,v0)) for v0 = 1:spp[p0]] for p0 = 1:np]
# enforce requirements
for (rp, rvs) in reqs
p0 = pdict[rp]
pvers0 = pvers[p0]
fld0 = fld[p0]
for v0 = 1:spp[p0]-1
vn = pvers0[v0]
if !in(vn, rvs)
# the state is forbidden by requirements
fld0[v0] = FieldValue(-1)
else
# the state is one of those explicitly requested:
# favor it at a higer level than normal (upgrade
# FieldValue from l2 to l1)
fld0[v0] += FieldValue(0, vweight[p0][v0], -vweight[p0][v0])
end
end
# the uninstalled state is forbidden by requirements
fld0[spp[p0]] = FieldValue(-1)
end
# normalize fields
for p0 = 1:np
m = maximum(fld[p0])
for v0 = 1:spp[p0]
fld[p0][v0] -= m
end
end
initial_fld = deepcopy(fld)
# initialize cavity messages to 0
gadj = graph.gadj
msg = [[zeros(FieldValue, spp[p0]) for p1 = 1:length(gadj[p0])] for p0 = 1:np]
return new(msg, fld, initial_fld, falses(np), np)
end
end
function getsolution(msgs::Messages)
# the solution is just the location of the maximum in
# each field
fld = msgs.fld
np = length(fld)
sol = Vector{Int}(np)
for p0 = 1:np
fld0 = fld[p0]
s0 = indmax(fld0)
if !validmax(fld0[s0])
throw(UnsatError(p0))
end
sol[p0] = s0
end
return sol
end
# This is the core of the max-sum solver:
# for a given node p0 (i.e. a package) updates all
# input cavity messages and fields of its neighbors
function update(p0::Int, graph::Graph, msgs::Messages)
gadj = graph.gadj
gmsk = graph.gmsk
gdir = graph.gdir
adjdict = graph.adjdict
spp = graph.spp
np = graph.np
msg = msgs.msg
fld = msgs.fld
decimated = msgs.decimated
maxdiff = zero(FieldValue)
gadj0 = gadj[p0]
msg0 = msg[p0]
fld0 = fld[p0]
spp0 = spp[p0]
adjdict0 = adjdict[p0]
# iterate over all neighbors of p0
for j0 in 1:length(gadj0)
p1 = gadj0[j0]
decimated[p1] && continue
j1 = adjdict0[p1]
#@assert j0 == adjdict[p1][p0]
bm1 = gmsk[p1][j1]
dir1 = gdir[p1][j1]
spp1 = spp[p1]
msg1 = msg[p1]
# compute the output cavity message p0->p1
cavmsg = fld0 - msg0[j0]
if dir1 == -1
# p0 depends on p1
for v0 = 1:spp0-1
cavmsg[v0] += FieldValue(0, VersionWeight(0), VersionWeight(0), 0, v0)
end
end
# keep the old input cavity message p0->p1
oldmsg = msg1[j1]
# init the new message to minus infinity
newmsg = [FieldValue(-1) for v1 = 1:spp1]
# compute the new message by passing cavmsg
# through the constraint encoded in the bitmask
# (nearly equivalent to:
# newmsg = [maximum(cavmsg[bm1[:,v1]]) for v1 = 1:spp1]
# except for the gnrg term)
m = FieldValue(-1)
for v1 = 1:spp1
for v0 = 1:spp0
if bm1[v0, v1]
newmsg[v1] = max(newmsg[v1], cavmsg[v0])
end
end
if dir1 == 1 && v1 != spp1
# p1 depends on p0
newmsg[v1] += FieldValue(0, VersionWeight(0), VersionWeight(0), 0, v1)
end
m = max(m, newmsg[v1])
end
if !validmax(m)
# No state available without violating some
# hard constraint
throw(UnsatError(p1))
end
# normalize the new message
for v1 = 1:spp1
newmsg[v1] -= m
end
diff = newmsg - oldmsg
maxdiff = max(maxdiff, maximum(abs.(diff)))
# update the field of p1
fld1 = fld[p1]
for v1 = 1:spp1
fld1[v1] += diff[v1]
end
# put the newly computed message in place
msg1[j1] = newmsg
end
return maxdiff
end
# A simple shuffling machinery for the update order in iterate()
# (woulnd't pass any random quality test but it's arguably enough)
let step=1
global shuffleperm, shuffleperminit
shuffleperminit() = (step = 1)
function shuffleperm(graph::Graph)
perm = graph.perm
np = graph.np
for j = np:-1:2
k = mod(step,j)+1
perm[j], perm[k] = perm[k], perm[j]
step += isodd(j) ? 1 : k
end
#@assert isperm(perm)
end
end
# Call update for all nodes (i.e. packages) in
# random order
function iterate(graph::Graph, msgs::Messages)
np = graph.np
maxdiff = zero(FieldValue)
shuffleperm(graph)
perm = graph.perm
for p0 in perm
maxdiff0 = update(p0, graph, msgs)
maxdiff = max(maxdiff, maxdiff0)
end
return maxdiff
end
function decimate1(p0::Int, graph::Graph, msgs::Messages)
decimated = msgs.decimated
fld = msgs.fld
adjdict = graph.adjdict
gmsk = graph.gmsk
@assert !decimated[p0]
fld0 = fld[p0]
s0 = indmax(fld0)
# only do the decimation if it is consistent with
# the previously decimated nodes
for p1 in find(decimated)
haskey(adjdict[p0], p1) || continue
s1 = indmax(fld[p1])
j1 = adjdict[p0][p1]
gmsk[p1][j1][s0,s1] || return false
end
#println("DECIMATING $p0 (s0=$s0 fld=$fld0)")
for v0 = 1:length(fld0)
v0 == s0 && continue
fld0[v0] = FieldValue(-1)
end
msgs.decimated[p0] = true
msgs.num_nondecimated -= 1
return true
end
function reset_messages!(msgs::Messages)
msg = msgs.msg
fld = msgs.fld
initial_fld = msgs.initial_fld
decimated = msgs.decimated
np = length(fld)
for p0 = 1:np
map(m->fill!(m, zero(FieldValue)), msg[p0])
decimated[p0] && continue
fld[p0] = copy(initial_fld[p0])
end
return msgs
end
# If normal convergence fails (or is too slow) fix the most
# polarized packages by adding extra infinite fields on every state
# but the maximum
function decimate(n::Int, graph::Graph, msgs::Messages)
#println("DECIMATING $n NODES")
adjdict = graph.adjdict
fld = msgs.fld
decimated = msgs.decimated
fldorder = sortperm(fld, by=secondmax)
did_dec = false
for p0 in fldorder
decimated[p0] && continue
did_dec |= decimate1(p0, graph, msgs)
n -= 1
n == 0 && break
end
@assert n == 0
if !did_dec
# did not succeed in decimating anything;
# try to decimate at least one node
for p0 in fldorder
decimated[p0] && continue
if decimate1(p0, graph, msgs)
did_dec = true
break
end
end
end
if !did_dec
# still didn't succeed, give up
p0 = first(fldorder[.~(decimated)])
throw(UnsatError(p0))
end
reset_messages!(msgs)
return
end
# In case ties still exist at convergence, break them and
# keep converging
function break_ties(msgs::Messages)
fld = msgs.fld
unbroken_ties = Int[]
for p0 = 1:length(fld)
fld0 = fld[p0]
z = 0
m = typemin(FieldValue)
for v0 = 1:length(fld0)
if fld0[v0] > m
m = fld0[v0]
z = 1
elseif fld0[v0] == m
z += 1
end
end
if z > 1
#println("TIE! p0=$p0")
decimate1(p0, msgs) && return false
push!(unbroken_ties, p0)
end
end
# If there were ties, but none were broken, bail out
isempty(unbroken_ties) || throw(PkgError(first(unbroken_ties)))
return true
end
# Iterative solver: run iterate() until convergence
# (occasionally calling decimate())
function maxsum(graph::Graph, msgs::Messages)
params = MaxSumParams()
it = 0
shuffleperminit()
while true
it += 1
maxdiff = iterate(graph, msgs)
#println("it = $it maxdiff = $maxdiff")
if maxdiff == zero(FieldValue)
break_ties(msgs) && break
continue
end
if it >= params.nondec_iterations &&
(it - params.nondec_iterations) % params.dec_interval == 0
numdec = clamp(floor(Int, params.dec_fraction * graph.np), 1, msgs.num_nondecimated)
decimate(numdec, graph, msgs)
msgs.num_nondecimated == 0 && break
end
end
# Finally, decimate everything just to
# check against inconsistencies
# (old_numnondec is saved just to prevent
# wrong messages about accuracy)
old_numnondec = msgs.num_nondecimated
decimate(msgs.num_nondecimated, graph, msgs)
msgs.num_nondecimated = old_numnondec
return getsolution(msgs)
end
end