from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple, Union
import matplotlib
import numpy as np
from floodlight.utils.types import Numeric
from floodlight.vis.positions import plot_positions, plot_trajectories
[docs]@dataclass
class XY:
"""Spatio-temporal data fragment. Core class of floodlight.
Parameters
----------
xy: np.ndarray
Full data array containing x- and y-coordinates, where each player's coordinates
occupy two consecutive columns.
framerate: int, optional
Temporal resolution of data in frames per second/Hertz.
direction: {'lr', 'rl'}, optional
Playing direction of players in data fragment, should be either
'lr' (left-to-right) or 'rl' (right-to-left).
Attributes
----------
x: np.array
X-data array, where each player's x-coordinates occupy one column.
y: np.array
Y-data array, where each player's y-coordinates occupy one column.
N: int
The object's number of players.
"""
xy: np.ndarray
framerate: int = None
direction: str = None
def __str__(self):
return f"Floodlight XY object of shape {self.xy.shape}"
def __len__(self):
return len(self.xy)
def __getitem__(self, key):
return self.xy[key]
def __setitem__(self, key, value):
self.xy[key] = value
@property
def N(self) -> int:
n_columns = self.xy.shape[1]
if (n_columns % 2) != 0:
raise ValueError(f"XY has an odd number of columns ({n_columns})")
return n_columns // 2
@property
def x(self) -> np.array:
"""X-data array, where each player's x-coordinates occupy one column."""
return self.xy[:, ::2]
@x.setter
def x(self, x_data: np.ndarray):
self.xy[:, ::2] = x_data
@property
def y(self) -> np.array:
"""Y-data array, where each player's y-coordinates occupy one column."""
return self.xy[:, 1::2]
@y.setter
def y(self, y_data: np.ndarray):
self.xy[:, 1::2] = y_data
[docs] def frame(self, t: int) -> np.ndarray:
"""Returns data for given frame *t*.
Parameters
----------
t : int
Frame index.
Returns
-------
frame : np.ndarray
One-dimensional xy-data row for given frame.
"""
return self.xy[t, :]
[docs] def player(self, xID: int) -> np.ndarray:
"""Returns data for player with given player index *xID*.
Parameters
----------
xID : int
Player index.
Returns
-------
player : np.ndarray
Two-dimensional xy-data for given player.
"""
return self.xy[:, xID * 2 : xID * 2 + 2]
[docs] def point(self, t: int, xID: int) -> np.ndarray:
"""Returns data for a point determined by frame *t* and player index *xID*.
Parameters
----------
t: int
Frame index.
xID: int
Player index.
Returns
-------
point : np.ndarray
Point-data of shape (2,)
"""
return self.xy[t, xID * 2 : xID * 2 + 2]
[docs] def translate(self, shift: Tuple[Numeric, Numeric]):
"""Translates data by shift vector.
Parameters
----------
shift : list or array-like
Shift vector of form v = (x, y). Any iterable data type with two numeric
entries is accepted.
Notes
-----
Executing this method will cast the object's xy attribute to dtype np.float32 if
it previously has a non-floating dtype.
"""
# cast to float
if self.xy.dtype not in [np.float_, np.float64, np.float32, float]:
self.xy = self.xy.astype(np.float32, copy=False)
self.x = np.round(self.x + shift[0], 3)
self.y = np.round(self.y + shift[1], 3)
[docs] def scale(self, factor: float, axis: str = None):
"""Scales data by a given factor and optionally selected axis.
Parameters
----------
factor : float
Scaling factor.
axis : {None, 'x', 'y'}, optional
Name of scaling axis. If set to 'x' data is scaled on x-axis, if set to 'y'
data is scaled on y-axis. If None, data is scaled in both directions
(default).
Notes
-----
Executing this method will cast the object's xy attribute to dtype np.float32 if
it previously has a non-floating dtype.
"""
# cast to float
if self.xy.dtype not in [np.float_, np.float64, np.float32, float]:
self.xy = self.xy.astype(np.float32, copy=False)
if axis is None:
self.xy = np.round(self.xy * factor, 3)
elif axis == "x":
self.x = np.round(self.x * factor, 3)
elif axis == "y":
self.y = np.round(self.y * factor, 3)
else:
raise ValueError(f"Expected axis to be one of ('x', 'y', None), got {axis}")
[docs] def reflect(self, axis: str):
"""Reflects data on given `axis`.
Parameters
----------
axis : {'x', 'y'}
Name of reflection axis. If set to "x", data is reflected on x-axis,
if set to "y", data is reflected on y-axis.
"""
if axis == "x":
self.scale(factor=-1, axis="y")
elif axis == "y":
self.scale(factor=-1, axis="x")
else:
raise ValueError(f"Expected axis to be one of ('x', 'y'), got {axis}")
[docs] def rotate(self, alpha: float):
"""Rotates data on given angle 'alpha' around the origin.
Parameters
----------
alpha: float
Rotation angle in degrees. Alpha must be between -360 and 360. If positive
alpha, data is rotated in counter clockwise direction around the origin. If
negative, data is rotated in clockwise direction around the origin.
Notes
-----
Executing this method will cast the object's xy attribute to dtype np.float32 if
it previously has a non-floating dtype.
"""
if not (-360 <= alpha <= 360):
raise ValueError(
f"Expected alpha to be from -360 to 360, got {alpha} instead"
)
# cast to float
if self.xy.dtype not in [np.float_, np.float64, np.float32, float]:
self.xy = self.xy.astype(np.float32, copy=False)
# construct rotation matrix
phi = np.radians(alpha)
cos = np.cos(phi)
sin = np.sin(phi)
r = np.array([[cos, -sin], [sin, cos]]).transpose()
# perform player-wise rotation - this correctly handles nan's compared to
# block matrix approach
for p in range(self.N):
columns = (p * 2, p * 2 + 1)
self.xy[:, columns] = np.round(self.xy[:, columns] @ r, 3)
[docs] def slice(
self, startframe: int = None, endframe: int = None, inplace: bool = False
):
"""Return copy of object with sliced data. Mimics numpy's array slicing.
Parameters
----------
startframe : int, optional
Start of slice. Defaults to beginning of segment.
endframe : int, optional
End of slice (endframe is excluded). Defaults to end of segment.
inplace: bool, optional
If set to ``False`` (default), a new object is returned, otherwise the
operation is performed in place on the called object.
Returns
-------
xy_sliced: Union[XY, None]
"""
sliced_data = self.xy[startframe:endframe, :].copy()
xy_sliced = None
if inplace:
self.xy = sliced_data
else:
xy_sliced = XY(
xy=sliced_data,
framerate=deepcopy(self.framerate),
direction=deepcopy(self.direction),
)
return xy_sliced
[docs] def plot(
self,
t: Union[int, Tuple[int, int]],
plot_type: str = "positions",
ball: bool = False,
ax: matplotlib.axes = None,
**kwargs,
) -> matplotlib.axes:
"""Plots a snapshot or time intervall of the object's spatiotemporal data on a
matplotlib axes.
Parameters
----------
t: Union[int, Tuple [int, int]]
Frame for which postions should be plotted if plot_type == 'positions', or a
Tuple that has the form (start_frame, end_frame) if
plot_type == 'trajectories'.
plot_type: str, optional
One of {'positions', 'trajectories'}. Determines which plotting function is
called. Defaults to 'positions'.
ball: bool, optional
Boolean indicating whether this object is storing ball data. If set to True,
the styling is adjusted accordingly. Defaults to False.
ax: matplotlib.axes, optional
Axes from matplotlib library to plot on. Defaults to None.
kwargs:
Optional keyworded arguments e.g. {'color', 'zorder', 'marker', 'linestyle',
'alpha'} which can be used for the plot functions from matplotlib.
The kwargs are only passed to all the plot functions of matplotlib. If not
given default values are used (see floodlight.vis.positions).
Returns
-------
axes: matplotlib.axes
Axes from matplotlib library on which the specified plot type is plotted.
Notes
-----
The kwargs are only passed to the plot functions of matplotlib. To customize the
plots have a look at
`matplotlib
<https://matplotlib.org/3.5.0/api/_as_gen/matplotlib.axes.Axes.plot.html>`_.
For example in order to modify the color of the points and lines pass a color
name or rgb-value (`matplotlib colors
<https://matplotlib.org/3.5.0/tutorials/colors/colors.html>`_) to the keyworded
argument 'color'. The same principle applies to other kwargs like 'zorder',
'marker' and 'linestyle'.
Examples
--------
- :ref:`Positions plot <positions-plot-label>`
- :ref:`Trajectories plot <trajectories-plot-label>`
"""
plot_types = ["positions", "trajectories"]
# call visualization function based on plot_type
if plot_type == "positions":
return plot_positions(self, t, ball, ax=ax, **kwargs)
elif plot_type == "trajectories":
return plot_trajectories(self, t[0], t[1], ball, ax=ax, **kwargs)
else:
raise ValueError(
f"Expected plot_type to be one of {plot_types}, got {plot_type} "
"instead."
)