Source code for muspy.base

"""Base classes.

This module defines the bases classes for MusPy objects.

Classes
-------

- Base
- ComplexBase

"""
from collections import OrderedDict
from inspect import isclass
from operator import attrgetter
from typing import Any, Callable, List, Mapping, Optional, Type, TypeVar

import yaml

__all__ = ["Base", "ComplexBase"]

BaseType = TypeVar("BaseType", bound="Base")
ComplexBaseType = TypeVar("ComplexBaseType", bound="ComplexBase")


class _OrderedDumper(yaml.SafeDumper):
    """A dumper that supports OrderedDict."""

    def increase_indent(self, flow=False, indentless=False):
        return super().increase_indent(flow, False)


def _dict_representer(dumper, data):
    return dumper.represent_mapping(
        yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, data.items()
    )


_OrderedDumper.add_representer(OrderedDict, _dict_representer)


def _yaml_dump(data):
    """Dump data to YAML, which supports OrderedDict.

    Code adapted from https://stackoverflow.com/a/21912744.
    """
    return yaml.dump(data, Dumper=_OrderedDumper, allow_unicode=True)


def _get_type_string(attr_type):
    """Return a string represeting acceptable type(s)."""
    if isinstance(attr_type, (list, tuple)):
        if len(attr_type) > 1:
            return (
                ", ".join([x.__name__ for x in attr_type[:-1]])
                + " or "
                + attr_type[-1].__name__
            )
        return attr_type[0].__name__
    return attr_type.__name__


[docs]class Base: """Base class for MusPy classes. This is the base class for MusPy classes. It provides two handy I/O methods---`from_dict` and `to_ordered_dict`. It also provides intuitive `__repr__` as well as methods `pretty_str` and `print` for beautifully printing the content. Hint ---- To implement a new class in MusPy, please inherit from this class and set the following class variables properly. - `_attributes`: An OrderedDict with attribute names as keys and their types as values. - `_optional_attributes`: A list of optional attribute names. - `_list_attributes`: A list of attributes that are lists. Take :class:`muspy.Note` for example.:: _attributes = OrderedDict( [ ("time", int), ("duration", int), ("pitch", int), ("velocity", int), ("pitch_str", str), ] ) _optional_attributes = ["pitch_str"] See Also -------- :class:`muspy.ComplexBase` : Base class that supports advanced operations on list attributes. """ _attributes: Mapping[str, Any] = {} _optional_attributes: List[str] = [] _list_attributes: List[str] = [] def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) def __repr__(self) -> str: to_join = [] for attr in self._attributes: value = getattr(self, attr) if attr in self._list_attributes: if not value: continue if len(value) > 3: to_join.append( attr + "=" + repr(value[:3])[:-1] + ", ...]" ) else: to_join.append(attr + "=" + repr(value)) elif value is not None: to_join.append(attr + "=" + repr(value)) return type(self).__name__ + "(" + ", ".join(to_join) + ")" def __eq__(self, other) -> bool: for attr in self._attributes: if getattr(self, attr) != getattr(other, attr): return False return True
[docs] @classmethod def from_dict(cls: Type[BaseType], dict_: Mapping) -> BaseType: """Return an instance constructed from a dictionary. Instantiate an object whose attributes and the corresponding values are given as a dictionary. Parameters ---------- dict_ : dict or mapping A dictionary that stores the attributes and their values as key-value pairs, e.g., `{"attr1": value1, "attr2": value2}`. Returns ------- Constructed object. """ kwargs = {} for attr, attr_type in cls._attributes.items(): value = dict_.get(attr) if value is None: if attr in cls._optional_attributes: continue raise TypeError("`{}` must not be None.".format(attr)) if isclass(attr_type) and issubclass(attr_type, Base): if attr in cls._list_attributes: kwargs[attr] = [attr_type.from_dict(v) for v in value] else: kwargs[attr] = attr_type.from_dict(value) else: kwargs[attr] = value return cls(**kwargs)
[docs] def to_ordered_dict(self, skip_none: bool = True) -> OrderedDict: """Return the object as an OrderedDict. Return an ordered dictionary that stores the attributes and their values as key-value pairs. Parameters ---------- skip_none : bool Whether to skip attributes with value None or those that are empty lists. Defaults to True. Returns ------- OrderedDict A dictionary that stores the attributes and their values as key-value pairs, e.g., `{"attr1": value1, "attr2": value2}`. """ ordered_dict: OrderedDict = OrderedDict() for attr, attr_type in self._attributes.items(): value = getattr(self, attr) if attr in self._list_attributes: if not value and skip_none: continue if isclass(attr_type) and issubclass(attr_type, Base): ordered_dict[attr] = [ v.to_ordered_dict(skip_none=skip_none) for v in value ] else: ordered_dict[attr] = value elif value is None: if not skip_none: ordered_dict[attr] = None elif isclass(attr_type) and issubclass(attr_type, Base): ordered_dict[attr] = value.to_ordered_dict(skip_none=skip_none) else: ordered_dict[attr] = value return ordered_dict
[docs] def pretty_str(self, skip_none: bool = True) -> str: """Return the attributes as a string in a YAML-like format. Parameters ---------- skip_none : bool Whether to skip attributes with value None or those that are empty lists. Defaults to True. Returns ------- str Stored data as a string in a YAML-like format. See Also -------- :meth:`muspy.Base.print` : Print the attributes in a YAML-like format. """ return _yaml_dump(self.to_ordered_dict(skip_none=skip_none))
[docs] def print(self, skip_none: bool = True): """Print the attributes in a YAML-like format. Parameters ---------- skip_none : bool Whether to skip attributes with value None or those that are empty lists. Defaults to True. See Also -------- :meth:`muspy.Base.pretty_str` : Return the the attributes as a string in a YAML-like format. """ print(self.pretty_str(skip_none=skip_none))
def _validate_attr_type(self, attr: str, recursive: bool): attr_type = self._attributes[attr] value = getattr(self, attr) if value is None: if attr in self._optional_attributes: return raise TypeError("`{}` must not be None".format(attr)) if attr in self._list_attributes: if not isinstance(value, list): raise TypeError("`{}` must be a list.".format(attr)) for item in value: if not isinstance(item, attr_type): raise TypeError( "`{}` must be a list of type {}.".format( attr, _get_type_string(attr_type) ) ) elif not isinstance(value, attr_type): raise TypeError( "`{}` must be of type {}.".format( attr, _get_type_string(attr_type) ) ) # Apply recursively if recursive and isclass(attr_type) and issubclass(attr_type, Base): if attr in self._list_attributes: for item in getattr(self, attr): item.validate_type(recursive=recursive) elif getattr(self, attr) is not None: getattr(self, attr).validate_type(recursive=recursive)
[docs] def validate_type( self: BaseType, attr: Optional[str] = None, recursive: bool = True, ) -> BaseType: """Raise an error if an attribute is of an invalid type. This will apply recursively to an attribute's attributes. Parameters ---------- attr : str Attribute to validate. Defaults to validate all attributes. recursive : bool Whether to apply recursively. Defaults to True. Returns ------- Object itself. See Also -------- :meth:`muspy.Base.is_valid_type` : Return True if an attribute is of a valid type. :meth:`muspy.Base.validate` : Raise an error if an attribute has an invalid type or value. """ if attr is None: for attribute in self._attributes: self._validate_attr_type(attribute, recursive) else: self._validate_attr_type(attr, recursive) return self
def _validate(self, attr: str, recursive: bool): attr_type = self._attributes[attr] if isclass(attr_type) and issubclass(attr_type, Base): if attr in self._list_attributes: if getattr(self, attr): for item in getattr(self, attr): item.validate() else: getattr(self, attr).validate() else: # Set recursive=False to avoid repeated checks invoked when # calling `validate` recursively self._validate_attr_type(attr, False) if attr == "time" and getattr(self, "time") < 0: raise ValueError("`time` must be nonnegative.") # Apply recursively if recursive and isclass(attr_type) and issubclass(attr_type, Base): if attr in self._list_attributes: for item in getattr(self, attr): item.validate(recursive=recursive) elif getattr(self, attr) is not None: getattr(self, attr).validate(recursive=recursive)
[docs] def validate( self: BaseType, attr: Optional[str] = None, recursive: bool = True, ) -> BaseType: """Raise an error if an attribute has an invalid type or value. This will apply recursively to an attribute's attributes. Parameters ---------- attr : str Attribute to validate. Defaults to validate all attributes. recursive : bool Whether to apply recursively. Defaults to True. Returns ------- Object itself. See Also -------- :meth:`muspy.Base.is_valid` : Return True if an attribute has a valid type and value. :meth:`muspy.Base.validate_type` : Raise an error if an attribute is of an invalid type. """ if attr is None: for attribute in self._attributes: self._validate(attribute, recursive) else: self._validate(attr, recursive) return self
[docs] def is_valid_type( self, attr: Optional[str] = None, recursive: bool = True, ) -> bool: """Return True if an attribute is of a valid type. This will apply recursively to an attribute's attributes. Parameters ---------- attr : str Attribute to validate. Defaults to validate all attributes. recursive : bool Whether to apply recursively. Defaults to True. Returns ------- bool Whether the attribute is of a valid type. recursive : bool Whether to apply recursively. Defaults to True. See Also -------- :meth:`muspy.Base.validate_type` : Raise an error if a certain attribute is of an invalid type. :meth:`muspy.Base.is_valid` : Return True if an attribute has a valid type and value. """ try: self.validate_type(attr, recursive) except TypeError: return False return True
[docs] def is_valid( self, attr: Optional[str] = None, recursive: bool = True, ) -> bool: """Return True if an attribute has a valid type and value. This will recursively apply to an attribute's attributes. Parameters ---------- attr : str Attribute to validate. Defaults to validate all attributes. recursive : bool Whether to apply recursively. Defaults to True. Returns ------- bool Whether the attribute has a valid type and value. See Also -------- :meth:`muspy.Base.validate` : Raise an error if an attribute has an invalid type or value. :meth:`muspy.Base.is_valid_type` : Return True if an attribute is of a valid type. """ try: self.validate(attr, recursive) except (TypeError, ValueError): return False return True
def _adjust_time( self, func: Callable[[int], int], attr: str, recursive: bool ): attr_type = self._attributes[attr] if attr == "time": if "time" in self._list_attributes: new_list = [func(item) for item in getattr(self, "time")] setattr(self, "time", new_list) else: setattr(self, "time", func(getattr(self, attr))) elif recursive and isclass(attr_type) and issubclass(attr_type, Base): if attr in self._list_attributes: for item in getattr(self, attr): item.adjust_time(func, recursive=recursive) elif getattr(self, attr) is not None: getattr(self, attr).adjust_time(func, recursive=recursive)
[docs] def adjust_time( self: BaseType, func: Callable[[int], int], attr: Optional[str] = None, recursive: bool = True, ) -> BaseType: """Adjust the timing of time-stamped objects. Parameters ---------- func : callable The function used to compute the new timing from the old timing, i.e., `new_time = func(old_time)`. attr : str Attribute to adjust. Defaults to adjust all attributes. recursive : bool Whether to apply recursively. Defaults to True. Returns ------- Object itself. """ if attr is None: for attribute in self._attributes: print(self) self._adjust_time(func, attribute, recursive) else: self._adjust_time(func, attr, recursive) return self
[docs]class ComplexBase(Base): """Base class that supports advanced operations on list attributes. This class extend the Base class with advanced operations on list attributes, including `append`, `remove_invalid`, `remove_duplicate` and `sort`. See Also -------- :class:`muspy.Base` : Base class for MusPy classes. """ def _append(self, obj): for attr in self._list_attributes: attr_type = self._attributes[attr] if isinstance(obj, attr_type): if isclass(attr_type) and issubclass(attr_type, Base): if getattr(self, attr) is None: setattr(self, attr, [obj]) else: getattr(self, attr).append(obj) return raise TypeError( "Cannot find a list attribute for type {}.".format( type(obj).__name__ ) )
[docs] def append(self: ComplexBaseType, obj) -> ComplexBaseType: """Append an object to the correseponding list. This will automatically determine the list attributes to append based on the type of the object. Parameters ---------- obj Object to append. """ self._append(obj) return self
def _remove_invalid(self, attr: str, recursive: bool): # Skip it if empty if not getattr(self, attr): return attr_type = self._attributes[attr] value = getattr(self, attr) is_class = isclass(attr_type) # NOTE: The ordering mathers here. We first apply recursively # and later check to the currect object so that something that # can be fixed in a lower level would not make the high-level # object to be removed. # Apply recursively if recursive and is_class and issubclass(attr_type, ComplexBase): for value in getattr(self, attr): value.remove_invalid(recursive=recursive) # Replace the old list with a new list of only valid items if is_class and issubclass(attr_type, Base): new_value = [item for item in value if item.is_valid()] else: new_value = [item for item in value if isinstance(item, attr_type)] setattr(self, attr, new_value)
[docs] def remove_invalid( self: ComplexBaseType, attr: Optional[str] = None, recursive: bool = True, ) -> ComplexBaseType: """Remove invalid items from a list attribute. Parameters ---------- attr : str Attribute to validate. Defaults to validate all attributes. recursive : bool Whether to apply recursively. Defaults to True. Returns ------- Object itself. """ if attr is None: for attribute in self._list_attributes: self._remove_invalid(attribute, recursive) elif attr in self._list_attributes: self._remove_invalid(attr, recursive) else: raise TypeError("`{}` must be a list attribute.") return self
def _remove_duplicate(self, attr: str, recursive: bool): # Skip it if empty if not getattr(self, attr): return # Replace the old lis with a new list without duplicates attr_type = self._attributes[attr] value = getattr(self, attr) new_value = [value[0]] for item, next_item in zip(value[:-1], value[1:]): if item != next_item: new_value.append(next_item) setattr(self, attr, new_value) # Apply recursively if ( recursive and isclass(attr_type) and issubclass(attr_type, ComplexBase) ): for value in getattr(self, attr): value.remove_duplicate(recursive=recursive)
[docs] def remove_duplicate( self: ComplexBaseType, attr: Optional[str] = None, recursive: bool = True, ) -> ComplexBaseType: """Remove duplicate items from a list attribute. Parameters ---------- attr : str Attribute to check. Defaults to check all attributes. recursive : bool Whether to apply recursively. Defaults to True. Returns ------- Object itself. """ if attr is None: for attribute in self._list_attributes: self._remove_duplicate(attribute, recursive) elif attr in self._list_attributes: self._remove_duplicate(attr, recursive) else: raise TypeError("`{}` must be a list attribute.") return self
def _sort(self, attr: str, recursive: bool): # Skip it if empty if not getattr(self, attr): return # Sort the list attr_type = self._attributes[attr] if isclass(attr_type) and issubclass(attr_type, Base): # pylint: disable=protected-access if "time" in attr_type._attributes: getattr(self, attr).sort(key=attrgetter("time")) # Apply recursively if recursive and issubclass(attr_type, ComplexBase): for value in getattr(self, attr): value.sort(recursive=recursive)
[docs] def sort( self: ComplexBaseType, attr: Optional[str] = None, recursive: bool = True, ) -> ComplexBaseType: """Sort a list attribute. Parameters ---------- attr : str Attribute to sort. Defaults to sort all attributes. recursive : bool Whether to apply recursively. Defaults to True. Returns ------- Object itself. """ if attr is None: for attribute in self._list_attributes: self._sort(attribute, recursive) elif attr in self._list_attributes: self._sort(attr, recursive) else: raise TypeError("`{}` must be a list attribute.") return self