from typing import Tuple
import scipy.signal
import numpy as np
from floodlight import XY
from floodlight.utils.types import Numeric
def _get_filterable_and_short_sequences(
data: np.ndarray, min_signal_len: int
) -> Tuple[np.ndarray, np.ndarray]:
"""Returns start and end indices of continuous, filterable sequences and sequences
to short for filtering with the specified filter.
Parameters
----------
data: np.ndarray
Array of shape (T,) potentially containing NaNs.
min_signal_len: int
The minimum signal length that the specified filter can be applied on.
Returns
-------
filterable_sequences: np.ndarray
Two-dimensional array of shape (N, 2) and form
``[[sequence_start_idx, sequence_end_idx]]`` containing start and end indices of
N filterable sequences in the original data, ordered ascendingly. A sequence is
filterable when it doesn't contain NaNs and is at least as long as the minimum
window length of the specified filter.
short_sequences: np.ndarray
Two-dimensional array of shape (N, 2) and form
``[[sequence_start_idx, sequence_end_idx]]`` containing start and end indices of
N sequences in the original data that don't contain NaNs but are to short to
apply the specified filter on.
"""
if data.ndim != 1:
raise ValueError(
f"Expected input data to be one-dimensional. Got {data.ndim}-dimensional "
f"data instead."
)
# Convert possible None-types in data to np.NaN
data = np.array(data, dtype=float)
# indices where nans and numbers are next to each other
change_points = np.where(np.diff(np.isnan(data), prepend=np.nan, append=np.nan))[0]
sequences = np.array(
[
(change_points[i], change_points[i + 1])
for i in range(len(change_points) - 1)
]
)
# which sequences contain NaNs
seq_is_nan = np.where(np.isnan(data[sequences[:, 0]]), False, True)
# remove sequences containing NaNs
non_nan_sequences = sequences[seq_is_nan == 1]
# split remaining sequences into filterable and short
filterable_sequences = non_nan_sequences[
(non_nan_sequences[:, 1] - non_nan_sequences[:, 0]) > min_signal_len
]
short_sequences = non_nan_sequences[
(non_nan_sequences[:, 1] - non_nan_sequences[:, 0]) <= min_signal_len
]
return filterable_sequences, short_sequences
def _filter_sequence_butterworth_lowpass(
signal: np.ndarray,
order: int = 3,
Wn: Numeric = 1,
framerate: Numeric = None,
**kwargs,
) -> np.ndarray:
"""Filters the incoming signal with a digital Butterworth lowpass filter.
Wrapper for combined application of the `scipy.signal.butter <https://docs.scipy.
org/doc/scipy/reference/generated/scipy.signal.butter.html>`__ and `scipy.signal.
filtfilt <https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.
filtfilt.html>`_ functions.
Parameters
----------
signal: np.ndarray
Array of shape (T, N) containing the signal to be smoothed with T frames and N
independent signals. Corresponds to the argument ``x`` from the `scipy.signal.
filtfilt <https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.
filtfilt.html>`_ function.
order: int, optional
The order of the filter. Corresponds to the argument ``N`` from the `scipy.
signal. butter <https://docs.scipy.org/doc/scipy/reference/generated/scipy.
signal.butter.html>`_ function. Default is 3.
Wn: Numeric, optional
The critical cutoff frequency. Corresponds to the argument ``Wn`` from the
`scipy.signal.butter <https://docs.scipy.org/doc/scipy/reference/generated/scipy
.signal.butter.html>`_ function. Default is 1.
framerate: Numeric, optional
The sampling frequency of the signal. Corresponds to the argument ``fs`` from
the `scipy.signal.butter <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.butter.html>`_ function.
kwargs:
Optional arguments {'padtype', 'padlen', 'method', 'irlen'} that can be passed
to the `scipy.signal.filtfilt <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.filtfilt.html>`_ function.
Returns
-------
signal_filtered: np.array
Signal filtered by the Butterworth filter.
"""
# Calculation of filter coefficients
coeffs = scipy.signal.butter(
order,
Wn,
btype="lowpass",
output="ba",
fs=framerate,
)
# applying the filter to the data
signal_filtered = scipy.signal.filtfilt(
coeffs[0], coeffs[1], signal, axis=0, **kwargs
)
return signal_filtered
[docs]def butterworth_lowpass(
xy: XY, order: int = 3, Wn: Numeric = 1, remove_short_seqs: bool = False, **kwargs
) -> XY:
"""Applies a digital Butterworth lowpass-filter to a XY data object. [1]_
For filtering, the `scipy.filter.butter <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.butter.html>`_ and the `scipy.signal.filtfilt <https://docs.
scipy.org/doc/scipy/reference/generated/scipy.signal.filtfilt.html>`_ functions are
used. This function provides a convenience access to both functions, directly
applying the filter to all non-NaN sequences in all columns.
Parameters
----------
xy: XY
Floodlight XY Data object.
order: int, optional
The order of the filter. Corresponds to the argument ``N`` from the `scipy.
signal.butter <https://docs.scipy.org/doc/scipy/reference/generated/scipy.
signal.butter.html>`_ function. Default is 3
Wn: Numeric, optional
The normalized critical cutoff frequency. Corresponds to the argument ``Wn``
from the `scipy.signal.butter <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.butter.html>`_ function. Default is 1.
remove_short_seqs: bool, optional
If True, sequences that are to short for the filter with the specified settings
are replaced with np.NaNs. If False, they are kept unfiltered. Default is False.
kwargs:
Optional arguments {'padtype', 'padlen', 'method', 'irlen'} that can be passed
to the `scipy.signal.filtfilt <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.filtfilt.html>`_ function.
Returns
-------
xy_filtered: XY
XY object with position data filtered by designed Butterworth low pass filter.
Notes
-----
The values of the input data are assumed to be numerical. Missing data is assumed
to be either np.NaN or None. The Butterworth-filter requires a minimum signal length
depending on the settings. A signal is a sequence of data in the XY-object that is
not interrupted by missing values. The minimum signal length is defined as
:math:`3 \\cdot (order + 1)`. The treatment of signals shorter than the minimum
signal length are specified with the ``remove_short_sequence``-argument, where True
will replace these sequences with np.NaNs ond False will keep the sequences in the
data unfiltered.
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from floodlight import XY
>>> from floodlight.transforms.filter import butterworth_lowpass
We first generate a noisy XY-object to smooth.
>>> t = np.linspace(-5, 5, 1000)
>>> player_x = np.sin(t) * t + np.random.rand(1000)
>>> player_x[450:495] = np.NaN
>>> player_x[505:550] = np.NaN
>>> player_y = t + np.random.randn()
>>> xy = XY(np.transpose(np.stack((player_x, player_y))), framerate=20)
Apply the Butterworth lowpass filter with its default settings.
>>> xy_filt = butterworth_lowpass(xy)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/butterworth_default_example.png
Apply the same filter but remove the sequence that is to short to filter.
>>> xy_filt = butterworth_lowpass(xy, remove_short_seqs=True)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/butterworth_removed_short_example.png
Apply the filter with different specifications.
>>> xy_filt = butterworth_lowpass(xy, order=5, Wn=4)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/butterworth_adjusted_example.png
References
----------
.. [1] `Butterworth, S. (1930). On the theory of filter amplifiers. Wireless
Engineer, 3, 536-541. <https://www.changpuak.ch/electronics/downloads/
On_the_Theory_of_Filter_Amplifiers.pdf>`_
"""
# minimum signal length a filter with this specs can be applied on
min_signal_len = 3 * (order + 1)
framerate = xy.framerate
# pre-allocate space for filtered data
xy_filt = np.empty(xy.xy.shape)
# loop through the xy-object columns
for i, column in enumerate(np.transpose(xy.xy)):
# extract indices of filterable and short sequences
seqs_filt, seqs_short = _get_filterable_and_short_sequences(
column, min_signal_len
)
# pre-allocate space for filtered column
col_filt = np.full(column.shape, np.nan)
# loop through filterable sequences
for start, end in seqs_filt:
# apply filter to the sequence and enter filtered data to their
# respective indices in the data
col_filt[start:end] = _filter_sequence_butterworth_lowpass(
column[start:end], order, Wn, framerate, **kwargs
)
# check treatment of sequences that don't meet minimum signal length
if remove_short_seqs is False:
# enter short sequences unfiltered to their respective indices in the data
for start, end in seqs_short:
col_filt[start:end] = column[start:end]
# enter filtered data into respective column
xy_filt[:, i] = col_filt
# create new XY-data object with filtered data
xy_filtered = XY(xy=xy_filt, framerate=xy.framerate, direction=xy.direction)
return xy_filtered
[docs]def savgol_lowpass(
xy: XY,
window_length: int = 5,
poly_order: Numeric = 3,
remove_short_seqs: bool = False,
**kwargs,
) -> XY:
"""Applies a Savitzky-Golay lowpass-filter to a XY data object. [2]_
For filtering, the `scipy.filter.savgol <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.savgol_filter.html>`_ function is used. This function
provides a convenient access to the function, directly applying the filter to all
non-NaN sequences in all columns.
Parameters
----------
xy: XY
Floodlight XY Data object.
window_length: int, optional
The length of the filter window. Corresponds to the argument ``window_length``
from the `scipy.filter.savgol <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.savgol_filter.html>`_ function. Default is 5.
polyorder: Numeric, optional
The order of the polynomial used to fit the samples. ``poly_order`` must be less
than ``window_length``. Default is 3. Corresponds to the argument ``polyorder``
from the `scipy.filter.savgol <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.savgol_filter.html>`_ function. Default is 5.
remove_short_seqs: bool, optional
If True, sequences that are to short for the Filter with the specified settings
are removed from the data. If False, they are kept unfiltered. Default is False.
kwargs:
Optional arguments {'deriv', 'delta', 'mode', 'cval'} that can be passed to
the `scipy.signal.savgol <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.savgol_filter.html>`_ function.
Returns
-------
xy_filtered: XY
XY object with position data filtered by designed Butterworth low pass filter.
Notes
-----
The values of the input data are assumed to be numerical. Missing data is assumed
to be either np.NaN or None. The Savitzky-Golay-filter requires a minimum signal
length depending on the settings. A signal is a sequence of data in the XY-object
that is not interrupted by missing values. The minimum signal length is defined as
the ``window_length``. The treatment of signals shorter than the minimum signal
length are specified with the ``remove_short_sequence``-argument, where True will
replace these sequences with np.NaNs ond False will keep the sequences in the data
unfiltered.
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from floodlight import XY
>>> from floodlight.transforms.filter import savgol_lowpass
We first generate a noisy XY-object to smooth.
>>> t = np.linspace(-5, 5, 1000)
>>> player_x = np.sin(t) * t + np.random.rand(1000)
>>> player_x[450:495] = np.NaN
>>> player_x[505:550] = np.NaN
>>> player_y = t + np.random.randn()
>>> xy = XY(np.transpose(np.stack((player_x, player_y))), framerate=20)
Apply the Savgol lowpass filter with its default settings.
>>> xy_filt = savgol_lowpass(xy)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/savgol_default_example.png
Apply the filter with a longer window lengh and remove the sequence that is to short
to filter.
>>> xy_filt = savgol_lowpass(xy, window_length=12, remove_short_seqs=True)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/savgol_removed_short_example.png
Apply the filter with different specifications.
>>> xy_filt = savgol_lowpass(xy, window_length=50, poly_order=5)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/savgol_adjusted_example.png
References
----------
.. [2] `Savitzky, A.; Golay, M.J. (1964). Smoothing and differentiation of data
by simplified least squares procedures. Analytical Chemistry, 36(1), 1627-
1639. <https://pubs.acs.org/doi/abs/10.1021/ac60214a047>`_
"""
# minimum signal length a filter with this specs can be applied on
min_signal_len = window_length
# pre-allocate space for filtered data
xy_filt = np.empty(xy.xy.shape)
# loop through the xy-object columns
for i, column in enumerate(np.transpose(xy.xy)):
# extract indices of alternating NaN/non-NaN sequences
seqs_filt, seqs_short = _get_filterable_and_short_sequences(
column, min_signal_len
)
# pre-allocate space for filtered column
col_filt = np.full(column.shape, np.nan)
# loop through filterable sequences
for start, end in seqs_filt:
# apply filter to the sequence and enter filtered data to their
# respective indices in the data
col_filt[start:end] = scipy.signal.savgol_filter(
column[start:end], window_length, poly_order, **kwargs
)
# check treatment of sequences that don't meet minimum signal length
if remove_short_seqs is False:
# enter short sequences unfiltered to their respective indices in the data
for start, end in seqs_short:
col_filt[start:end] = column[start:end]
# enter filtered data into respective column
xy_filt[:, i] = col_filt
# create new XY-data object with filtered data
xy_filtered = XY(xy=xy_filt, framerate=xy.framerate, direction=xy.direction)
return xy_filtered