from typing import Tuple
import scipy.signal
import numpy as np
from floodlight import XY
from floodlight.utils.types import Numeric
def _get_filterable_and_short_sequences(
data: np.ndarray, min_signal_len: int
) -> Tuple[np.ndarray, np.ndarray]:
"""Returns start and end indices of continuous, filterable sequences and sequences
too short for filtering with the specified filter.
Parameters
----------
data: np.ndarray
Array of shape (T,) potentially containing NaNs.
min_signal_len: int
The minimum signal length that the specified filter can be applied on.
Returns
-------
filterable_sequences: np.ndarray
Two-dimensional array of shape (N, 2) and form
``[[sequence_start_idx, sequence_end_idx]]`` containing start and end indices of
N filterable sequences in the original data, ordered ascendingly. A sequence is
filterable when it doesn't contain NaNs and is at least as long as the minimum
window length of the specified filter.
short_sequences: np.ndarray
Two-dimensional array of shape (N, 2) and form
``[[sequence_start_idx, sequence_end_idx]]`` containing start and end indices of
N sequences in the original data that don't contain NaNs but are too short to
apply the specified filter on.
"""
if data.ndim != 1:
raise ValueError(
f"Expected input data to be one-dimensional. Got {data.ndim}-dimensional "
f"data instead."
)
# Convert possible None-types in data to np.nan
data = np.array(data, dtype=float)
# indices where nans and numbers are next to each other
change_points = np.where(np.diff(np.isnan(data), prepend=np.nan, append=np.nan))[0]
sequences = np.array(
[
(change_points[i], change_points[i + 1])
for i in range(len(change_points) - 1)
]
)
# which sequences contain NaNs
seq_is_nan = np.where(np.isnan(data[sequences[:, 0]]), False, True)
# remove sequences containing NaNs
non_nan_sequences = sequences[seq_is_nan == 1]
# split remaining sequences into filterable and short
filterable_sequences = non_nan_sequences[
(non_nan_sequences[:, 1] - non_nan_sequences[:, 0]) > min_signal_len
]
short_sequences = non_nan_sequences[
(non_nan_sequences[:, 1] - non_nan_sequences[:, 0]) <= min_signal_len
]
return filterable_sequences, short_sequences
def _filter_sequence_butterworth_lowpass(
signal: np.ndarray,
order: int = 3,
Wn: Numeric = 1,
framerate: Numeric = None,
**kwargs,
) -> np.ndarray:
"""Filters the incoming signal with a digital Butterworth lowpass filter.
Wrapper for combined application of the `scipy.signal.butter <https://docs.scipy.
org/doc/scipy/reference/generated/scipy.signal.butter.html>`__ and `scipy.signal.
filtfilt <https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.
filtfilt.html>`_ functions.
Parameters
----------
signal: np.ndarray
Array of shape (T, N) containing the signal to be smoothed with T frames and N
independent signals. Corresponds to the argument ``x`` from the `scipy.signal.
filtfilt <https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.
filtfilt.html>`_ function.
order: int, optional
The order of the filter. Corresponds to the argument ``N`` from the `scipy.
signal. butter <https://docs.scipy.org/doc/scipy/reference/generated/scipy.
signal.butter.html>`_ function. Default is 3.
Wn: Numeric, optional
The critical cutoff frequency. Corresponds to the argument ``Wn`` from the
`scipy.signal.butter <https://docs.scipy.org/doc/scipy/reference/generated/scipy
.signal.butter.html>`_ function. Default is 1.
framerate: Numeric, optional
The sampling frequency of the signal. Corresponds to the argument ``fs`` from
the `scipy.signal.butter <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.butter.html>`_ function.
kwargs:
Optional arguments {'padtype', 'padlen', 'method', 'irlen'} that can be passed
to the `scipy.signal.filtfilt <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.filtfilt.html>`_ function.
Returns
-------
signal_filtered: np.array
Signal filtered by the Butterworth filter.
"""
# Calculation of filter coefficients
coeffs = scipy.signal.butter(
order,
Wn,
btype="lowpass",
output="ba",
fs=framerate,
)
# applying the filter to the data
signal_filtered = scipy.signal.filtfilt(
coeffs[0], coeffs[1], np.asarray(signal, dtype=np.float64), axis=0, **kwargs
)
return signal_filtered
[docs]
def butterworth_lowpass(
xy: XY, order: int = 3, Wn: Numeric = 1, remove_short_seqs: bool = False, **kwargs
) -> XY:
"""Applies a digital Butterworth lowpass-filter to an XY data object. [1]_
For filtering, the `scipy.filter.butter <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.butter.html>`_ and the `scipy.signal.filtfilt <https://docs.
scipy.org/doc/scipy/reference/generated/scipy.signal.filtfilt.html>`_ functions are
used. This function provides a convenience access to both functions, directly
applying the filter to all non-NaN sequences in all columns.
Parameters
----------
xy: XY
Floodlight XY Data object.
order: int, optional
The order of the filter. Corresponds to the argument ``N`` from the `scipy.
signal.butter <https://docs.scipy.org/doc/scipy/reference/generated/scipy.
signal.butter.html>`_ function. Default is 3
Wn: Numeric, optional
The normalized critical cutoff frequency. Corresponds to the argument ``Wn``
from the `scipy.signal.butter <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.butter.html>`_ function. Default is 1.
remove_short_seqs: bool, optional
If True, sequences that are too short for the filter with the specified settings
are replaced with np.nan. If False, they are kept unfiltered. Default is False.
kwargs:
Optional arguments {'padtype', 'padlen', 'method', 'irlen'} that can be passed
to the `scipy.signal.filtfilt <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.filtfilt.html>`_ function.
Returns
-------
xy_filtered: XY
XY object with position data filtered by designed Butterworth low pass filter.
Notes
-----
The values of the input data are assumed to be numerical. Missing data is assumed
to be either np.nan or None. The Butterworth-filter requires a minimum signal length
depending on the settings. A signal is a sequence of data in the XY-object that is
not interrupted by missing values. The minimum signal length is defined as
:math:`3 \\cdot (order + 1)`. The treatment of signals shorter than the minimum
signal length are specified with the ``remove_short_sequence``-argument, where True
will replace these sequences with np.nan and False will keep the sequences in the
data unfiltered.
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from floodlight import XY
>>> from floodlight.transforms.filter import butterworth_lowpass
We first generate a noisy XY-object to smooth.
>>> t = np.linspace(-5, 5, 1000)
>>> player_x = np.sin(t) * t + np.random.rand(1000)
>>> player_x[450:495] = np.nan
>>> player_x[505:550] = np.nan
>>> player_y = t + np.random.randn()
>>> xy = XY(np.transpose(np.stack((player_x, player_y))), framerate=20)
Apply the Butterworth lowpass filter with its default settings.
>>> xy_filt = butterworth_lowpass(xy)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/butterworth_default_example.png
Apply the same filter but remove the sequence that is too short to filter.
>>> xy_filt = butterworth_lowpass(xy, remove_short_seqs=True)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/butterworth_removed_short_example.png
Apply the filter with different specifications.
>>> xy_filt = butterworth_lowpass(xy, order=5, Wn=4)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/butterworth_adjusted_example.png
References
----------
.. [1] `Butterworth, S. (1930). On the theory of filter amplifiers. Wireless
Engineer, 3, 536-541. <https://www.changpuak.ch/electronics/downloads/
On_the_Theory_of_Filter_Amplifiers.pdf>`_
"""
# minimum signal length a filter with this specs can be applied on
min_signal_len = 3 * (order + 1)
framerate = xy.framerate
# pre-allocate space for filtered data
xy_filt = np.empty(xy.xy.shape)
# loop through the xy-object columns
for i, column in enumerate(np.transpose(xy.xy)):
# extract indices of filterable and short sequences
seqs_filt, seqs_short = _get_filterable_and_short_sequences(
column, min_signal_len
)
# pre-allocate space for filtered column
col_filt = np.full(column.shape, np.nan)
# loop through filterable sequences
for start, end in seqs_filt:
# apply filter to the sequence and enter filtered data to their
# respective indices in the data
col_filt[start:end] = _filter_sequence_butterworth_lowpass(
column[start:end], order, Wn, framerate, **kwargs
)
# check treatment of sequences that don't meet minimum signal length
if remove_short_seqs is False:
# enter short sequences unfiltered to their respective indices in the data
for start, end in seqs_short:
col_filt[start:end] = column[start:end]
# enter filtered data into respective column
xy_filt[:, i] = col_filt
# create new XY-data object with filtered data
xy_filtered = XY(xy=xy_filt, framerate=xy.framerate, direction=xy.direction)
return xy_filtered
[docs]
def savgol_lowpass(
xy: XY,
window_length: int = 5,
poly_order: Numeric = 3,
remove_short_seqs: bool = False,
**kwargs,
) -> XY:
"""Applies a Savitzky-Golay lowpass-filter to an XY data object. [2]_
For filtering, the `scipy.filter.savgol <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.savgol_filter.html>`_ function is used. This function
provides a convenient access to the function, directly applying the filter to all
non-NaN sequences in all columns.
Parameters
----------
xy: XY
Floodlight XY Data object.
window_length: int, optional
The length of the filter window. Corresponds to the argument ``window_length``
from the `scipy.filter.savgol <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.savgol_filter.html>`_ function. Default is 5.
poly_order: Numeric, optional
The order of the polynomial used to fit the samples. ``poly_order`` must be less
than ``window_length``. Default is 3. Corresponds to the argument ``poly_order``
from the `scipy.filter.savgol <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.savgol_filter.html>`_ function. Default is 5.
remove_short_seqs: bool, optional
If True, sequences that are too short for the filter with the specified settings
are removed from the data. If False, they are kept unfiltered. Default is False.
kwargs:
Optional arguments {'deriv', 'delta', 'mode', 'cval'} that can be passed to
the `scipy.signal.savgol <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.savgol_filter.html>`_ function.
Returns
-------
xy_filtered: XY
XY object with position data filtered by designed Savitzky-Golay low pass
filter.
Notes
-----
The values of the input data are assumed to be numerical. Missing data is assumed
to be either np.nan or None. The Savitzky-Golay-filter requires a minimum signal
length depending on the settings. A signal is a sequence of data in the XY-object
that is not interrupted by missing values. The minimum signal length is defined as
the ``window_length``. The treatment of signals shorter than the minimum signal
length are specified with the ``remove_short_sequence``-argument, where True will
replace these sequences with np.nan and False will keep the sequences in the data
unfiltered.
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from floodlight import XY
>>> from floodlight.transforms.filter import savgol_lowpass
We first generate a noisy XY-object to smooth.
>>> t = np.linspace(-5, 5, 1000)
>>> player_x = np.sin(t) * t + np.random.rand(1000)
>>> player_x[450:495] = np.nan
>>> player_x[505:550] = np.nan
>>> player_y = t + np.random.randn()
>>> xy = XY(np.transpose(np.stack((player_x, player_y))), framerate=20)
Apply the Savgol lowpass filter with its default settings.
>>> xy_filt = savgol_lowpass(xy)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/savgol_default_example.png
Apply the filter with a longer window length and remove the sequence that is too
short to filter.
>>> xy_filt = savgol_lowpass(xy, window_length=12, remove_short_seqs=True)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/savgol_removed_short_example.png
Apply the filter with different specifications.
>>> xy_filt = savgol_lowpass(xy, window_length=50, poly_order=5)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/savgol_adjusted_example.png
References
----------
.. [2] `Savitzky, A.; Golay, M.J. (1964). Smoothing and differentiation of data
by simplified least squares procedures. Analytical Chemistry, 36(1), 1627-
1639. <https://pubs.acs.org/doi/abs/10.1021/ac60214a047>`_
"""
# minimum signal length a filter with this specs can be applied on
min_signal_len = window_length
# pre-allocate space for filtered data
xy_filt = np.empty(xy.xy.shape)
# loop through the xy-object columns
for i, column in enumerate(np.transpose(xy.xy)):
# extract indices of alternating NaN/non-NaN sequences
seqs_filt, seqs_short = _get_filterable_and_short_sequences(
column, min_signal_len
)
# pre-allocate space for filtered column
col_filt = np.full(column.shape, np.nan)
# loop through filterable sequences
for start, end in seqs_filt:
# apply filter to the sequence and enter filtered data to their
# respective indices in the data
col_filt[start:end] = scipy.signal.savgol_filter(
column[start:end], window_length, poly_order, **kwargs
)
# check treatment of sequences that don't meet minimum signal length
if remove_short_seqs is False:
# enter short sequences unfiltered to their respective indices in the data
for start, end in seqs_short:
col_filt[start:end] = column[start:end]
# enter filtered data into respective column
xy_filt[:, i] = col_filt
# create new XY-data object with filtered data
xy_filtered = XY(xy=xy_filt, framerate=xy.framerate, direction=xy.direction)
return xy_filtered
def _filter_sequence_fir_lowpass(
signal: np.ndarray,
numtaps: int = 21,
cutoff: Numeric = 1,
framerate: Numeric = None,
window: str = "hamming",
**kwargs,
) -> np.ndarray:
"""Filters the incoming signal with a FIR lowpass filter.
Wrapper for combined application of the `scipy.signal.firwin <https://docs.scipy.
org/doc/scipy/reference/generated/scipy.signal.firwin.html>`__ and `scipy.signal.
filtfilt <https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.
filtfilt.html>`_ functions.
Parameters
----------
signal: np.ndarray
Array of shape (T, N) containing the signal to be smoothed with T frames and N
independent signals. Corresponds to the argument ``x`` from the `scipy.signal.
filtfilt <https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.
filtfilt.html>`_ function.
numtaps: int, optional
Length of the FIR filter (number of coefficients, i.e. the filter order + 1).
``numtaps`` must be odd for a Type I filter. Corresponds to the argument
``numtaps`` from the `scipy.signal.firwin <https://docs.scipy.org/doc/scipy/
reference/generated/scipy.signal.firwin.html>`_ function. Default is 21.
cutoff: Numeric, optional
The cutoff frequency of the filter in Hz (when ``framerate`` is specified).
Corresponds to the argument ``cutoff`` from the `scipy.signal.firwin
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.
firwin.html>`_ function. Default is 1.
framerate: Numeric, optional
The sampling frequency of the signal in Hz. Corresponds to the argument ``fs``
from the `scipy.signal.firwin <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.firwin.html>`_ function.
window: str, optional
Desired window to use for the FIR filter design. Corresponds to the argument
``window`` from the `scipy.signal.firwin <https://docs.scipy.org/doc/scipy/
reference/generated/scipy.signal.firwin.html>`_ function. Default is
``"hamming"``.
kwargs:
Optional arguments {'padtype', 'padlen', 'method', 'irlen'} that can be passed
to the `scipy.signal.filtfilt <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.filtfilt.html>`_ function.
Returns
-------
signal_filtered: np.array
Signal filtered by the FIR lowpass filter.
"""
# Design FIR filter coefficients
h = scipy.signal.firwin(numtaps, cutoff, window=window, fs=framerate)
# Apply zero-phase filtering (FIR filter has a = 1)
signal_filtered = scipy.signal.filtfilt(
h, 1, np.asarray(signal, dtype=np.float64), axis=0, **kwargs
)
return signal_filtered
[docs]
def fir_lowpass(
xy: XY,
numtaps: int = 21,
cutoff: Numeric = 1,
window: str = "hamming",
remove_short_seqs: bool = False,
**kwargs,
) -> XY:
"""Applies a FIR lowpass-filter to an XY data object.
For filtering, the `scipy.signal.firwin <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.firwin.html>`_ and the `scipy.signal.filtfilt <https://docs.
scipy.org/doc/scipy/reference/generated/scipy.signal.filtfilt.html>`_ functions are
used. This function provides a convenience access to both functions, directly
applying the filter to all non-NaN sequences in all columns.
Parameters
----------
xy: XY
Floodlight XY Data object.
numtaps: int, optional
Length of the FIR filter (number of coefficients, i.e. the filter order + 1).
``numtaps`` must be odd for a Type I filter. Corresponds to the argument
``numtaps`` from the `scipy.signal.firwin <https://docs.scipy.org/doc/scipy/
reference/generated/scipy.signal.firwin.html>`_ function. Default is 21.
cutoff: Numeric, optional
The cutoff frequency of the filter in Hz. Corresponds to the argument ``cutoff``
from the `scipy.signal.firwin <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.firwin.html>`_ function. Default is 1.
window: str, optional
Desired window to use for the FIR filter design. Corresponds to the argument
``window`` from the `scipy.signal.firwin <https://docs.scipy.org/doc/scipy/
reference/generated/scipy.signal.firwin.html>`_ function. Default is
``"hamming"``.
remove_short_seqs: bool, optional
If True, sequences that are too short for the filter with the specified settings
are replaced with np.nans. If False, they are kept unfiltered. Default is False.
kwargs:
Optional arguments {'padtype', 'padlen', 'method', 'irlen'} that can be passed
to the `scipy.signal.filtfilt <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.filtfilt.html>`_ function.
Returns
-------
xy_filtered: XY
XY object with position data filtered by the designed FIR lowpass filter.
Notes
-----
The values of the input data are assumed to be numerical. Missing data is assumed
to be either np.nan or None. The FIR filter requires a minimum signal length
depending on the settings. A signal is a sequence of data in the XY-object that is
not interrupted by missing values. The minimum signal length is defined as
:math:`3 \\cdot numtaps`. The treatment of signals shorter than the minimum
signal length are specified with the ``remove_short_seqs``-argument, where True
will replace these sequences with np.nan and False will keep the sequences in the
data unfiltered.
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from floodlight import XY
>>> from floodlight.transforms.filter import fir_lowpass
We first generate a noisy XY-object to smooth.
>>> t = np.linspace(-5, 5, 1000)
>>> player_x = np.sin(t) * t + np.random.rand(1000)
>>> player_x[450:495] = np.nan
>>> player_x[505:550] = np.nan
>>> player_y = t + np.random.randn()
>>> xy = XY(np.transpose(np.stack((player_x, player_y))), framerate=20)
Apply the FIR lowpass filter with its default settings.
>>> xy_filt = fir_lowpass(xy)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/fir_default_example.png
Apply the filter with different specifications.
>>> xy_filt = fir_lowpass(xy, numtaps=101, cutoff=3)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/fir_adjusted_example.png
"""
# minimum signal length a filter with this specs can be applied on
min_signal_len = 3 * numtaps
framerate = xy.framerate
# pre-allocate space for filtered data
xy_filt = np.empty(xy.xy.shape)
# loop through the xy-object columns
for i, column in enumerate(np.transpose(xy.xy)):
# extract indices of filterable and short sequences
seqs_filt, seqs_short = _get_filterable_and_short_sequences(
column, min_signal_len
)
# pre-allocate space for filtered column
col_filt = np.full(column.shape, np.nan)
# loop through filterable sequences
for start, end in seqs_filt:
# apply filter to the sequence and enter filtered data to their
# respective indices in the data
col_filt[start:end] = _filter_sequence_fir_lowpass(
column[start:end], numtaps, cutoff, framerate, window, **kwargs
)
# check treatment of sequences that don't meet minimum signal length
if remove_short_seqs is False:
# enter short sequences unfiltered to their respective indices in the data
for start, end in seqs_short:
col_filt[start:end] = column[start:end]
# enter filtered data into respective column
xy_filt[:, i] = col_filt
# create new XY-data object with filtered data
xy_filtered = XY(xy=xy_filt, framerate=xy.framerate, direction=xy.direction)
return xy_filtered
def _kalman_filter_1d(
signal: np.ndarray,
dt: float,
process_noise: float = 1.0,
measurement_noise: float = 0.04,
) -> np.ndarray:
"""Applies a forward-only 1D Kalman filter with a constant-velocity model.
The state vector is ``[position, velocity]`` and only position is observed.
NaN values in the input are handled natively: when an observation is missing,
only the predict step runs (covariance grows), and the output remains NaN.
The filter state is preserved across gaps so that velocity and uncertainty
information carries over when observations resume.
Parameters
----------
signal: np.ndarray
Array of shape (T,) containing the signal to be smoothed. May contain
NaNs (missing observations).
dt: float
Time step between frames in seconds (1 / framerate).
process_noise: float, optional
Process noise intensity (acceleration variance in m²/s⁴) that controls
how much the model trusts the constant-velocity prediction. Larger
values allow faster changes in velocity. Default corresponds to
:math:`\\sigma_a = 1\\,\\mathrm{m/s^2}`, which is a conservative smoothing
prior.
measurement_noise: float, optional
Measurement noise variance (in m²) that controls how much the model
trusts the observed positions. Larger values produce smoother output.
Default is 0.04, corresponding to 0.20 m (20 cm) RMSE, a conservative
estimate for common local or optical tracking systems during high-dynamic
situations.
Returns
-------
signal_filtered: np.ndarray
Filtered signal of shape (T,). Frames where the input is NaN remain
NaN in the output.
"""
T = len(signal)
signal = np.array(signal, dtype=float)
# state transition matrix (constant velocity model)
F = np.array([[1.0, dt], [0.0, 1.0]])
# observation matrix (only position is observed)
H = np.array([[1.0, 0.0]])
# process noise covariance (standard CV discretization assuming
# piecewise-constant acceleration noise with intensity process_noise)
Q = process_noise * np.array([[dt**4 / 4.0, dt**3 / 2.0], [dt**3 / 2.0, dt**2]])
# measurement noise covariance
R = np.array([[measurement_noise]])
filtered = np.full(T, np.nan)
# find first non-NaN observation to initialize state
initialized = False
x = None
P = None
for t in range(T):
obs = signal[t]
if not initialized:
if np.isnan(obs):
continue
# initialize state from first observation
x = np.array([obs, 0.0])
P = np.diag([measurement_noise, measurement_noise / dt])
filtered[t] = obs
initialized = True
continue
# predict
x = F @ x
P = F @ P @ F.T + Q
if np.isnan(obs):
# no observation: predict only, output stays NaN
continue
# update with observation
y = obs - H @ x # innovation
S = H @ P @ H.T + R # innovation covariance
K = P @ H.T / S[0, 0] # Kalman gain (S is scalar)
x = x + K.ravel() * y[0]
P = (np.eye(2) - K @ H) @ P
filtered[t] = x[0]
return filtered
[docs]
def kalman(
xy: XY,
process_noise: float = 1.0,
measurement_noise: float = 0.04,
) -> XY:
"""Applies a forward Kalman filter to an XY data object. [3]_
Uses a constant-velocity motion model where the state vector consists of
position and velocity. Only positions are observed. The filter smooths noisy
position data by combining predictions from the motion model with the
observed measurements.
Parameters
----------
xy: XY
Floodlight XY Data object. Must have ``framerate`` set.
process_noise: float, optional
Process noise intensity (acceleration variance in m²/s⁴) that controls
how much the model trusts the constant-velocity prediction. Larger
values allow faster changes in velocity. Default corresponds to
:math:`\\sigma_a = 1\\,\\mathrm{m/s^2}`, which is a conservative smoothing
prior.
measurement_noise: float, optional
Measurement noise variance (in m²) controlling how much the model
trusts the observed positions. Larger values produce smoother output.
Default is 0.04, corresponding to 0.20 m RMSE, a conservative
estimate for common optical [4]_ and local [5]_ tracking systems
during high-dynamic situations.
Returns
-------
xy_filtered: XY
XY object with position data filtered by the Kalman filter.
Notes
-----
The Kalman filter requires ``xy.framerate`` to be set in order to compute
the time interval between frames.
The default noise parameters assume positions are given in meters. If
positions use a different unit (e.g. centimeters), the noise parameters
must be adjusted accordingly (e.g. ``measurement_noise = 20.0**2`` for
20 cm RMSE in centimeter units).
Unlike :func:`~floodlight.transforms.filter.butterworth_lowpass` and
:func:`~floodlight.transforms.filter.savgol_lowpass`, the Kalman filter
handles missing data (NaN) natively. When an observation is missing, the
filter runs a predict-only step, maintaining its internal state (velocity
estimate, covariance) across gaps. Frames with missing input data remain
NaN in the output — no gap-filling is performed.
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from floodlight import XY
>>> from floodlight.transforms.filter import kalman
Generate a noisy XY-object to smooth.
>>> t = np.linspace(-5, 5, 1000)
>>> player_x = np.sin(t) * t + np.random.rand(1000)
>>> player_x[450:495] = np.nan
>>> player_x[505:550] = np.nan
>>> player_y = t + np.random.randn()
>>> xy = XY(np.transpose(np.stack((player_x, player_y))), framerate=20)
Apply the Kalman filter with default settings.
>>> xy_filt = kalman(xy)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/kalman_default_example.png
Apply the filter with increased measurement noise for stronger smoothing.
>>> xy_filt = kalman(xy, measurement_noise=1.0)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/kalman_adjusted_example.png
References
----------
.. [3] `Kalman, R. E. (1960). A New Approach to Linear Filtering and
Prediction Problems. Journal of Basic Engineering, 82(1), 35-45.
<https://doi.org/10.1115/1.3662552>`_
.. [4] `Linke, D., Link, D., & Lames, M. (2020). Football-specific
validity of TRACAB's optical video tracking systems. PLoS ONE,
15(3), e0230179. <https://doi.org/10.1371/journal.pone.0230179>`_
.. [5] `Blauberger, P., Marzilger, R., & Lames, M. (2021). Validation
of player and ball tracking with a local positioning system.
Sensors, 21(4), 1465. <https://doi.org/10.3390/s21041465>`_
"""
if xy.framerate is None:
raise ValueError(
"The Kalman filter requires xy.framerate to be set in order to "
"compute the time step between frames."
)
dt = 1.0 / xy.framerate
# pre-allocate space for filtered data
xy_filt = np.empty(xy.xy.shape)
# loop through the xy-object columns
for i, column in enumerate(np.transpose(xy.xy)):
xy_filt[:, i] = _kalman_filter_1d(column, dt, process_noise, measurement_noise)
# create new XY-data object with filtered data
xy_filtered = XY(xy=xy_filt, framerate=xy.framerate, direction=xy.direction)
return xy_filtered
[docs]
def wiener(
xy: XY,
window_size: int = 5,
noise: float = None,
remove_short_seqs: bool = False,
) -> XY:
"""Applies a Wiener filter to an XY data object. [6]_
For filtering, the `scipy.signal.wiener <https://docs.scipy.org/doc/scipy/reference/
generated/scipy.signal.wiener.html>`_ function is used. This function provides a
convenient access to the function, directly applying the filter to all non-NaN
sequences in all columns.
Parameters
----------
xy: XY
Floodlight XY Data object.
window_size: int, optional
Size of the local window used for noise estimation and filtering. Corresponds
to the argument ``mysize`` from the `scipy.signal.wiener <https://docs.scipy.
org/doc/scipy/reference/generated/scipy.signal.wiener.html>`_ function.
Default is 5.
noise: float, optional
Noise power estimate. If None, the noise power is estimated locally from the
data within the window. Corresponds to the argument ``noise`` from the
`scipy.signal.wiener <https://docs.scipy.org/doc/scipy/reference/generated/
scipy.signal.wiener.html>`_ function. Default is None.
remove_short_seqs: bool, optional
If True, sequences that are too short for the filter with the specified settings
are replaced with np.nan. If False, they are kept unfiltered. Default is False.
Returns
-------
xy_filtered: XY
XY object with position data filtered by the Wiener filter.
Notes
-----
The values of the input data are assumed to be numerical. Missing data is assumed
to be either np.nan or None. The Wiener filter requires a minimum signal length
depending on the settings. A signal is a sequence of data in the XY-object that is
not interrupted by missing values. The minimum signal length is defined as
the ``window_size``. The treatment of signals shorter than the minimum signal
length are specified with the ``remove_short_seqs``-argument, where True will
replace these sequences with np.nan and False will keep the sequences in the
data unfiltered.
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from floodlight import XY
>>> from floodlight.transforms.filter import wiener
We first generate a noisy XY-object to smooth.
>>> t = np.linspace(-5, 5, 1000)
>>> player_x = np.sin(t) * t + np.random.rand(1000)
>>> player_x[450:495] = np.nan
>>> player_x[505:550] = np.nan
>>> player_y = t + np.random.randn()
>>> xy = XY(np.transpose(np.stack((player_x, player_y))), framerate=20)
Apply the Wiener filter with its default settings.
>>> xy_filt = wiener(xy)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/wiener_default_example.png
Apply the filter with a larger window size for stronger smoothing.
>>> xy_filt = wiener(xy, window_size=25)
>>> plt.plot(xy.x)
>>> plt.plot(xy_filt.x, linewidth=3)
>>> plt.legend(("Raw", "Smoothed"))
>>> plt.show()
.. image:: ../../_img/wiener_adjusted_example.png
References
----------
.. [6] `Wiener, N. (1949). Extrapolation, Interpolation, and Smoothing of
Stationary Time Series. MIT Press.
<https://doi.org/10.7551/mitpress/2946.001.0001>`_
"""
# minimum signal length a filter with this specs can be applied on
min_signal_len = window_size
# pre-allocate space for filtered data
xy_filt = np.empty(xy.xy.shape)
# loop through the xy-object columns
for i, column in enumerate(np.transpose(xy.xy)):
# extract indices of filterable and short sequences
seqs_filt, seqs_short = _get_filterable_and_short_sequences(
column, min_signal_len
)
# pre-allocate space for filtered column
col_filt = np.full(column.shape, np.nan)
# loop through filterable sequences
for start, end in seqs_filt:
# apply filter to the sequence
col_filt[start:end] = scipy.signal.wiener(
column[start:end].astype(float), mysize=window_size, noise=noise
)
# check treatment of sequences that don't meet minimum signal length
if remove_short_seqs is False:
# enter short sequences unfiltered to their respective indices
for start, end in seqs_short:
col_filt[start:end] = column[start:end]
# enter filtered data into respective column
xy_filt[:, i] = col_filt
# create new XY-data object with filtered data
xy_filtered = XY(xy=xy_filt, framerate=xy.framerate, direction=xy.direction)
return xy_filtered