#  Copyright 2020 Regents of the University of Minnesota.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Labels functionality."""
import threading
from abc import ABC, abstractmethod, ABCMeta
from queue import Queue
from typing import (
    List,
    Tuple,
    Set,
    TYPE_CHECKING,
    TypeVar,
    NamedTuple,
    Any,
    Mapping,
    Sequence,
    Union,
    Optional
)
if TYPE_CHECKING:
    import mtap
    from mtap import data
[docs]class Location(NamedTuple('Location', [('start_index', float), ('end_index', float)])):
    """A location in text, a tuple of (`start_index`, `end_index`).
    Used to perform comparison of labels based on their locations.
    Args:
        start_index (float):
            The start index inclusive of the location in text.
        end_index (float):
            The end index exclusive of the location in text.
    Attributes:
        start_index (float):
            The start index inclusive of the location in text.
        end_index (float):
            The end index exclusive of the location in text.
    """
    __slots__ = ()
[docs]    def covers(self, other: Union['data.Location', 'data.Label']):
        """Whether the span of text covered by this label completely overlaps the span of text
        covered by the ``other`` label or location.
        Args:
            other (~typing.Union[Location, Label]): A location or label to compare against.
        Returns:
            bool: ``True`` if `other` is completely overlapped/covered ``False`` otherwise.
        """
        return self.start_index <= other.start_index and self.end_index >= other.end_index 
[docs]    def relative_to(self, location: Union['data.Location', 'data.Label', int]) -> 'data.Location':
        """Creates a location relative to the the same origin as ``location`` and makes it relative
        to ``location``.
        Args:
            location (int or Location or Label): A location to relativize this location to.
        Returns:
            ~data.Location: A copy with updated indices.
        Examples:
            >>> sentence = Location(10, 20)
            >>> token = Location(10, 15)
            >>> token.relative_to(sentence)
            Location(start_index=0, end_index=5)
        """
        try:
            start_index = location.start_index
        except AttributeError:
            start_index = location
        if not isinstance(start_index, int):
            raise ValueError('location must be Label, Location, or an int value')
        return Location(self.start_index - start_index, self.end_index - start_index) 
[docs]    def offset_by(self, location: Union['data.Location', 'data.Label', int]) -> 'data.Location':
        """Creates a location by offsetting this location by an integer or the ``start_index`` of a
        location / label. Derelativizes this location.
        Args:
            location (int or Location or Label): A location to offset this location by.
        Returns:
            ~data.Location: A copy with updated indices.
        Examples:
            >>> sentence = Location(10, 20)
            >>> token_in_sentence = Location(0, 5)
            >>> token_in_sentence.offset_by(sentence)
            Location(start_index=10, end_index=15)
        """
        try:
            start_index = location.start_index
        except AttributeError:
            start_index = location
        if not isinstance(start_index, int):
            raise ValueError('location must be Label, Location, or an int value')
        return Location(self.start_index + start_index, self.end_index + start_index)  
[docs]class Label(ABC, metaclass=ABCMeta):
    """An abstract base class for a label of attributes on text.
    """
    @property
    @abstractmethod
    def document(self) -> 'mtap.Document':
        """Document: The parent document this label appears on."""
        ...
    @document.setter
    @abstractmethod
    def document(self, value: 'mtap.Document'):
        """Sets the label's document, this will automatically be done when the label is created
        via a Document (i.e. get_label_index) or added to a document (i.e. via labeler or add_labels).
        """
        ...
    @property
    @abstractmethod
    def label_index_name(self) -> str:
        """str: The label index this label appears on."""
        ...
    @label_index_name.setter
    @abstractmethod
    def label_index_name(self, value: str):
        """Sets the name for the label index this label appears on. Will automatically be called
        when a label is added to a document via labeler or add_labels."""
        ...
    @property
    @abstractmethod
    def identifier(self) -> int:
        """int: The index of the label within its label index."""
        ...
    @identifier.setter
    @abstractmethod
    def identifier(self, value: int):
        """The index of the label within its label index. Labels will automatically be assigned
        this when added to a document via labeler or add_labels."""
        ...
    @property
    @abstractmethod
    def start_index(self) -> int:
        """int: The index of the first character of the text covered by this label.
        """
        ...
    @start_index.setter
    @abstractmethod
    def start_index(self, value: int):
        ...
    @property
    @abstractmethod
    def end_index(self) -> int:
        """int: The index after the last character of the text covered by this label.
        """
        ...
    @end_index.setter
    @abstractmethod
    def end_index(self, value: int):
        ...
    @property
    def location(self) -> Location:
        """Location: A tuple of (start_index, end_index) used to perform sorting and
            comparison first based on start_index, then based on end_index.
        """
        return Location(self.start_index, self.end_index)
    @property
    def text(self):
        """str: The slice of document text covered by this label. Will retrieve from events server
        if it is not cached locally.
        """
        return self.document.text[self.start_index:self.end_index]
[docs]    @abstractmethod
    def shallow_fields_equal(self, other) -> bool:
        """Tests if the fields on this label and locations of references are the same as another
        label.
        Args:
            other: The other label to test.
        Returns:
            True if all of the fields are equal and the references are at the same locations.
        """
        pass 
    @abstractmethod
    def collect_floating_references(self, s):
        pass 
L = TypeVar('L', bound=Label)
_repr_local = threading.local()
[docs]class GenericLabel(Label):
    """Default implementation of the Label class which uses a dictionary to store attributes.
    Will be suitable for the majority of use cases for labels.
    Args:
        start_index (int): The index of the first character in text to be included in the label.
        end_index (int): The index after the last character in text to be included in the label.
    Keyword Args:
        document (~typing.Optional[Document]): The parent document of the label. This will be
            automatically set if a the label is created via labeler.
        **kwargs : Arbitrary, any other fields that should be added to the label, values must be
            json-serializable.
    Examples:
        >>> pos_tag = pos_tag_labeler(0, 5)
        >>> pos_tag.tag = 'NNS'
        >>> pos_tag.tag
        'NNS'
        >>> pos_tag2 = pos_tag_labeler(6, 10, tag='VB')
        >>> pos_tag2.tag
        'VB'
    """
    def __init__(self, start_index: int, end_index: int, *,
                 identifier: Optional[int] = None,
                 document: Optional['mtap.Document'] = None,
                 label_index_name: Optional['str'] = None,
                 fields: Optional[dict] = None,
                 reference_field_ids: Optional[dict] = None,
                 **kwargs):
        self._document = document
        self._label_index_name = label_index_name
        self._identifier = identifier
        self._start_index = int(start_index)
        self._end_index = int(end_index)
        if fields is None:
            self.fields = {}
        else:
            self.fields = fields
        if reference_field_ids is None:
            self.reference_field_ids = {}
        else:
            self.reference_field_ids = reference_field_ids
        self.reference_cache = {}
        for key, value in kwargs.items():
            setattr(self, key, value)
    @property
    def document(self) -> 'mtap.Document':
        return self._document
    @document.setter
    def document(self, document: 'mtap.Document'):
        self._document = document
    @property
    def label_index_name(self) -> str:
        return self._label_index_name
    @label_index_name.setter
    def label_index_name(self, value: str):
        self._label_index_name = value
    @property
    def identifier(self) -> int:
        return self._identifier
    @identifier.setter
    def identifier(self, value: int):
        self._identifier = value
    @property
    def start_index(self) -> int:
        return self._start_index
    @start_index.setter
    def start_index(self, start_index: int):
        self._start_index = start_index
    @property
    def end_index(self) -> int:
        return self._end_index
    @end_index.setter
    def end_index(self, end_index: int):
        self._end_index = end_index
    def _is_reserved(self, key):
        return key in self.__dict__.keys() or key in vars(GenericLabel) or key in vars(Label)
    def __getattr__(self, item):
        try:
            return self.fields[item]
        except KeyError:
            pass
        try:
            return self.reference_cache[item]
        except KeyError:
            pass
        try:
            ref_value = self.reference_field_ids[item]
            self.reference_cache[item] = _dereference(ref_value, self.document)
            return self.reference_cache[item]
        except KeyError:
            raise AttributeError('Key "{}" not in fields, reference cache, or reference ids.'
                                 .format(item))
    def __setattr__(self, key, value):
        if key in ('document', 'label_index_name', 'identifier', 'start_index', 'end_index',
                   '_document', '_label_index_name', '_identifier', '_start_index', '_end_index',
                   'fields', 'reference_field_ids', 'reference_cache'):
            object.__setattr__(self, key, value)
            return
        if self._is_reserved(key):
            raise ValueError('The key "{}" is a reserved key.'.format(key))
        is_ref = _is_referential(value, [id(self)])
        if is_ref:
            self.reference_cache[key] = value
        else:
            self.fields[key] = value
    def __eq__(self, other):
        if not isinstance(other, GenericLabel):
            return False
        if other is self:
            return True
        if not self.location == other.location:
            return False
        return self.shallow_fields_equal(other)
[docs]    def shallow_fields_equal(self, other):
        if not self.fields == other.fields:
            return False
        refs = set(self.reference_field_ids.keys()).union(self.reference_cache.keys())
        other_refs = set(other.reference_field_ids.keys()).union(other.reference_cache.keys())
        if not refs == other_refs:
            return False
        for k in refs:
            try:
                if self.reference_field_ids[k] == other.reference_field_ids[k]:
                    continue
            except KeyError:
                pass
            self_k = getattr(self, k)
            other_k = getattr(other, k)
            if not _collect_locations(self_k) == _collect_locations(other_k):
                return False
        return True 
    def __repr__(self):
        try:
            stack = _repr_local.stack
        except AttributeError:
            stack = set()
            _repr_local.stack = stack
        if id(self) in stack:
            return 'GenericLabel(...)'
        stack.add(id(self))
        attributes = [repr(self.start_index), repr(self.end_index)]
        for k, v in self.fields.items():
            attributes.append("{}={}".format(k, repr(v)))
        for k, v in self.reference_cache.items():
            attributes.append("{}={}".format(k, repr(v)))
        for k, v in self.reference_field_ids.items():
            if k not in self.reference_cache:
                attributes.append("{}=ref:{}".format(k, repr(v)))
        stack.remove(id(self))
        return "GenericLabel(".format() + ", ".join(attributes) + ")"
    def collect_floating_references(self, s):
        queue = Queue()
        for k, v in self.reference_cache.items():
            if v is not None:
                queue.put(v)
        while not queue.empty():
            o = queue.get_nowait()
            if isinstance(o, Label):
                if o.identifier is None:
                    s.add(id(o))
            elif isinstance(o, Mapping):
                for _, v in o.items():
                    if v is not None:
                        queue.put(v)
            elif isinstance(o, Sequence):
                for v in o:
                    if v is not None:
                        queue.put(v) 
[docs]def label(start_index: int,
          end_index: int,
          *, document: Optional['mtap.Document'] = None,
          **kwargs) -> GenericLabel:
    """An alias for :class:`GenericLabel`.
    Args:
        start_index (int): The index of the first character in text to be included in the label.
        end_index (int): The index after the last character in text to be included in the label.
        document (~typing.Optional[Document]): The parent document of the label. This will be
            automatically set if a the label is created via labeler.
        **kwargs : Arbitrary, any other fields that should be added to the label, values must be
            json-serializable.
    """
    return GenericLabel(start_index, end_index, document=document, **kwargs) 
def _staticize(labels: Sequence['data.Label'],
               document: 'mtap.Document',
               label_index_name: str) -> Tuple[List['data.Label'], Set[int]]:
    """Prepares a label index for serialization by finalizing sort order and setting label
    identifiers.
    Args:
        labels (~typing.Sequence[GenericLabel]): The labels in a label index.
    Returns:
        List['GenericLabel']: The labels sorted by position.
        Set[int]: A set of labels which a referenced by labels in this index.
    """
    labels = sorted(labels, key=lambda x: x.location)
    waiting_on = set()
    for i, lbl in enumerate(labels):
        lbl.document = document
        lbl.identifier = i
        lbl.label_index_name = label_index_name
    for lbl in labels:
        lbl.collect_floating_references(waiting_on)
    return labels, waiting_on
def _is_referential(o: Any, parents=None) -> bool:
    if parents is None:
        parents = [id(o)]
    if isinstance(o, (str, float, bool, int)) or o is None:
        return False
    elif isinstance(o, Label):
        return True
    elif isinstance(o, Mapping):
        map_is_ref = None
        for v in o.values():
            if id(v) in parents:
                raise ValueError('Recursive loop')
            x = _is_referential(v, parents + [id(v)])
            if map_is_ref is None:
                map_is_ref = x
            elif x != map_is_ref:
                raise TypeError('Label dictionaries cannot have mixes of references to labels'
                                'and primitive types.')
        return map_is_ref
    elif isinstance(o, Sequence):
        seq_is_ref = None
        for v in o:
            if id(v) in parents:
                raise ValueError('Recursive loop')
            x = _is_referential(v, parents + [id(v)])
            if seq_is_ref is None:
                seq_is_ref = x
            elif x != seq_is_ref:
                raise TypeError('Label lists cannot have mixes of references to labels'
                                'and primitive types.')
        return seq_is_ref
    else:
        raise TypeError('Unrecognized type')
def _dereference(o: Any, document: 'mtap.Document') -> Any:
    if o is None:
        return o
    if isinstance(o, str):
        label_index_name, label_id = o.split(':')
        label_index = document.labels[label_index_name]
        label_ = label_index[int(label_id)]
        return label_
    if isinstance(o, Mapping):
        replacement = {}
        for k, v in o.items():
            replacement[k] = _dereference(v, document)
        return replacement
    if isinstance(o, Sequence):
        replacement = [_dereference(v, document) for v in o]
        return replacement
def _collect_locations(o):
    if o is None:
        return None
    if isinstance(o, Label):
        return o.location
    if isinstance(o, Mapping):
        return {k: _collect_locations(v) for k, v in o.items()}
    if isinstance(o, Sequence):
        return [_collect_locations(v) for v in o]