# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
The SCEngine class uses the ``sortedcontainers`` package to implement an
Index engine for Tables.
"""
from collections import OrderedDict
from itertools import starmap
from astropy.utils.compat.optional_deps import HAS_SORTEDCONTAINERS
if HAS_SORTEDCONTAINERS:
from sortedcontainers import SortedList
class Node:
__slots__ = ("key", "value")
def __init__(self, key, value):
self.key = key
self.value = value
def __lt__(self, other):
if other.__class__ is Node:
return (self.key, self.value) < (other.key, other.value)
return self.key < other
def __le__(self, other):
if other.__class__ is Node:
return (self.key, self.value) <= (other.key, other.value)
return self.key <= other
def __eq__(self, other):
if other.__class__ is Node:
return (self.key, self.value) == (other.key, other.value)
return self.key == other
def __ne__(self, other):
if other.__class__ is Node:
return (self.key, self.value) != (other.key, other.value)
return self.key != other
def __gt__(self, other):
if other.__class__ is Node:
return (self.key, self.value) > (other.key, other.value)
return self.key > other
def __ge__(self, other):
if other.__class__ is Node:
return (self.key, self.value) >= (other.key, other.value)
return self.key >= other
__hash__ = None
def __repr__(self):
return f"Node({self.key!r}, {self.value!r})"
[docs]
class SCEngine:
"""
Fast tree-based implementation for indexing, using the
``sortedcontainers`` package.
Parameters
----------
data : Table
Sorted columns of the original table
row_index : Column object
Row numbers corresponding to data columns
unique : bool
Whether the values of the index must be unique.
Defaults to False.
"""
def __init__(self, data, row_index, unique=False):
if not HAS_SORTEDCONTAINERS:
raise ImportError("sortedcontainers is needed for using SCEngine")
node_keys = map(tuple, data)
self._nodes = SortedList(starmap(Node, zip(node_keys, row_index)))
self._unique = unique
[docs]
def add(self, key, value):
"""
Add a key, value pair.
"""
if self._unique and (key in self._nodes):
message = f"duplicate {key!r} in unique index"
raise ValueError(message)
self._nodes.add(Node(key, value))
[docs]
def find(self, key):
"""
Find rows corresponding to the given key.
"""
return [node.value for node in self._nodes.irange(key, key)]
[docs]
def remove(self, key, data=None):
"""
Remove data from the given key.
"""
if data is not None:
item = Node(key, data)
try:
self._nodes.remove(item)
except ValueError:
return False
return True
items = list(self._nodes.irange(key, key))
for item in items:
self._nodes.remove(item)
return bool(items)
[docs]
def shift_left(self, row):
"""
Decrement rows larger than the given row.
"""
for node in self._nodes:
if node.value > row:
node.value -= 1
[docs]
def shift_right(self, row):
"""
Increment rows greater than or equal to the given row.
"""
for node in self._nodes:
if node.value >= row:
node.value += 1
[docs]
def items(self):
"""
Return a list of key, data tuples.
"""
result = OrderedDict()
for node in self._nodes:
if node.key in result:
result[node.key].append(node.value)
else:
result[node.key] = [node.value]
return result.items()
[docs]
def sort(self):
"""
Make row order align with key order.
"""
for index, node in enumerate(self._nodes):
node.value = index
[docs]
def sorted_data(self):
"""
Return a list of rows in order sorted by key.
"""
return [node.value for node in self._nodes]
[docs]
def range(self, lower, upper, bounds=(True, True)):
"""
Return row values in the given range.
"""
iterator = self._nodes.irange(lower, upper, bounds)
return [node.value for node in iterator]
[docs]
def replace_rows(self, row_map):
"""
Replace rows with the values in row_map.
"""
nodes = [node for node in self._nodes if node.value in row_map]
for node in nodes:
node.value = row_map[node.value]
self._nodes.clear()
self._nodes.update(nodes)
def __repr__(self):
if len(self._nodes) > 6:
nodes = list(self._nodes[:3]) + ["..."] + list(self._nodes[-3:])
else:
nodes = self._nodes
nodes_str = ", ".join(str(node) for node in nodes)
return f"<{self.__class__.__name__} nodes={nodes_str}>"