Source code for muspy.base

"""Base classes.

This module defines the bases classes for MusPy objects.

Classes
-------

- Base
- ComplexBase

"""
from collections import OrderedDict
from inspect import isclass
from operator import attrgetter
from typing import Any, Callable, List, Mapping, Optional, Type, TypeVar

import yaml

__all__ = ["Base", "ComplexBase"]

BaseType = TypeVar("BaseType", bound="Base")
ComplexBaseType = TypeVar("ComplexBaseType", bound="ComplexBase")


class _OrderedDumper(yaml.SafeDumper):
    """A dumper that supports OrderedDict."""

    def increase_indent(self, flow=False, indentless=False):
        return super().increase_indent(flow, False)


def _dict_representer(dumper, data):
    return dumper.represent_mapping(
        yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, data.items()
    )


_OrderedDumper.add_representer(OrderedDict, _dict_representer)


def _yaml_dump(data):
    """Dump data to YAML, which supports OrderedDict.

    Code adapted from https://stackoverflow.com/a/21912744.
    """
    return yaml.dump(data, Dumper=_OrderedDumper, allow_unicode=True)


def _get_type_string(attr_type):
    """Return a string represeting acceptable type(s)."""
    if isinstance(attr_type, (list, tuple)):
        if len(attr_type) > 1:
            return (
                ", ".join([x.__name__ for x in attr_type[:-1]])
                + " or "
                + attr_type[-1].__name__
            )
        return attr_type[0].__name__
    return attr_type.__name__


[docs]class Base: """The base class for MusPy classes. This is the base class for MusPy classes. It provides two handy I/O methods---`from_dict` and `to_ordered_dict`. It also provides intuitive `__repr__` as well as methods `pretty_str` and `print` for beautifully printing the content. Hint ---- To implement a new class in MusPy, please inherit from this class and set the following class variables properly. - `_attributes`: An OrderedDict with attribute names as keys and their types as values. - `_optional_attributes`: A list of optional attribute names. - `_list_attributes`: A list of attributes that are lists. - `_sort_attributes`: A list of attributes used when being sorted, which will be passed to operator.attrgetter. Take :class:`muspy.Note` for example.:: _attributes = OrderedDict( [ ("time", int), ("duration", int), ("pitch", int), ("velocity", int), ("pitch_str", str), ] ) _optional_attributes = ["pitch_str"] _sort_attributes = ["time", "duration", "pitch"] See Also -------- :class:`muspy.ComplexBase` : A base class that supports advanced operations on list attributes. """ _attributes: Mapping[str, Any] = {} _optional_attributes: List[str] = [] _list_attributes: List[str] = [] _sort_attributes: List[str] = [] def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) def __repr__(self): to_join = [] for attr in self._attributes: value = getattr(self, attr) if attr in self._list_attributes: if not value: continue if len(value) > 3: to_join.append( attr + "=" + repr(value[:3])[:-1] + ", ...]" ) else: to_join.append(attr + "=" + repr(value)) elif value is not None: to_join.append(attr + "=" + repr(value)) return type(self).__name__ + "(" + ", ".join(to_join) + ")" def __eq__(self, other) -> bool: for attr in self._attributes: if getattr(self, attr) != getattr(other, attr): return False return True
[docs] @classmethod def from_dict(cls: Type[BaseType], dict_: Mapping) -> BaseType: """Return an instance constructed from a dictionary. Instantiate an object whose attributes and the corresponding values are given as a dictionary. Parameters ---------- dict_ : dict or mapping A dictionary that stores the attributes and their values as key-value pairs, e.g., `{"attr1": value1, "attr2": value2}`. Returns ------- Constructed object. """ kwargs = {} for attr, attr_type in cls._attributes.items(): value = dict_.get(attr) if value is None: if attr in cls._optional_attributes: continue raise TypeError("`{}` must not be None.".format(attr)) if isclass(attr_type) and issubclass(attr_type, Base): if attr in cls._list_attributes: kwargs[attr] = [attr_type.from_dict(v) for v in value] else: kwargs[attr] = attr_type.from_dict(value) else: kwargs[attr] = value return cls(**kwargs)
[docs] def to_ordered_dict(self, skip_none: bool = True) -> OrderedDict: """Return the object as an OrderedDict. Return an ordered dictionary that stores the attributes and their values as key-value pairs. Parameters ---------- skip_none : bool Whether to skip attributes with value None or those that are empty lists. Returns ------- OrderedDict A dictionary that stores the attributes and their values as key-value pairs, e.g., `{"attr1": value1, "attr2": value2}`. """ ordered_dict: OrderedDict = OrderedDict() for attr, attr_type in self._attributes.items(): value = getattr(self, attr) if attr in self._list_attributes: if not value and skip_none: continue if isclass(attr_type) and issubclass(attr_type, Base): ordered_dict[attr] = [v.to_ordered_dict() for v in value] else: ordered_dict[attr] = value elif value is None: if not skip_none: ordered_dict[attr] = None elif isclass(attr_type) and issubclass(attr_type, Base): ordered_dict[attr] = value.to_ordered_dict() else: ordered_dict[attr] = value return ordered_dict
[docs] def pretty_str(self) -> str: """Return the stored data as a string in a beautiful YAML-like format. Returns ------- str Stored data as a string in pretty YAML-like format. See Also -------- :meth:`muspy.Base.print` : Print the stored data in a beautiful YAML-like format. """ return _yaml_dump(self.to_ordered_dict())
[docs] def print(self): """Print the stored data in a beautiful YAML-like format. See Also -------- :meth:`muspy.Base.pretty_str` : Return the stored data as a string in a beautiful YAML-like format. """ print(self.pretty_str())
def _validate_attr_type(self, attr: str): attr_type = self._attributes[attr] value = getattr(self, attr) if value is None: if attr in self._optional_attributes: return raise TypeError("`{}` must not be None".format(attr)) if attr in self._list_attributes: if not isinstance(value, list): raise TypeError("`{}` must be a list.".format(attr)) for item in value: if not isinstance(item, attr_type): raise TypeError( "`{}` must be a list of type {}.".format( attr, _get_type_string(attr_type) ) ) elif not isinstance(value, attr_type): raise TypeError( "`{}` must be of type {}.".format( attr, _get_type_string(attr_type) ) )
[docs] def validate_type(self: BaseType, attr: Optional[str] = None) -> BaseType: """Raise an error if a certain attribute has an invalid type. This will apply recursively to an attribute's attributes. Parameters ---------- attr : str Attribute to validate. Defaults to validate all attributes. Returns ------- Object itself. See Also -------- :meth:`muspy.Base.is_valid_type` : Return True if an attribute has a valid type, otherwise False. :meth:`muspy.Base.validate` : Raise an error if a certain attribute has an invalid type or value. """ if attr is None: for attribute in self._attributes: self._validate_attr_type(attribute) else: self._validate_attr_type(attr) return self
def _validate(self, attr: str): attr_type = self._attributes[attr] if isclass(attr_type) and issubclass(attr_type, Base): if attr in self._list_attributes: if getattr(self, attr): for item in getattr(self, attr): item.validate() else: getattr(self, attr).validate() else: self._validate_attr_type(attr) if attr == "time" and getattr(self, "time") < 0: raise ValueError("`time` must be nonnegative.")
[docs] def validate(self: BaseType, attr: Optional[str] = None) -> BaseType: """Raise an error if a certain attribute has an invalid type or value. This will apply recursively to an attribute's attributes. Parameters ---------- attr : str Attribute to validate. Defaults to validate all attributes. Returns ------- Object itself. See Also -------- :meth:`muspy.Base.is_valid` : Return True if an attribute is valid, otherwise False. :meth:`muspy.Base.validate_type` : Raise an error if a certain attribute has an invalid type. """ if attr is None: for attribute in self._attributes: self._validate(attribute) else: self._validate(attr) return self
[docs] def is_valid_type(self, attr: Optional[str] = None) -> bool: """Return True if an attribute has a valid type, otherwise False. This will apply recursively to an attribute's attributes. Parameters ---------- attr : str Attribute to validate. Defaults to validate all attributes. Returns ------- bool Whether the attribute has a valid type. See Also -------- :meth:`muspy.Base.validate_type` : Raise an error if a certain attribute has an invalid type. :meth:`muspy.Base.is_valid` : Return True if an attribute is valid, otherwise False. """ try: self.validate_type(attr) except TypeError: return False return True
[docs] def is_valid(self, attr: Optional[str] = None) -> bool: """Return True if an attribute is valid, otherwise False. This will recursively apply to an attribute's attributes. Parameters ---------- attr : str Attribute to validate. Defaults to validate all attributes. Returns ------- bool Whether the attribute has a valid type and value. See Also -------- :meth:`muspy.Base.validate` : Raise an error if a certain attribute has an invalid type or value. :meth:`muspy.Base.is_valid_type` : Return True if an attribute has a valid type, otherwise False. """ try: self.validate(attr) except (TypeError, ValueError): return False return True
def _adjust_time(self, func: Callable[[int], int], attr: str): attr_type = self._attributes[attr] if attr == "time": if "time" in self._list_attributes: new_list = [func(item) for item in getattr(self, "time")] setattr(self, "time", new_list) else: setattr(self, "time", func(getattr(self, attr))) else: if isclass(attr_type) and issubclass(attr_type, Base): if attr in self._list_attributes: for item in getattr(self, attr): item.adjust_time(func) elif getattr(self, attr) is not None: getattr(self, attr).adjust_time(func)
[docs] def adjust_time( self: BaseType, func: Callable[[int], int], attr: Optional[str] = None ) -> BaseType: """Adjust the timing of time-stamped objects. This will apply recursively to an attribute's attributes. Parameters ---------- func : callable The function used to compute the new timing from the old timing, i.e., `new_time = func(old_time)`. attr : str Attribute to adjust. If None, adjust all attributes. Defaults to None. Returns ------- Object itself. """ if attr is None: for attribute in self._attributes: self._adjust_time(func, attribute) else: self._adjust_time(func, attr) return self
[docs]class ComplexBase(Base): """A base class that supports advanced operations on list attributes. This class extend the Base class with advanced operations on list attributes, including `append`, `remove_invalid`, `remove_duplicate` and `sort`. See Also -------- :class:`muspy.Base` : The base class for MusPy classes. """ def _append(self, obj): for attr in self._list_attributes: attr_type = self._attributes[attr] if isinstance(obj, attr_type): if isclass(attr_type) and issubclass(attr_type, Base): if getattr(self, attr) is None: setattr(self, attr, [obj]) else: getattr(self, attr).append(obj) return raise TypeError( "Cannot find a list attribute for type {}.".format( type(obj).__name__ ) )
[docs] def append(self: ComplexBaseType, obj) -> ComplexBaseType: """Append an object to the correseponding list. This will automatically determine the list attributes to append based on the type of the object. Parameters ---------- obj Object to append. """ self._append(obj) return self
def _remove_invalid(self, attr: str, recursive: bool): # Skip it if empty if not getattr(self, attr): return # Replace the old lis with a new list of only valid items attr_type = self._attributes[attr] value = getattr(self, attr) is_class = isclass(attr_type) if is_class and issubclass(attr_type, Base): new_value = [item for item in value if item.is_valid()] else: new_value = [item for item in value if isinstance(item, attr_type)] setattr(self, attr, new_value) # Apply recursively if recursive and is_class and issubclass(attr_type, ComplexBase): for value in getattr(self, attr): value.remove_invalid(recursive=recursive)
[docs] def remove_invalid( self: ComplexBaseType, attr: Optional[str] = None, recursive: bool = True, ) -> ComplexBaseType: """Remove invalid items from list attributes, others left unchanged. Parameters ---------- attr : str Attribute to validate. Defaults to validate all attributes. recursive : bool Whether to apply recursively. Defaults to True. Returns ------- Object itself. """ if attr is None: for attribute in self._list_attributes: self._remove_invalid(attribute, recursive) elif attr in self._list_attributes: self._remove_invalid(attr, recursive) else: raise TypeError("`{}` must be a list attribute.") return self
def _remove_duplicate(self, attr: str, recursive: bool): # Skip it if empty if not getattr(self, attr): return # Replace the old lis with a new list without duplicates attr_type = self._attributes[attr] value = getattr(self, attr) new_value = [value[0]] for item, next_item in zip(value[:-1], value[1:]): if item != next_item: new_value.append(next_item) setattr(self, attr, new_value) # Apply recursively if ( recursive and isclass(attr_type) and issubclass(attr_type, ComplexBase) ): for value in getattr(self, attr): value.sort(recursive=recursive)
[docs] def remove_duplicate( self: ComplexBaseType, attr: Optional[str] = None, recursive: bool = True, ) -> ComplexBaseType: """Remove duplicate items. Parameters ---------- attr : str Attribute to check. If None, check all attributes. Defaults to None. recursive : bool Whether to apply recursively. Defaults to True. Returns ------- Object itself. """ if attr is None: for attribute in self._list_attributes: self._remove_duplicate(attribute, recursive) elif attr in self._list_attributes: self._remove_duplicate(attr, recursive) else: raise TypeError("`{}` must be a list attribute.") return self
def _sort(self, attr: str, recursive: bool): # Skip it if empty if not getattr(self, attr): return # Sort the list attr_type = self._attributes[attr] if isclass(attr_type) and issubclass(attr_type, Base): # pylint: disable=protected-access if attr_type._sort_attributes: getattr(self, attr).sort( key=attrgetter(*attr_type._sort_attributes) ) # Apply recursively if recursive and issubclass(attr_type, ComplexBase): for value in getattr(self, attr): value.sort(recursive=recursive)
[docs] def sort( self: ComplexBaseType, attr: Optional[str] = None, recursive: bool = True, ) -> ComplexBaseType: """Sort a list attribute. Parameters ---------- attr : str Attribute to sort. If None, sort all attributes. Defaults to None. recursive : bool Whether to apply recursively. Defaults to True. Returns ------- Object itself. """ if attr is None: for attribute in self._list_attributes: self._sort(attribute, recursive) elif attr in self._list_attributes: self._sort(attr, recursive) else: raise TypeError("`{}` must be a list attribute.") return self