Source code for astropy.table.sorted_array

# Licensed under a 3-clause BSD style license - see LICENSE.rst
import numpy as np


def _searchsorted(array, val, side='left'):
    '''
    Call np.searchsorted or use a custom binary
    search if necessary.
    '''
    if hasattr(array, 'searchsorted'):
        return array.searchsorted(val, side=side)
    # Python binary search
    begin = 0
    end = len(array)
    while begin < end:
        mid = (begin + end) // 2
        if val > array[mid]:
            begin = mid + 1
        elif val < array[mid]:
            end = mid
        elif side == 'right':
            begin = mid + 1
        else:
            end = mid
    return begin


[docs]class SortedArray: ''' Implements a sorted array container using a list of numpy arrays. Parameters ---------- data : Table Sorted columns of the original table row_index : Column object Row numbers corresponding to data columns unique : bool Whether the values of the index must be unique. Defaults to False. ''' def __init__(self, data, row_index, unique=False): self.data = data self.row_index = row_index self.num_cols = len(getattr(data, 'colnames', [])) self.unique = unique @property def cols(self): return list(self.data.columns.values())
[docs] def add(self, key, row): ''' Add a new entry to the sorted array. Parameters ---------- key : tuple Column values at the given row row : int Row number ''' pos = self.find_pos(key, row) # first >= key if self.unique and 0 <= pos < len(self.row_index) and \ all(self.data[pos][i] == key[i] for i in range(len(key))): # already exists raise ValueError(f'Cannot add duplicate value "{key}" in a unique index') self.data.insert_row(pos, key) self.row_index = self.row_index.insert(pos, row)
def _get_key_slice(self, i, begin, end): ''' Retrieve the ith slice of the sorted array from begin to end. ''' if i < self.num_cols: return self.cols[i][begin:end] else: return self.row_index[begin:end]
[docs] def find_pos(self, key, data, exact=False): ''' Return the index of the largest key in data greater than or equal to the given key, data pair. Parameters ---------- key : tuple Column key data : int Row number exact : bool If True, return the index of the given key in data or -1 if the key is not present. ''' begin = 0 end = len(self.row_index) num_cols = self.num_cols if not self.unique: # consider the row value as well key = key + (data,) num_cols += 1 # search through keys in lexicographic order for i in range(num_cols): key_slice = self._get_key_slice(i, begin, end) t = _searchsorted(key_slice, key[i]) # t is the smallest index >= key[i] if exact and (t == len(key_slice) or key_slice[t] != key[i]): # no match return -1 elif t == len(key_slice) or (t == 0 and len(key_slice) > 0 and key[i] < key_slice[0]): # too small or too large return begin + t end = begin + _searchsorted(key_slice, key[i], side='right') begin += t if begin >= len(self.row_index): # greater than all keys return begin return begin
[docs] def find(self, key): ''' Find all rows matching the given key. Parameters ---------- key : tuple Column values Returns ------- matching_rows : list List of rows matching the input key ''' begin = 0 end = len(self.row_index) # search through keys in lexicographic order for i in range(self.num_cols): key_slice = self._get_key_slice(i, begin, end) t = _searchsorted(key_slice, key[i]) # t is the smallest index >= key[i] if t == len(key_slice) or key_slice[t] != key[i]: # no match return [] elif t == 0 and len(key_slice) > 0 and key[i] < key_slice[0]: # too small or too large return [] end = begin + _searchsorted(key_slice, key[i], side='right') begin += t if begin >= len(self.row_index): # greater than all keys return [] return self.row_index[begin:end]
[docs] def range(self, lower, upper, bounds): ''' Find values in the given range. Parameters ---------- lower : tuple Lower search bound upper : tuple Upper search bound bounds : (2,) tuple of bool Indicates whether the search should be inclusive or exclusive with respect to the endpoints. The first argument corresponds to an inclusive lower bound, and the second argument to an inclusive upper bound. ''' lower_pos = self.find_pos(lower, 0) upper_pos = self.find_pos(upper, 0) if lower_pos == len(self.row_index): return [] lower_bound = tuple([col[lower_pos] for col in self.cols]) if not bounds[0] and lower_bound == lower: lower_pos += 1 # data[lower_pos] > lower # data[lower_pos] >= lower # data[upper_pos] >= upper if upper_pos < len(self.row_index): upper_bound = tuple([col[upper_pos] for col in self.cols]) if not bounds[1] and upper_bound == upper: upper_pos -= 1 # data[upper_pos] < upper elif upper_bound > upper: upper_pos -= 1 # data[upper_pos] <= upper return self.row_index[lower_pos:upper_pos + 1]
[docs] def remove(self, key, data): ''' Remove the given entry from the sorted array. Parameters ---------- key : tuple Column values data : int Row number Returns ------- successful : bool Whether the entry was successfully removed ''' pos = self.find_pos(key, data, exact=True) if pos == -1: # key not found return False self.data.remove_row(pos) keep_mask = np.ones(len(self.row_index), dtype=bool) keep_mask[pos] = False self.row_index = self.row_index[keep_mask] return True
[docs] def shift_left(self, row): ''' Decrement all row numbers greater than the input row. Parameters ---------- row : int Input row number ''' self.row_index[self.row_index > row] -= 1
[docs] def shift_right(self, row): ''' Increment all row numbers greater than or equal to the input row. Parameters ---------- row : int Input row number ''' self.row_index[self.row_index >= row] += 1
[docs] def replace_rows(self, row_map): ''' Replace all rows with the values they map to in the given dictionary. Any rows not present as keys in the dictionary will have their entries deleted. Parameters ---------- row_map : dict Mapping of row numbers to new row numbers ''' num_rows = len(row_map) keep_rows = np.zeros(len(self.row_index), dtype=bool) tagged = 0 for i, row in enumerate(self.row_index): if row in row_map: keep_rows[i] = True tagged += 1 if tagged == num_rows: break self.data = self.data[keep_rows] self.row_index = np.array( [row_map[x] for x in self.row_index[keep_rows]])
[docs] def items(self): ''' Retrieve all array items as a list of pairs of the form [(key, [row 1, row 2, ...]), ...] ''' array = [] last_key = None for i, key in enumerate(zip(*self.data.columns.values())): row = self.row_index[i] if key == last_key: array[-1][1].append(row) else: last_key = key array.append((key, [row])) return array
[docs] def sort(self): ''' Make row order align with key order. ''' self.row_index = np.arange(len(self.row_index))
[docs] def sorted_data(self): ''' Return rows in sorted order. ''' return self.row_index
def __getitem__(self, item): ''' Return a sliced reference to this sorted array. Parameters ---------- item : slice Slice to use for referencing ''' return SortedArray(self.data[item], self.row_index[item]) def __repr__(self): t = self.data.copy() t['rows'] = self.row_index return f'<{self.__class__.__name__} length={len(t)}>\n{t}'