1 module satd.cnf;
2 import std.container : redBlackTree, RedBlackTree;
3 import std.array : array;
4 import std.algorithm : count, sort;
5 import std.string : format;
6 import std.range : empty, zip;
7 import std.math : abs;
8 import satd.dimacs;
9 
10 debug import std.stdio;
11 
12 alias Set(T) = RedBlackTree!T;
13 
14 public:
15 
16 /++
17 + 0 は \overline \Lambda (conflict node) を表す。
18 + \cdots, -3, -2, -1, 1, 2, 3, \cdots を通常のリテラルのために利用する。
19 + x > 0 であるとき、
20 + 正の整数 x は、x を意味する。
21 + 負の整数 -x は、\lnot x を意味する。
22 +/
23 alias Literal = long;
24 
25 /// 与えられた Literal を否定したものを返す。
26 Literal negate(Literal lit)
27 {
28     return -lit;
29 }
30 
31 /// 節
32 struct Clause
33 {
34     alias ID = size_t;
35 
36     /// 節を区別するための ID
37     ID id;
38     /// 節に含まれる Literal の集合
39     Set!Literal literals;
40 
41     this(ID id, Set!Literal literals)
42     {
43         this.id = id;
44         this.literals = literals;
45     }
46 
47     this(ID id, Literal[] literals)
48     {
49         this(id, redBlackTree!Literal(literals));
50     }
51 
52     this(Set!Literal literals)
53     {
54         this.literals = literals;
55     }
56 
57     this(Literal[] literals)
58     {
59         this(redBlackTree!Literal(literals));
60     }
61 
62     this(Clause clause)
63     {
64         this.id = clause.id;
65         this.literals = clause.literals.dup;
66     }
67 
68     bool isEmptyClause()
69     {
70         return literals.length == 0;
71     }
72 
73     bool isUnitClause()
74     {
75         return literals.length == 1;
76     }
77 
78     Literal unitLiteral()
79     {
80         assert(this.isUnitClause());
81         return literals.front;
82     }
83 
84     bool containsLiteral(Literal lit)
85     {
86         return lit in literals;
87     }
88 
89     auto removeLiteral(Literal lit)
90     {
91         return literals.removeKey(lit);
92     }
93 
94     int opCmp(R)(const R other) const
95     {
96         return this.id < other.id;
97     }
98 
99     bool opBinaryRight(string op)(Literal lit) const if (op == "in")
100     {
101         return lit in literals;
102     }
103 
104     string toString()
105     {
106         if (literals.length == 0)
107             return "(empty)";
108         else
109             return format("(%(%d ∨ %))", literals.array.sort!((a, b) => a.abs < b.abs));
110     }
111 }
112 
113 struct CNF
114 {
115     size_t variableNum, clauseNum;
116 
117     Clause[Clause.ID] allClauses;
118 
119     Clause[Clause.ID] normalClauses; // other than "unit" or "empty" clause
120     Clause[Clause.ID] unitClauses;
121     Clause[Clause.ID] emptyClauses;
122 
123     Literal[][Clause.ID] literalsInClause;
124     Clause.ID[][Literal] clausesContainingLiteral;
125 
126     this(Clause[] clauses, Preamble preamble)
127     {
128         this.variableNum = preamble.variables;
129         this.clauseNum = preamble.clauses;
130 
131         foreach (clause; clauses)
132         {
133             Clause.ID cid = clause.id;
134             allClauses[cid] = clause;
135 
136             if (clause.isEmptyClause())
137                 emptyClauses[cid] = clause;
138             else if (clause.isUnitClause())
139                 unitClauses[cid] = clause;
140             else
141                 normalClauses[cid] = clause;
142 
143             foreach (literal; clause.literals.array)
144             {
145                 literalsInClause[cid] ~= literal;
146                 clausesContainingLiteral[literal] ~= clause.id;
147                 debug stderr.writeln(clause.id);
148             }
149         }
150     }
151 
152     // for deep copy
153     this(CNF rhs)
154     {
155         this.variableNum = rhs.variableNum;
156         this.clauseNum = rhs.clauseNum;
157 
158         foreach (key, value; rhs.allClauses)
159         {
160             this.allClauses[key] = Clause(value);
161         }
162         foreach (key, value; rhs.normalClauses)
163         {
164             this.normalClauses[key] = Clause(value);
165         }
166         foreach (key, value; rhs.unitClauses)
167         {
168             this.unitClauses[key] = Clause(value);
169         }
170         foreach (key, value; rhs.emptyClauses)
171         {
172             this.emptyClauses[key] = Clause(value);
173         }
174         this.literalsInClause = rhs.literalsInClause.dup;
175         this.clausesContainingLiteral = rhs.clausesContainingLiteral.dup;
176     }
177 
178     void removeLiterals(Literal literal)
179     {
180         if (literal !in clausesContainingLiteral)
181             return;
182         foreach (id; clausesContainingLiteral[literal])
183         {
184             if (id in unitClauses)
185             {
186                 Clause clause = unitClauses[id];
187                 unitClauses.remove(id);
188                 clause.removeLiteral(literal);
189                 this.emptyClauses[id] = clause;
190             }
191             else if (id in normalClauses)
192             {
193                 Clause clause = normalClauses[id];
194                 clause.removeLiteral(literal);
195                 if (clause.isUnitClause())
196                 {
197                     normalClauses.remove(id);
198                     unitClauses[id] = clause;
199                 }
200             }
201         }
202     }
203 
204     void removeClauseById(Clause.ID clauseId)
205     {
206         normalClauses.remove(clauseId);
207         unitClauses.remove(clauseId);
208         emptyClauses.remove(clauseId);
209         allClauses.remove(clauseId);
210     }
211 
212     void removeClauseContainingLiteral(Literal literal)
213     {
214         foreach (id; clausesContainingLiteral[literal])
215             removeClauseById(id);
216         clausesContainingLiteral.remove(literal);
217     }
218 
219     void simplify(Literal literal)
220     {
221         this.removeClauseContainingLiteral(literal);
222         this.removeLiterals(-literal);
223     }
224 
225     string toString() const
226     {
227         if (allClauses.length == 0)
228             return "<none>";
229         const(Clause)[] tmp;
230         foreach (key; allClauses.keys.sort)
231         {
232             tmp ~= allClauses[key];
233         }
234         return format("%((%s)∧%))", tmp);
235     }
236 
237     debug string debugString() const
238     {
239         return format("all: %s\nnormal:%s\nunit:%s\nempty:%s\nclcon:%s", allClauses,
240                 normalClauses, unitClauses, emptyClauses, clausesContainingLiteral);
241     }
242 }