# coding=utf-8
# This file is part of intset (https://github.com/DRMacIver/inteset)
# Most of this work is copyright (C) 2013-2015 David R. MacIver
# (david@drmaciver.com), but it contains contributions by others, who hold
# copyright over their individual contributions.
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at http://mozilla.org/MPL/2.0/.
# END HEADER
# coding=utf-8
# This file is part of intset (https://github.com/DRMacIver/intset)
# Copyright (C) 2013-2015 David R. MacIver (david@drmaciver.com)
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import division, print_function, absolute_import
from collections import Set, Sequence
__all__ = [
'IntSet',
]
class IntSetMeta(type):
def __call__(self, *args, **kwargs):
if len(args) == 0:
return self._wrap(())
elif len(args) == 1:
result = IntSet.Builder()
for i in args[0]:
try:
result.insert(i)
except TypeError:
result.insert_interval(*i)
return result.build()
else:
raise TypeError('IntSet expected at most 1 arguments, got %d' % (
len(args),
))
def _wrap(self, value):
return type.__call__(self, value)
[docs]class IntSet(IntSetMeta('IntSet', (object,), {})):
"""
An IntSet is a compressed immutable representation of a sorted list of
unsigned 64-bit integers with fast membership, union and range restriction.
It mostly behaves as if it were a sorted list of deduplicated integer
values. In particular, you can index it if it were, and it will sort and
compare equal (to other IntSets) as if it were.
Note that unlike lists, intsets may feasibly have more than sys.maxint
elements, and calling len() on such an intset may raise an OverflowError.
If you wish to avoid this, use .size() instead.
Because IntSet is immutable, unlike list it may also be used as a hash key.
It also supports set operations. In particular, all the boolean operations
are supported:
x & y: An IntSet containing the values that are present in both x and y
x | y: An IntSet containing the values present in either x or y
x - y: An IntSet containing the values present in x but not y
x ^ y: An IntSet containing the values present in x or y but not both
~x: An IntSet containing all values in the range 0 <= i < 2 ** 64 that
are not present in x (IntSet can represent this efficiently. It
won't allocate 2 ** 64 integers worth of memory).
IntSets may be constructed either from the dedicated class methods or by
calling the class as you usually would for a set. So IntSet([1, 2, 3]) is
an IntSet containing the values 1, 2 and 3.
When calling an IntSet this way, non-integer values which are iterable
sequences of length 2 will be interpreted as intervals start <= x < end.
So e.g. IntSet([1, [10, 100]]) will contain the numbers 1 and 10, ..., 99.
"""
__slots__ = ('wrapped')
[docs] class Builder(object):
"""An IntSet.Builder is for building up an IntSet incrementally through
a series of insertions.
This will typically be much faster than repeatedly calling
insert on an IntSet object. The intended usage is to repeatedly
call insert() or insert_interval() on a builder, then call
build() at the end. Note that you can continue to insert further
data into a Builder afterwards if you wish, and this will not
affect previously built IntSet instances.
"""
def __init__(self):
self.wrapped = ()
self.pending = []
self.intervals = []
[docs] def insert(self, value):
"""Add a single value to the IntSet to be built."""
_validate_integer_in_range('value', value)
self.pending.append(value)
[docs] def insert_interval(self, start, end):
"""Add all values x such that start <= x < end to the IntSet to be
built."""
if start >= end:
return
self.intervals.append([start, end])
[docs] def build(self):
"""Produce a new IntSet with all the values previously inserted to
this builder.
You may call build() more than once, and any values inserted
in between those calls will also be present, but previously
built values will be unaffected by subsequent inserts
"""
self.pending.sort()
still_pending = []
i = 0
while i < len(self.pending):
j = i
while j + 1 < len(self.pending):
if self.pending[j + 1] == self.pending[j] + 1:
j += 1
else:
break
if i < j:
self.intervals.append([
self.pending[i], self.pending[j] + 1
])
else:
still_pending.append(self.pending[i])
i = j + 1
self.pending = []
if still_pending:
self.wrapped = _union(
self.wrapped,
_from_sorted_list(still_pending, 0, len(still_pending))
)
if self.intervals:
intervals = _normalize_intervals(self.intervals)
self.intervals = []
self.wrapped = _union(
self.wrapped, _from_intervals(intervals))
return IntSet._wrap(self.wrapped)
def __getstate__(self):
# wrap in a tuple because a falsey value will cause the corresponding
# setstate to not be called.
return (list(self.intervals()),)
def __setstate__(self, state):
self.wrapped = IntSet.from_intervals(state[0]).wrapped
def __init__(self, wrapped):
assert isinstance(wrapped, tuple)
self.wrapped = wrapped
def __repr__(self):
bits = []
for i, j in self.intervals():
if i + 1 < j:
bits.append((i, j))
else:
bits.append(i)
return 'IntSet(%r)' % (bits,)
@classmethod
[docs] def empty(cls):
"""Return an empty IntSet."""
return IntSet._wrap(())
@classmethod
[docs] def single(cls, value):
"""Return an IntSet containing only the single value provided."""
_validate_integer_in_range('value', value)
_validate_integer_in_range('value + 1', value + 1)
return IntSet._wrap(_new_single(value))
@classmethod
[docs] def interval(cls, start, end):
"""
Return an IntSet containing only the values x such that
start <= x < end
"""
_validate_integer_in_range('start', start)
if end != 0:
_validate_integer_in_range('end - 1', end - 1)
return IntSet._wrap(_new_maybe_empty_interval(start, end))
@classmethod
[docs] def from_iterable(self, values):
"""Return an IntSet containing everything in values, which should be an
iterable over intsets in the valid range."""
return IntSet._wrap(
_from_sorted_list(sorted(values), 0, len(values))
)
@classmethod
[docs] def from_intervals(cls, intervals):
"""Return a new IntSet which contains precisely the intervals passed
in."""
return cls._wrap(
_from_intervals(_normalize_intervals(list(map(list, intervals)))))
[docs] def size(self):
"""This returns the same as len() when the latter is defined, but
IntSet may have more values than will fit in the size of index that len
will allow."""
if self.wrapped:
return self.wrapped[_SIZE]
else:
return 0
[docs] def insert(self, value):
"""Returns an IntSet which contains all the values of the current one
plus the provided value."""
_validate_integer_in_range('value', value)
return IntSet._wrap(_insert(self.wrapped, value))
[docs] def discard(self, value):
"""Returns an IntSet which contains all the values of the current one
except for the passed in value.
Returns self if the value is not present rather than raising an
error
"""
_validate_integer_in_range('value', value)
return IntSet._wrap(_discard(self.wrapped, value))
[docs] def restrict(self, start, end):
"""Return a new IntSet with all values x in self such that start <=
x < end."""
return IntSet._wrap(_restrict(self.wrapped, start, end))
def __len__(self):
return self.size()
def __bool__(self):
return bool(self.wrapped)
def __nonzero__(self):
return self.__bool__()
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, IntSet):
return False
if self.size() != other.size():
return False
return self.__cmp__(other) == 0
def __ne__(self, other):
return not self.__eq__(other)
def __cmp__(self, other):
if not isinstance(other, IntSet):
raise TypeError(
'Unorderable types IntSet and %s' % (type(other).__name__,))
self_intervals = list(self.intervals())
other_intervals = list(other.intervals())
self_intervals.reverse()
other_intervals.reverse()
while self_intervals and other_intervals:
self_head = self_intervals.pop()
other_head = other_intervals.pop()
if self_head[0] < other_head[0]:
return -1
if self_head[0] > other_head[0]:
return 1
if self_head[1] < other_head[1]:
other_intervals.append((self_head[1], other_head[1]))
if self_head[1] > other_head[1]:
self_intervals.append((other_head[1], self_head[1]))
if self_intervals:
return 1
if other_intervals:
return -1
return 0
def __lt__(self, other):
return self.__cmp__(other) < 0
def __gt__(self, other):
return self.__cmp__(other) > 0
def __le__(self, other):
return self.__cmp__(other) <= 0
def __ge__(self, other):
return self.__cmp__(other) >= 0
def __contains__(self, i):
return _contains(self.wrapped, i)
def __iter__(self):
for start, end in self.intervals():
for i in range(start, end):
yield i
def __getitem__(self, i):
size = self.size()
if i < -size or i >= size:
raise IndexError('IntSet index %d out of range for size %d' % (
i, size,
))
if i < 0:
i += size
assert i >= 0
return _getitem(self.wrapped, i)
def __hash__(self):
return hash(self.wrapped[:3])
def __copy__(self):
return self
def __deepcopy__(self, table):
return table.setdefault(self, self)
[docs] def isdisjoint(self, other):
"""Returns True if self and other have no common elements."""
return _isdisjoint(self.wrapped, other.wrapped)
[docs] def intersects(self, other):
"""Returns True if there is an element i such that i in self and i in
other."""
return not self.isdisjoint(other)
def issubset(self, other):
return _issubset(self.wrapped, other.wrapped)
[docs] def issuperset(self, other):
"""Returns True if every element of other is also in self."""
return other.issubset(self)
def __and__(self, other):
assert isinstance(other, IntSet)
return IntSet._wrap(_intersect(self.wrapped, other.wrapped))
def __invert__(self):
return whole_range - self
def __sub__(self, other):
return IntSet._wrap(_subtract(self.wrapped, other.wrapped))
def __xor__(self, other):
return (self | other) - (self & other)
def __or__(self, other):
return IntSet._wrap(_union(self.wrapped, other.wrapped))
def intervals(self):
"""
Provide a sorted iterator over a sequence of values start < end which
represent non-overlapping intervals such that for any start <= x < end
x in self
"""
return _intervals(self.wrapped)
def reversed_intervals(self):
"""Iterator over the reverse of intervals()"""
return _reversed_intervals(self.wrapped)
def __reversed__(self):
for start, end in self.reversed_intervals():
for i in range(end - 1, start - 1, -1):
yield i
Sequence.register(IntSet)
Set.register(IntSet)
def _new_maybe_empty_interval(start, end):
if end <= start:
return ()
return _new_interval(start, end)
_START = 0
_END = 1
_SIZE = 2
_PREFIX = 3
_MASK = 4
_LEFT = 5
_RIGHT = 6
_INTERVAL_LENGTH = 3
_SPLIT_LENGTH = 7
def _new_interval(start, end):
return (start, end, end - start)
def _new_single(value):
return (value, value + 1, 1)
def _new_split_maybe_empty(prefix, mask, left, right):
if len(left) == 0:
return right
if len(right) == 0:
return left
return _new_split(prefix, mask, left, right)
def _new_split(prefix, mask, left, right):
if left[_SIZE] + right[_SIZE] + left[_START] == right[_END]:
return _new_interval(left[_START], right[_END])
return _new_split_no_collapse(prefix, mask, left, right)
def _new_split_no_collapse(prefix, mask, left, right):
return (
left[_START], right[_END],
left[_SIZE] + right[_SIZE], prefix, mask, left, right
)
def _split_interval(ins):
start = ins[_START]
end = ins[_END]
split_mask = branch_mask(start, end - 1)
split_prefix = _mask_off(start, split_mask)
split_point = split_prefix | split_mask
return (
start, end, ins[_SIZE], split_prefix, split_mask,
_new_interval(start, split_point), _new_interval(split_point, end)
)
def _join(p1, t1, p2, t2):
m = branch_mask(p1, p2)
p = _mask_off(p1, m)
if not _is_zero(p1, m):
t1, t2 = t2, t1
return _new_split(p, m, t1, t2)
def _insert(ins, value):
l = len(ins)
if l == 0:
return _new_single(value)
elif l == _INTERVAL_LENGTH:
start = ins[_START]
end = ins[_END]
if start <= value < end:
return ins
elif value == end:
return _new_interval(start, end + 1)
elif value + 1 == start:
return _new_interval(value, end)
elif ins[_SIZE] == 1:
return _join(start, ins, value, _new_single(value))
else:
ins = _split_interval(ins)
prefix = ins[_PREFIX]
mask = ins[_MASK]
if _no_match(value, prefix, mask):
return _join(
value, _new_single(value),
prefix, ins
)
elif _is_zero(value, mask):
return _new_split(
prefix, mask, _insert(ins[_LEFT], value), ins[_RIGHT])
else:
return _new_split(
prefix, mask, ins[_LEFT], _insert(ins[_RIGHT], value))
def _getitem(self, i):
while len(self) > _INTERVAL_LENGTH:
if i < self[_LEFT][_SIZE]:
self = self[_LEFT]
else:
i -= self[_LEFT][_SIZE]
self = self[_RIGHT]
return self[_START] + i
def _discard(self, value):
l = len(self)
if l == 0:
return self
elif l == _INTERVAL_LENGTH:
if value < self[_START] or value >= self[_END]:
return self
if value == self[_START]:
return _new_maybe_empty_interval(self[_START] + 1, self[_END])
if value + 1 == self[_END]:
return _new_maybe_empty_interval(self[_START], self[_END] - 1)
self = _split_interval(self)
if _is_zero(value, self[_MASK]):
return _new_split_maybe_empty(
self[_PREFIX], self[_MASK],
_discard(self[_LEFT], value), self[_RIGHT]
)
else:
return _new_split_maybe_empty(
self[_PREFIX], self[_MASK],
self[_LEFT], _discard(self[_RIGHT], value)
)
def _normalize_intervals(intervals):
intervals.sort()
merged_intervals = []
for x in intervals:
if x[0] >= x[1]:
continue
if merged_intervals:
last = merged_intervals[-1]
if x[0] <= last[-1]:
last[-1] = max(x[1], last[-1])
continue
merged_intervals.append(x)
return merged_intervals
def _from_intervals(intervals):
if len(intervals) == 0:
return ()
else:
return _from_intervals_worker(intervals)
def _from_intervals_worker(intervals):
if len(intervals) == 1:
return _new_interval(*intervals[0])
start = intervals[0][0]
end = intervals[-1][-1]
split_mask = branch_mask(start, end - 1)
split_prefix = _mask_off(start, split_mask)
split_point = split_prefix | split_mask
left = []
right = []
for x in intervals:
if x[1] <= split_point:
left.append(x)
elif x[0] < split_point:
left.append([x[0], split_point])
right.append([split_point, x[1]])
else:
right.append(x)
return _new_split_no_collapse(
split_prefix, split_mask,
_from_intervals_worker(left), _from_intervals_worker(right)
)
def _union(self, other):
if len(self) == 0:
return other
if len(other) == 0:
return self
if other[_SIZE] > self[_SIZE]:
self, other = other, self
if len(self) == _INTERVAL_LENGTH:
if self[_START] <= other[_START] and other[_END] <= self[_END]:
return self
if len(other) == _INTERVAL_LENGTH:
if self[_START] <= other[_END] and other[_START] <= self[_END]:
return _new_interval(
min(self[_START], other[_START]),
max(self[_END], other[_END]),
)
elif self[_SIZE] > 1:
return _union(_split_interval(self), other)
else:
return _join(self[_START], self, other[_START], other)
if len(other) == _INTERVAL_LENGTH:
if other[_SIZE] == 1:
return _insert(self, other[_START])
else:
other = _split_interval(other)
if len(self) == _INTERVAL_LENGTH:
self = _split_interval(self)
if _shorter(other[_MASK], self[_MASK]):
self, other = other, self
if _shorter(self[_MASK], other[_MASK]):
if _no_match(other[_PREFIX], self[_PREFIX], self[_MASK]):
return _join(
self[_PREFIX], self, other[_PREFIX], other
)
elif _is_zero(other[_PREFIX], self[_MASK]):
return _new_split(
self[_PREFIX], self[_MASK],
_union(self[_LEFT], other), self[_RIGHT]
)
else:
return _new_split(
self[_PREFIX], self[_MASK],
self[_LEFT], _union(self[_RIGHT], other)
)
else:
assert self[_MASK] == other[_MASK]
if self[_PREFIX] == other[_PREFIX]:
return _new_split(
self[_PREFIX], self[_MASK],
_union(self[_LEFT], other[_LEFT]),
_union(self[_RIGHT], other[_RIGHT])
)
else:
return _join(self[_PREFIX], self, other[_PREFIX], other)
def _restrict(self, start, end):
if not self:
return self
if start >= self[_END] or self[_START] >= end:
return ()
if len(self) == _INTERVAL_LENGTH:
return _new_interval(
max(start, self[_START]), min(end, self[_END]))
return _new_split_maybe_empty(
self[_PREFIX], self[_MASK],
_restrict(self[_LEFT], start, end),
_restrict(self[_RIGHT], start, end),
)
def _contains(self, value):
if not self:
return False
while len(self) != _INTERVAL_LENGTH:
if _is_zero(value, self[_MASK]):
self = self[_LEFT]
else:
self = self[_RIGHT]
return self[_START] <= value < self[_END]
def _intersect(self, other):
if not (self and other):
return ()
if self[_SIZE] > other[_SIZE]:
self, other = other, self
if other[_SIZE] == 1:
if _contains(self, other[_START]):
return other
else:
return ()
if len(self) == _INTERVAL_LENGTH:
return _restrict(other, self[_START], self[_END])
if len(other) == _INTERVAL_LENGTH:
return _restrict(self, other[_START], other[_END])
if self[_START] >= other[_END]:
return ()
if self[_END] <= other[_START]:
return ()
if _shorter(other[_MASK], self[_MASK]):
self, other = other, self
if _shorter(self[_MASK], other[_MASK]):
if _is_zero(other[_PREFIX], self[_MASK]):
return _intersect(self[_LEFT], other)
else:
return _intersect(self[_RIGHT], other)
else:
return _new_split_maybe_empty(
self[_PREFIX], self[_MASK],
_intersect(self[_LEFT], other[_LEFT]),
_intersect(self[_RIGHT], other[_RIGHT])
)
def _subtract(self, other):
if not (other and self):
return self
if len(other) == _INTERVAL_LENGTH:
return _union(
_restrict(self, self[_START], other[_START]),
_restrict(self, other[_END], self[_END]))
if self[_SIZE] == 1:
if _contains(other, self[_START]):
return ()
else:
return self
if len(self) == _INTERVAL_LENGTH:
self = _split_interval(self)
if _shorter(self[_MASK], other[_MASK]):
if _no_match(other[_PREFIX], self[_PREFIX], self[_MASK]):
return self
elif _is_zero(other[_PREFIX], self[_MASK]):
return _new_split_maybe_empty(
self[_PREFIX], self[_MASK],
_subtract(self[_LEFT], other), self[_RIGHT]
)
else:
return _new_split_maybe_empty(
self[_PREFIX], self[_MASK], self[_LEFT],
_subtract(self[_RIGHT], other)
)
elif _shorter(other[_MASK], self[_MASK]):
if _is_zero(self[_PREFIX], other[_MASK]):
return _subtract(self, other[_LEFT])
else:
return _subtract(self, other[_RIGHT])
else:
if self[_PREFIX] == other[_PREFIX]:
return _new_split_maybe_empty(
self[_PREFIX], self[_MASK],
_subtract(self[_LEFT], other[_LEFT]),
_subtract(self[_RIGHT], other[_RIGHT])
)
else:
return self
def _isdisjoint(self, other):
if not (self and other):
return True
if self[_START] >= other[_END]:
return True
if other[_START] >= self[_END]:
return True
if len(self) == _INTERVAL_LENGTH:
if len(other) == _INTERVAL_LENGTH:
return False
other, self = self, other
return _isdisjoint(self[_LEFT], other) and _isdisjoint(self[_RIGHT], other)
def _issubset(self, other):
if not self:
return True
if not other:
return False
if len(other) == _INTERVAL_LENGTH:
return (
other[_START] <= self[_START] and
self[_END] <= other[_END])
if self[_START] >= other[_END]:
return False
if other[_START] >= self[_END]:
return False
if len(self) == _INTERVAL_LENGTH:
if self[_SIZE] == 1:
return _contains(other, self[_START])
elif self[_SIZE] == 2:
return (
_contains(other, self[_START]) and
_contains(other, self[_END] - 1))
self = _split_interval(self)
if _shorter(self[_MASK], other[_MASK]):
return False
elif _shorter(other[_MASK], self[_MASK]):
if _is_zero(self[_PREFIX], other[_MASK]):
return _issubset(self, other[_LEFT])
else:
return _issubset(self, other[_RIGHT])
else:
# If they have incompatible prefixes the above start/end checks
# must have returned False already because they're actually
# disjoint.
assert self[_PREFIX] == other[_PREFIX]
return (
_issubset(self[_LEFT], other[_LEFT]) and
_issubset(self[_RIGHT], other[_RIGHT]))
def _intervals(self):
if not self:
return
stack = [self]
while stack:
head = stack.pop()
if len(head) == _INTERVAL_LENGTH:
yield (head[_START], head[_END])
else:
stack.append(head[_RIGHT])
stack.append(head[_LEFT])
def _from_sorted_list(ls, start, end):
if start == end:
return ()
if start + 1 == end:
return _new_single(ls[start])
mid = (start + end) // 2
return _union(
_from_sorted_list(ls, start, mid), _from_sorted_list(ls, mid, end))
def _reversed_intervals(self):
if not self:
return
stack = [self]
while stack:
head = stack.pop()
if len(head) == _INTERVAL_LENGTH:
yield (head[_START], head[_END])
else:
stack.append(head[_LEFT])
stack.append(head[_RIGHT])
def _right_fill_bits(key):
key |= (key >> 1)
key |= (key >> 2)
key |= (key >> 4)
key |= (key >> 8)
key |= (key >> 16)
key |= (key >> 32)
return key
def _highest_bit_mask(k):
k = _right_fill_bits(k)
k ^= (k >> 1)
return k
def branch_mask(p1, p2):
return _highest_bit_mask(p1 ^ p2)
def _mask_off(i, m):
return i & (~(m - 1) ^ m)
def _is_zero(i, m):
return (i & m) == 0
def _no_match(i, p, m):
return _mask_off(i, m) != p
def _shorter(m1, m2):
return m1 > m2
_UPPER_BOUND = 2 ** 64
whole_range = IntSet._wrap(_new_interval(0, _UPPER_BOUND))
INTEGER_TYPES = (type(0), type(2 ** 64))
def _validate_integer_in_range(name, i):
if not isinstance(i, INTEGER_TYPES):
raise TypeError(
'Expected %s to be an integer but got %r of type %s' % (
name, i, type(i).__name__))
if i < 0 or i >= _UPPER_BOUND:
raise ValueError(
'Argument %s=%d out of required range 0 <= %s < 2 ** 64' % (
name, i, name))