FST  openfst-1.8.3
OpenFst Library
bitmap-index.cc
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 
19 
20 #include <algorithm>
21 #include <array>
22 #include <cstddef>
23 #include <cstdint>
24 #include <utility>
25 #include <vector>
26 
27 #include <fst/log.h>
29 
30 namespace fst {
31 
32 static_assert(sizeof(long long) >= sizeof(uint64_t), // NOLINT
33  "__builtin_...ll is used on uint64_t values.");
34 
35 constexpr std::array<uint64_t, 64> LowBitsMasks() {
36  std::array<uint64_t, 64> m{};
37  for (int i = 0; i < 64; ++i) m[i] = (uint64_t{1} << i) - 1;
38  return m;
39 }
40 constexpr std::array<uint64_t, 64> kLowBitsMasks = LowBitsMasks();
41 
42 size_t BitmapIndex::Rank1(size_t end) const {
43  DCHECK_LE(end, Bits());
44  // TODO(jrosenstock): Remove nullptr support and this special case.
45  if (end == 0) return 0;
46  // Without this special case, we'd go past the end. It's questionable
47  // whether we should support end == Bits().
48  if (end >= num_bits_) return GetOnesCount();
49  const uint32_t end_word = end / kStorageBitSize;
50  const uint32_t sum = GetIndexOnesCount(end_word);
51  const int bit_index = end % kStorageBitSize;
52  // TODO(jrosenstock): better with or without special case, and does
53  // this depend on whether there's a popcnt instruction?
54  if (bit_index == 0) return sum; // Entire answer is in the index.
55  return sum + __builtin_popcountll(bits_[end_word] & kLowBitsMasks[bit_index]);
56 }
57 
58 size_t BitmapIndex::Select1(size_t bit_index) const {
59  if (bit_index >= GetOnesCount()) return Bits();
60  const RankIndexEntry& entry = FindRankIndexEntry(bit_index);
61  const uint32_t block_index = &entry - rank_index_.data();
62  // TODO(jrosenstock): Look at whether word or bit indices are faster.
63  static_assert(kUnitsPerRankIndexEntry == 8);
64  uint32_t word_index = block_index * kUnitsPerRankIndexEntry;
65 
66  // Find position within this block.
67  uint32_t rembits = bit_index - entry.absolute_ones_count();
68  if (rembits < entry.relative_ones_count_4()) {
69  if (rembits < entry.relative_ones_count_2()) {
70  if (rembits < entry.relative_ones_count_1()) {
71  // First word, nothing to do.
72  } else {
73  word_index += 1;
74  rembits -= entry.relative_ones_count_1();
75  }
76  } else if (rembits < entry.relative_ones_count_3()) {
77  word_index += 2;
78  rembits -= entry.relative_ones_count_2();
79  } else {
80  word_index += 3;
81  rembits -= entry.relative_ones_count_3();
82  }
83  } else if (rembits < entry.relative_ones_count_6()) {
84  if (rembits < entry.relative_ones_count_5()) {
85  word_index += 4;
86  rembits -= entry.relative_ones_count_4();
87  } else {
88  word_index += 5;
89  rembits -= entry.relative_ones_count_5();
90  }
91  } else if (rembits < entry.relative_ones_count_7()) {
92  word_index += 6;
93  rembits -= entry.relative_ones_count_6();
94  } else {
95  word_index += 7;
96  rembits -= entry.relative_ones_count_7();
97  }
98 
99  const int nth = nth_bit(bits_[word_index], rembits);
100  return kStorageBitSize * word_index + nth;
101 }
102 
103 size_t BitmapIndex::Select0(size_t bit_index) const {
104  const uint32_t zeros_count = Bits() - GetOnesCount();
105  if (bit_index >= zeros_count) return Bits();
106  const RankIndexEntry& entry = FindInvertedRankIndexEntry(bit_index);
107  const uint32_t block_index = &entry - rank_index_.data();
108  static_assert(kUnitsPerRankIndexEntry == 8);
109  uint32_t word_index = block_index * kUnitsPerRankIndexEntry;
110 
111  // Find position within this block.
112  uint32_t entry_zeros_count =
113  kStorageBitSize * word_index - entry.absolute_ones_count();
114  uint32_t remzeros = bit_index - entry_zeros_count;
115  if (remzeros < 4 * kStorageBitSize - entry.relative_ones_count_4()) {
116  if (remzeros < 2 * kStorageBitSize - entry.relative_ones_count_2()) {
117  if (remzeros < kStorageBitSize - entry.relative_ones_count_1()) {
118  // Nothing to do.
119  } else {
120  word_index += 1;
121  remzeros -= kStorageBitSize - entry.relative_ones_count_1();
122  }
123  } else if (remzeros < 3 * kStorageBitSize - entry.relative_ones_count_3()) {
124  word_index += 2;
125  remzeros -= 2 * kStorageBitSize - entry.relative_ones_count_2();
126  } else {
127  word_index += 3;
128  remzeros -= 3 * kStorageBitSize - entry.relative_ones_count_3();
129  }
130  } else if (remzeros < 6 * kStorageBitSize - entry.relative_ones_count_6()) {
131  if (remzeros < 5 * kStorageBitSize - entry.relative_ones_count_5()) {
132  word_index += 4;
133  remzeros -= 4 * kStorageBitSize - entry.relative_ones_count_4();
134  } else {
135  word_index += 5;
136  remzeros -= 5 * kStorageBitSize - entry.relative_ones_count_5();
137  }
138  } else if (remzeros < 7 * kStorageBitSize - entry.relative_ones_count_7()) {
139  word_index += 6;
140  remzeros -= 6 * kStorageBitSize - entry.relative_ones_count_6();
141  } else {
142  word_index += 7;
143  remzeros -= 7 * kStorageBitSize - entry.relative_ones_count_7();
144  }
145 
146  const int nth = nth_bit(~bits_[word_index], remzeros);
147  return kStorageBitSize * word_index + nth;
148 }
149 
150 std::pair<size_t, size_t> BitmapIndex::Select0s(size_t bit_index) const {
151  const uint32_t zeros_count = Bits() - GetOnesCount();
152  if (bit_index >= zeros_count) return {Bits(), Bits()};
153  if (bit_index + 1 >= zeros_count) return {Select0(bit_index), Bits()};
154 
155  const RankIndexEntry& entry = FindInvertedRankIndexEntry(bit_index);
156  const uint32_t block_index = &entry - rank_index_.data();
157  uint32_t word_index = block_index * kUnitsPerRankIndexEntry;
158 
159  // Find position within this block.
160  uint32_t entry_zeros_count =
161  kStorageBitSize * word_index - entry.absolute_ones_count();
162  uint32_t remzeros = bit_index - entry_zeros_count;
163  if (remzeros < 4 * kStorageBitSize - entry.relative_ones_count_4()) {
164  if (remzeros < 2 * kStorageBitSize - entry.relative_ones_count_2()) {
165  if (remzeros < kStorageBitSize - entry.relative_ones_count_1()) {
166  // Nothing to do.
167  } else {
168  word_index += 1;
169  remzeros -= kStorageBitSize - entry.relative_ones_count_1();
170  }
171  } else if (remzeros < 3 * kStorageBitSize - entry.relative_ones_count_3()) {
172  word_index += 2;
173  remzeros -= 2 * kStorageBitSize - entry.relative_ones_count_2();
174  } else {
175  word_index += 3;
176  remzeros -= 3 * kStorageBitSize - entry.relative_ones_count_3();
177  }
178  } else if (remzeros < 6 * kStorageBitSize - entry.relative_ones_count_6()) {
179  if (remzeros < 5 * kStorageBitSize - entry.relative_ones_count_5()) {
180  word_index += 4;
181  remzeros -= 4 * kStorageBitSize - entry.relative_ones_count_4();
182  } else {
183  word_index += 5;
184  remzeros -= 5 * kStorageBitSize - entry.relative_ones_count_5();
185  }
186  } else if (remzeros < 7 * kStorageBitSize - entry.relative_ones_count_7()) {
187  word_index += 6;
188  remzeros -= 6 * kStorageBitSize - entry.relative_ones_count_6();
189  } else {
190  word_index += 7;
191  remzeros -= 7 * kStorageBitSize - entry.relative_ones_count_7();
192  }
193 
194  // Find the position of the bit_index-th zero.
195  const uint64_t inv_word = ~bits_[word_index];
196  const int nth = nth_bit(inv_word, remzeros);
197 
198  // Then, we want to "1-out" everything below that position, and count trailing
199  // ones on the result. This gives us the position of the next zero.
200  // There is no count trailing ones builtin, so we invert and use count
201  // trailing zeros.
202 
203  // This mask has 1s in the nth+1 low order bits; it is equivalent to
204  // (1 << (nth + 1)) - 1, but doesn't need a special case when nth == 63.
205  // We want ~0 in this case anyway. We want nth+1 because if the bit_index-th
206  // zero is in position nth, we need to skip nth+1 positions.
207  const uint64_t mask = -(uint64_t{0x2} << nth); // == ~((2 << nth) - 1)
208  const uint64_t masked_inv_word = inv_word & mask;
209 
210  // If this is 0, then the next zero is not in the same word.
211  if (masked_inv_word != 0) {
212  // We can't ctz on 0, but we already checked that.
213  const int next_nth = __builtin_ctzll(masked_inv_word);
214  return {kStorageBitSize * word_index + nth,
215  kStorageBitSize * word_index + next_nth};
216  } else {
217  // TODO(jrosenstock): Try other words in the block.
218  // This should not be massively important. With a bit density of 1/2,
219  // 31/32 zeros in a word have the next zero in the same word.
220  return {kStorageBitSize * word_index + nth, Select0(bit_index + 1)};
221  }
222 }
223 
224 uint32_t BitmapIndex::GetIndexOnesCount(size_t array_index) const {
225  const auto& rank_index_entry =
226  rank_index_[array_index / kUnitsPerRankIndexEntry];
227  uint32_t ones_count = rank_index_entry.absolute_ones_count();
228  static_assert(kUnitsPerRankIndexEntry == 8);
229  return ones_count + rank_index_entry.relative_ones_count(
230  array_index % kUnitsPerRankIndexEntry);
231 }
232 
233 void BitmapIndex::BuildIndex(const uint64_t* bits, size_t num_bits,
234  bool enable_select_0_index,
235  bool enable_select_1_index) {
236  // Absolute counts are uint32s, so this is the most *set* bits we support
237  // for now. Just check the number of *input* bits is less than this
238  // to keep things simple.
239  DCHECK_LT(num_bits, uint64_t{1} << 32);
240  bits_ = bits;
241  num_bits_ = num_bits;
242  rank_index_.clear();
243  rank_index_.resize(rank_index_size());
244 
245  select_0_index_.clear();
246  if (enable_select_0_index) {
247  // Reserve approximately enough for density = 1/2.
248  select_0_index_.reserve(num_bits / (2 * kBitsPerSelect0Block) + 1);
249  }
250 
251  select_1_index_.clear();
252  if (enable_select_1_index) {
253  select_1_index_.reserve(num_bits / (2 * kBitsPerSelect1Block) + 1);
254  }
255 
256  const size_t kArraySize = ArraySize();
257  uint32_t ones_count = 0;
258  uint32_t zeros_count = 0;
259  for (uint32_t word_index = 0; word_index < kArraySize; word_index += 8) {
260  const uint64_t word[8] = {
261  bits[word_index],
262  (word_index + 1 < kArraySize) ? bits[word_index + 1] : 0,
263  (word_index + 2 < kArraySize) ? bits[word_index + 2] : 0,
264  (word_index + 3 < kArraySize) ? bits[word_index + 3] : 0,
265  (word_index + 4 < kArraySize) ? bits[word_index + 4] : 0,
266  (word_index + 5 < kArraySize) ? bits[word_index + 5] : 0,
267  (word_index + 6 < kArraySize) ? bits[word_index + 6] : 0,
268  (word_index + 7 < kArraySize) ? bits[word_index + 7] : 0,
269  };
270 
271  const int word_ones_count[8] = {
272  __builtin_popcountll(word[0]), __builtin_popcountll(word[1]),
273  __builtin_popcountll(word[2]), __builtin_popcountll(word[3]),
274  __builtin_popcountll(word[4]), __builtin_popcountll(word[5]),
275  __builtin_popcountll(word[6]), __builtin_popcountll(word[7]),
276  };
277 
278  auto& rank_index_entry = rank_index_[word_index / kUnitsPerRankIndexEntry];
279  const uint32_t abs_ones_count = ones_count;
280  rank_index_entry.set_absolute_ones_count(abs_ones_count);
281  ones_count += word_ones_count[0];
282  rank_index_entry.set_relative_ones_count_1(ones_count - abs_ones_count);
283  ones_count += word_ones_count[1];
284  rank_index_entry.set_relative_ones_count_2(ones_count - abs_ones_count);
285  ones_count += word_ones_count[2];
286  rank_index_entry.set_relative_ones_count_3(ones_count - abs_ones_count);
287  ones_count += word_ones_count[3];
288  rank_index_entry.set_relative_ones_count_4(ones_count - abs_ones_count);
289  ones_count += word_ones_count[4];
290  rank_index_entry.set_relative_ones_count_5(ones_count - abs_ones_count);
291  ones_count += word_ones_count[5];
292  rank_index_entry.set_relative_ones_count_6(ones_count - abs_ones_count);
293  ones_count += word_ones_count[6];
294  rank_index_entry.set_relative_ones_count_7(ones_count - abs_ones_count);
295  ones_count += word_ones_count[7];
296 
297  if (enable_select_0_index) {
298  int s0_zeros_count = zeros_count;
299  for (int i = 0; i < 8; ++i) {
300  const size_t bit_offset = (word_index + i) * kStorageBitSize;
301  if (bit_offset >= num_bits_) break;
302 
303  // Zeros count is somewhat more involved to compute, so only do it
304  // if we need it. The last word has zeros in the high bits, so
305  // that needs to be accounted for when computing the zeros count
306  // from the ones count.
307  const uint32_t bits_remaining = num_bits - bit_offset;
308  const int word_zeros_count =
309  std::min(bits_remaining, kStorageBitSize) - word_ones_count[i];
310 
311  // We record a 0 every kBitsPerSelect0Block bits. So, if zeros_count
312  // is 0 mod kBitsPerSelect0Block, we record the next zero. If
313  // zeros_count is 1 mod kBitsPerSelect0Block, we need to skip
314  // kBitsPerSelect0Block - 1 zeros, then record a zero. And so on.
315  // What function is this? It's -zeros_count % kBitsPerSelect0Block.
316  const uint32_t zeros_to_skip = -s0_zeros_count % kBitsPerSelect0Block;
317  if (word_zeros_count > zeros_to_skip) {
318  const int nth = nth_bit(~word[i], zeros_to_skip);
319  select_0_index_.push_back(bit_offset + nth);
320  static_assert(kBitsPerSelect0Block >= 512,
321  "kBitsPerSelect0Block must be at least 512.");
322  break; // 8 entries is 512 bits, so we can't push another bit here.
323  }
324  s0_zeros_count += word_zeros_count;
325  }
326  zeros_count += 8 * kStorageBitSize - (ones_count - abs_ones_count);
327  }
328 
329  if (enable_select_1_index) {
330  int s1_ones_count = abs_ones_count;
331  for (int i = 0; i < 8; ++i) {
332  const size_t bit_offset = (word_index + i) * kStorageBitSize;
333  uint32_t ones_to_skip = -s1_ones_count % kBitsPerSelect1Block;
334  if (word_ones_count[i] > ones_to_skip) {
335  const int nth = nth_bit(word[i], ones_to_skip);
336  select_1_index_.push_back(bit_offset + nth);
337  static_assert(kBitsPerSelect1Block >= 512,
338  "kBitsPerSelect1Block must be at least 512.");
339  break; // 8 entries is 512 bits, so we can't push another bit here.
340  }
341  s1_ones_count += word_ones_count[i];
342  }
343  }
344  }
345  // Add the extra entry with the total number of bits.
346  rank_index_.back().set_absolute_ones_count(ones_count);
347 
348  if (enable_select_0_index) {
349  // Add extra entry with num_bits_.
350  select_0_index_.push_back(num_bits_);
351  select_0_index_.shrink_to_fit();
352  }
353 
354  if (enable_select_1_index) {
355  select_1_index_.push_back(num_bits_);
356  select_1_index_.shrink_to_fit();
357  }
358 }
359 
360 const BitmapIndex::RankIndexEntry& BitmapIndex::FindRankIndexEntry(
361  size_t bit_index) const {
362  DCHECK_GE(bit_index, 0);
363  DCHECK_LT(bit_index, rank_index_.back().absolute_ones_count());
364 
365  const RankIndexEntry* begin = nullptr;
366  const RankIndexEntry* end = nullptr;
367  if (select_1_index_.empty()) {
368  begin = &rank_index_[0];
369  end = begin + rank_index_.size();
370  } else {
371  const uint32_t select_index = bit_index / kBitsPerSelect1Block;
372  DCHECK_LT(select_index + 1, select_1_index_.size());
373 
374  // TODO(jrosenstock): It would be nice to handle the exact hit
375  // bit_index % kBitsPerSelect1Block == 0 case so we could
376  // return the value, but that requiries some refactoring:
377  // either inlining this into Select1, or returning a pair
378  // or out param, etc.
379 
380  // The bit is between these indices.
381  const uint32_t lo_bit_index = select_1_index_[select_index];
382  const uint32_t hi_bit_index = select_1_index_[select_index + 1];
383 
384  begin = &rank_index_[lo_bit_index / kBitsPerSelect1Block];
385  end = &rank_index_[(hi_bit_index + kBitsPerSelect1Block - 1) /
386  kBitsPerSelect1Block];
387  }
388 
389  // Linear search if the range is small.
390  const RankIndexEntry* entry = nullptr;
391  if (end - begin <= kMaxLinearSearchBlocks) {
392  for (entry = begin; entry != end; ++entry) {
393  if (entry->absolute_ones_count() > bit_index) break;
394  }
395  } else {
396  RankIndexEntry search_entry;
397  search_entry.set_absolute_ones_count(bit_index);
398  // TODO(jrosenstock): benchmark upper vs custom bsearch.
399  entry = &*std::upper_bound(
400  begin, end, search_entry,
401  [](const RankIndexEntry& e1, const RankIndexEntry& e2) {
402  return e1.absolute_ones_count() < e2.absolute_ones_count();
403  });
404  }
405 
406  const auto& e = *(entry - 1);
407  DCHECK_LE(e.absolute_ones_count(), bit_index);
408  DCHECK_GT(entry->absolute_ones_count(), bit_index);
409  return e;
410 }
411 
412 const BitmapIndex::RankIndexEntry& BitmapIndex::FindInvertedRankIndexEntry(
413  size_t bit_index) const {
414  DCHECK_GE(bit_index, 0);
415  DCHECK_LT(bit_index, num_bits_ - rank_index_.back().absolute_ones_count());
416 
417  uint32_t lo = 0, hi = 0;
418  if (select_0_index_.empty()) {
419  lo = 0;
420  hi = (num_bits_ + kBitsPerRankIndexEntry - 1) / kBitsPerRankIndexEntry;
421  } else {
422  const uint32_t select_index = bit_index / kBitsPerSelect0Block;
423  DCHECK_LT(select_index + 1, select_0_index_.size());
424 
425  // TODO(jrosenstock): Same special case for exact hit.
426 
427  lo = select_0_index_[select_index] / kBitsPerSelect0Block;
428  hi = (select_0_index_[select_index + 1] + kBitsPerSelect0Block - 1) /
429  kBitsPerSelect0Block;
430  }
431 
432  DCHECK_LT(hi, rank_index_.size());
433  // Linear search never showed an advantage when benchmarking. This may be
434  // because the linear search is more complex with the zeros_count computation,
435  // or because the ranges are larger, so linear search is triggered less often,
436  // and the difference is harder to measure.
437  while (lo + 1 < hi) {
438  const uint32_t mid = lo + (hi - lo) / 2;
439  if (bit_index <
440  kBitsPerRankIndexEntry * mid - rank_index_[mid].absolute_ones_count()) {
441  hi = mid;
442  } else {
443  lo = mid;
444  }
445  }
446 
447  DCHECK_LE(lo * kBitsPerRankIndexEntry - rank_index_[lo].absolute_ones_count(),
448  bit_index);
449  if ((lo + 1) * kBitsPerRankIndexEntry <= num_bits_) {
450  DCHECK_GT((lo + 1) * kBitsPerRankIndexEntry -
451  rank_index_[lo + 1].absolute_ones_count(),
452  bit_index);
453  } else {
454  DCHECK_GT(num_bits_ - rank_index_[lo + 1].absolute_ones_count(), bit_index);
455  }
456  return rank_index_[lo];
457 }
458 
459 } // end namespace fst
#define DCHECK_LT(x, y)
Definition: log.h:76
static constexpr uint32_t kStorageBitSize
Definition: bitmap-index.h:172
constexpr std::array< uint64_t, 64 > kLowBitsMasks
Definition: bitmap-index.cc:40
void BuildIndex(const uint64_t *bits, size_t num_bits, bool enable_select_0_index=false, bool enable_select_1_index=false)
size_t Bits() const
Definition: bitmap-index.h:124
size_t Select1(size_t bit_index) const
Definition: bitmap-index.cc:58
#define DCHECK_GT(x, y)
Definition: log.h:77
size_t Rank1(size_t end) const
Definition: bitmap-index.cc:42
size_t Select0(size_t bit_index) const
size_t GetOnesCount() const
Definition: bitmap-index.h:139
constexpr std::array< uint64_t, 64 > LowBitsMasks()
Definition: bitmap-index.cc:35
size_t ArraySize() const
Definition: bitmap-index.h:126
#define DCHECK_LE(x, y)
Definition: log.h:78
std::pair< size_t, size_t > Select0s(size_t bit_index) const
#define DCHECK_GE(x, y)
Definition: log.h:79
int nth_bit(const uint64_t v, uint32_t r)
Definition: nthbit.cc:235