"""
Contains tools to easily generate stubs for the most common pyfar Classes.
Stubs are used instead of pyfar objects for testing functions that have pyfar
objects as input arguments. This makes testing such functions independent from
the pyfar objects themselves and helps to find bugs.
"""
import numpy as np
import deepdiff
from copy import deepcopy
from unittest import mock
from pyfar import Signal, TimeData, FrequencyData
from pyfar.io import _codec
[docs]
def signal_stub(time, freq, sampling_rate, fft_norm, is_complex=False):
"""Function to generate stub of pyfar Signal class based on MagicMock.
The properties of the signal are set without any further check.
Parameters
----------
time : ndarray
Time data
freq : ndarray
Frequency data
sampling_rate : float
Sampling rate
fft_norm : 'unitary', 'amplitude', 'rms', 'power', 'psd'
See documentation of pyfar.fft.normalization.
is_complex : bool
Specifies if the underlying time domain data are complex
real-valued.
Returns
-------
signal
stub of pyfar Signal class
"""
# Use MagicMock and side_effect to mock __getitem__
# See "Mocking a dictionary with MagicMock",
# https://het.as.utexas.edu/HET/Software/mock/examples.html
def getitem(slice_idx):
time = np.atleast_2d(signal.time[slice_idx])
freq = np.atleast_2d(signal.freq[slice_idx])
item = signal_stub(
time,
freq,
signal.sampling_rate,
signal.fft_norm)
return item
def find_nearest_time(times):
samples = np.zeros(len(times), dtype=int)
for idx, time in enumerate(times):
samples[idx] = np.argmin(np.abs(signal.times-time))
return samples
def find_nearest_frequency(freqs):
bins = np.zeros(len(freqs), dtype=int)
for idx, freq in enumerate(freqs):
bins[idx] = np.argmin(np.abs(signal.frequencies-freq))
return bins
signal = mock.MagicMock(
spec_set=Signal(time, sampling_rate, domain='time',
is_complex=is_complex))
signal.time = np.atleast_2d(time)
signal.freq = np.atleast_2d(freq)
signal.sampling_rate = sampling_rate
signal.fft_norm = fft_norm
signal._complex = is_complex
signal.n_samples = signal.time.shape[-1]
signal.n_bins = signal.freq.shape[-1]
signal.cshape = signal.time.shape[:-1]
signal.times = np.atleast_1d(
np.arange(0, signal.n_samples) / sampling_rate)
signal.frequencies = np.atleast_1d(
np.fft.rfftfreq(signal.n_samples, 1 / sampling_rate))
signal.__getitem__.side_effect = getitem
signal.find_nearest_time = find_nearest_time
signal.find_nearest_frequency = find_nearest_frequency
return signal
[docs]
def time_data_stub(time, times):
"""Function to generate stub of pyfar TimeData class based on MagicMock.
The properties of the signal are set without any further check.
Parameters
----------
time : ndarray
Time data
times : ndarray
Times of time in second
Returns
-------
time_data
stub of pyfar TimeData class
"""
# Use MagicMock and side_effect to mock __getitem__
# See "Mocking a dictionary with MagicMock",
# https://het.as.utexas.edu/HET/Software/mock/examples.html
def getitem(slice_idx):
time = np.atleast_2d(time_data.time[slice_idx])
item = time_data_stub(time, time_data.times)
return item
time_data = mock.MagicMock(
spec_set=TimeData(time, times))
time_data.time = np.atleast_2d(time)
time_data.times = np.atleast_1d(times)
time_data.domain = 'time'
time_data._complex = False
time_data.n_samples = time_data.time.shape[-1]
time_data.cshape = time_data.time.shape[:-1]
time_data.__getitem__.side_effect = getitem
return time_data
[docs]
def frequency_data_stub(freq, frequencies):
"""
Function to generate stub of pyfar FrequencyData class based onMagicMock.
The properties of the signal are set without any further check.
Parameters
----------
freq : ndarray
Frequency data
frequencies : ndarray
Frequencies of freq in Hz
Returns
-------
frequency_data
stub of pyfar FrequencyData class
"""
# Use MagicMock and side_effect to mock __getitem__
# See "Mocking a dictionary with MagicMock",
# https://het.as.utexas.edu/HET/Software/mock/examples.html
def getitem(slice_idx):
freq = np.atleast_2d(frequency_data.freq[slice_idx])
item = frequency_data_stub(freq, frequency_data.frequencies)
return item
frequency_data = mock.MagicMock(
spec_set=FrequencyData(freq, frequencies))
frequency_data.freq = np.atleast_2d(freq)
frequency_data.frequencies = np.atleast_1d(frequencies)
frequency_data.domain = 'freq'
frequency_data.n_bins = frequency_data.freq.shape[-1]
frequency_data.cshape = frequency_data.freq.shape[:-1]
frequency_data.__getitem__.side_effect = getitem
return frequency_data
[docs]
def impulse_func(delay, n_samples, fft_norm, cshape):
"""Generate time and frequency data of delta impulse.
Parameters
----------
delay : ndarray, int
Delay in samples
n_samples : int
Number of samples
fft_norm : 'none', 'rms'
See documentation of pyfar.fft.normalization.
cshape : tuple
Channel shape
Returns
-------
time : ndarray, float
time vector
freq : ndarray, complex
Spectrum
"""
# Convert delay to array
delay = np.atleast_1d(delay)
if np.shape(delay) != cshape:
raise ValueError("Shape of delay needs to equal cshape.")
if delay.max() >= n_samples:
raise ValueError("Delay is larger than number of samples, "
f"which is {n_samples}")
# Time vector
time = np.zeros(cshape + (n_samples,))
for idx, d in np.ndenumerate(delay):
time[idx + (d,)] = 1
# Spectrum
n_bins = int(n_samples / 2) + 1
bins = np.broadcast_to(np.arange(n_bins), (cshape + (n_bins,)))
freq = np.exp(-1j * 2 * np.pi * bins * delay[..., np.newaxis] / n_samples)
# Normalization
freq = _normalization(freq, n_samples, fft_norm)
return time, freq
[docs]
def sine_func(frequency, sampling_rate, n_samples, fft_norm, cshape):
"""Generate time and frequency data of sine signal.
The frequency is adjusted resulting in a fully periodic signal in the
given time interval.
Parameters
----------
frequency : float
Frequency of sine
sampling_rate : float
Sampling rate
n_samples : int
Number of samples
fft_norm : 'none', 'rms'
See documentation of pyfar.fft.normalization.
cshape : tuple
Channel shape
Returns
-------
time : ndarray, float
time vector
freq : ndarray, complex
frequency vector
frequency : float
adjusted frequency
"""
# Convert frequency to array
frequency = np.atleast_1d(frequency)
if np.shape(frequency) != cshape:
raise ValueError("Shape of frequency needs to equal cshape.")
if np.any(frequency >= sampling_rate / 2):
raise ValueError(f"Frequency is larger than Nyquist frequency,"
f"which is {sampling_rate/2}.")
# Round to the nearest frequency bin
n_periods = np.floor(n_samples / sampling_rate * frequency)
frequency = n_periods * sampling_rate / n_samples
# Time vector
times = np.arange(0, n_samples) / sampling_rate
times = np.broadcast_to(times, (cshape + (n_samples,)))
time = np.sin(2 * np.pi * frequency[..., np.newaxis] * times)
# Spectrum
n_bins = int(n_samples / 2) + 1
freq = np.zeros(cshape + (n_bins,), dtype=complex)
for idx, f in np.ndenumerate(frequency):
f_bin = int(f / sampling_rate * n_samples)
freq[idx + (f_bin,)] = -0.5j * float(n_samples)
# Normalization
freq = _normalization(freq, n_samples, fft_norm)
return time, freq, frequency
[docs]
def noise_func(sigma, n_samples, cshape):
"""Generate time and frequency data of zero-mean, gaussian white noise,
RMS FFT normalization.
Parameters
----------
sigma : float
Standard deviation
n_samples : int
Number of samples
cshape : tuple
Channel shape
Returns
-------
time : ndarray, float
time vector
freq : ndarray, complex
Spectrum
"""
# Time vector
rng = np.random.default_rng(1000)
time = sigma * rng.standard_normal((cshape + (n_samples,)))
freq = np.fft.rfft(time)
norm = 1 / n_samples / np.sqrt(2) * 2
freq *= norm
return time, freq
def _normalization(freq, n_samples, fft_norm):
"""Normalized spectrum as defined in _[1],
see documentation of pyfar.fft.normalization.
Parameters
----------
freq : ndarray, complex
frequency data
n_samples : int
Number of samples
fft_norm : 'none', 'rms'
See documentation of pyfar.fft.normalization.
Returns
-------
freq
Normalized frequency data
References
----------
.. [1] J. Ahrens, C. Andersson, P. Höstmad, and W. Kropp, “Tutorial on
Scaling of the Discrete Fourier Transform and the Implied Physical
Units of the Spectra of Time-Discrete Signals,” Vienna, Austria,
May 2020, p. e-Brief 600.
"""
norm = np.ones_like(freq)
if fft_norm == 'rms':
# Equation 4 in Ahrens et al. 2020
norm /= n_samples
# Equation 8 and 10 in Ahrens et al. 2020
if n_samples % 2 != 0:
norm[..., 1:] *= np.sqrt(2)
else:
norm[..., 1:-1] *= np.sqrt(2)
elif fft_norm != 'none':
raise ValueError(("norm type must be 'none' or 'rms', "
f"but is '{fft_norm}'"))
freq_norm = norm * freq
return freq_norm
[docs]
def any_ndarray():
"""Return an arbitrary ndarray for testing purposes."""
return np.arange(0, 24).reshape((2, 3, 4))
[docs]
def dict_of_builtins():
"""
Return a dictionary that contains all builtin types that can
be written to and read from disk.
"""
typename_instance = {}
for type_ in _codec._supported_builtin_types():
try:
typename_instance[type_.__name__] = type_(42)
except TypeError:
typename_instance[type_.__name__] = type_([42])
return typename_instance
[docs]
class AnyClass:
"""Placeholder class."""
def __init__(self, x=42):
self.x = x
[docs]
class NoEncodeClass:
"""Placeholder class to Raise NotImplementedError for `_encode`."""
def __init__(self, x=42):
self.x = x
[docs]
class NoDecodeClass:
"""Placeholder class to Raise NotImplementedError for `_decode`."""
def __init__(self, x=42):
self.x = x
[docs]
def copy(self):
"""Return a deep copy of the Orientations object."""
return deepcopy(self)
def _encode(self):
"""Return dictionary for the encoding."""
return self.copy().__dict__
[docs]
class FlatData:
"""Class only containing flat data and methods.
"""
def __init__(self, m=49):
self.signal = any_ndarray()
self._m = m
def _encode(self):
"""Return dictionary for the encoding."""
return self.copy().__dict__
@classmethod
def _decode(cls, obj_dict):
"""Decode object based on its respective `_encode` counterpart."""
obj = cls()
obj.__dict__.update(obj_dict)
return obj
[docs]
def copy(self):
"""Return a deep copy of the Orientations object."""
return deepcopy(self)
def __eq__(self, other):
"""Compare two FlatData objects."""
return not deepdiff.DeepDiff(self, other)
[docs]
class NestedData:
"""Class containing nested data such as lists, dicts and other objects
as well as methods. The purpose of this class is, to define and test
general requirements for the encoding and decoding process.
"""
def __init__(self, n, comment, matrix, subobj, mylist, mydict):
self._n = n
self._comment = comment
self._matrix = matrix
self._subobj = subobj
self._list = mylist
self._dict = mydict
self._complex = 3 + 4j
print('foo')
self._tuple = (1, 2, 3)
self._set = {'a', 1, 2}
self._frozenset = frozenset(('a', 1, 2))
[docs]
@classmethod
def create(cls):
"""Create a NestedData object with arbitrary data."""
n = 42
comment = 'My String'
matrix = any_ndarray()
subobj = FlatData()
mylist = [1, np.int32, np.arange(10), FlatData()]
mydict = {
'number': 1,
'numpy-type': np.int32,
'numpy-ndarray': np.arange(10),
'subobject': FlatData(-1),
'complex-number': 3 + 4j,
'a tuple': (1, 2, 3),
'a set': {'a', 1, 2},
'a frozenset': frozenset(('a', 1, 2))}
return NestedData(
n, comment, matrix, subobj, mylist, mydict)
def _encode(self):
"""Return dictionary for the encoding."""
return self.copy().__dict__
@classmethod
def _decode(cls, obj_dict):
obj = cls(
obj_dict['_n'],
obj_dict['_comment'],
obj_dict['_matrix'],
obj_dict['_subobj'],
obj_dict['_list'],
obj_dict['_dict'])
obj.__dict__.update(obj_dict)
return obj
[docs]
def copy(self):
"""Return a deep copy of the Orientations object."""
return deepcopy(self)
def __eq__(self, other):
"""Compare two NestedData objects."""
return not deepdiff.DeepDiff(self, other)
[docs]
def stub_str_to_type():
"""Stubs `_codec.str_to_type` for tests that use general data structures.
"""
def side_effect(type_str):
if type_str == "BuiltinsWrapper":
return _codec.BuiltinsWrapper
else:
return {
'AnyClass': type(AnyClass()),
'NoEncodeClass': type(NoEncodeClass()),
'NoDecodeClass': type(NoDecodeClass()),
'FlatData': type(FlatData()),
'NestedData': type(NestedData.create()),
}.get(type_str)
return mock.MagicMock(side_effect=side_effect)
[docs]
def stub_is_pyfar_type():
"""Stubs `_codec._is_pyfar_type` for tests that use general data
structures.
"""
def side_effect(obj):
type_str = obj if isinstance(obj, str) else type(obj).__name__
return type_str in [
'NestedData', 'FlatData', 'NoEncodeClass', 'NoDecodeClass']
return mock.MagicMock(side_effect=side_effect)