from collections.abc import Sequence
import numpy as np
from scipy.io import savemat
from sgpykit.util.cell import Cell
from sgpykit.util.checks import is_numeric_scalar
from sgpykit.util.misc import get_shape_of_cells
from sgpykit.util.struct import Struct
from sgpykit.util.struct_array import StructArray
[docs]
def struct(*args, **kwargs):
"""
Create a Struct or StructArray from given arguments.
This function mimics MATLAB's struct behavior. If keyword arguments are provided,
it creates a Struct with the given fields. If positional arguments are provided,
it creates a Struct with fields from the even-indexed arguments and values from
the odd-indexed arguments. If cells are detected, it creates a StructArray.
Parameters
----------
*args : tuple
Positional arguments for field-value pairs.
**kwargs : dict
Keyword arguments for field-value pairs.
Returns
-------
Struct or StructArray
The created struct or struct array.
"""
if len(kwargs) == 0 and len(args) > 1 or len(kwargs) > 0:
# like matlab: check if there are cells given, otherwise return a single Struct
if len(kwargs) > 0:
cellshape = get_shape_of_cells(kwargs.values())
else:
cellshape = get_shape_of_cells(args[1::2])
if cellshape is None:
if len(kwargs) > 0:
return Struct(**kwargs)
else:
return Struct(**dict(zip(args[::2], args[1::2])))
return StructArray(*args, **kwargs)
[docs]
def fieldnames(s):
"""
Get the field names of a Struct or StructArray.
Parameters
----------
s : Struct or StructArray
The struct or struct array.
Returns
-------
list
The list of field names.
"""
if isinstance(s, Struct):
return list(vars(s).keys())
elif hasattr(s, 'shape'):
if s.shape[0] > 1: # 2D case
return list(vars(s.struct[0][0]).keys())
else:
return list(vars(s.struct[0]).keys())
else:
return list(vars(s.struct).keys())
[docs]
def cell(shape):
"""
Returns an empty cell (numpy) array with the given shape (empty nD-matrix).
2D cell objects can be flat indexed, e.g. c[0,2] or c[2].
Higher dimensions than 2D are not supported.
If shape is a single integer N, create a NxN cell
Parameters
----------
shape : int or tuple
The shape of the cell array. If an integer, creates a square cell array.
Returns
-------
Cell
The empty cell array.
"""
if isinstance(shape, int):
return Cell((shape, shape))
return Cell(shape)
[docs]
def ce(*args):
"""
Replacement for the matlab {} operator to create cells.
The cell elements must be of same data type.
Parameters
----------
*args : tuple
Elements to be placed in the cell array.
Returns
-------
Cell
The created cell array.
"""
result_cell = Cell((1, len(args)), dtype=type(args[0]))
for i, arg in enumerate(args):
result_cell[i] = arg
return result_cell
[docs]
def iscell(var):
"""
Check if the given variable is a Cell.
Parameters
----------
var : any
The variable to check.
Returns
-------
bool
True if the variable is a Cell, False otherwise.
"""
return isinstance(var, Cell)
[docs]
def isstruct(var):
"""
Check if the given variable is an instance of Struct or StructArray.
Parameters
----------
var : any
The variable to check.
Returns
-------
bool
True if the variable is an instance of Struct or StructArray, False otherwise.
"""
return isinstance(var, Struct) or isinstance(var, StructArray)
[docs]
def isfield(var, sel):
"""
Check if the given field exists in the struct.
Parameters
----------
var : str
The field name to check.
sel : Struct or StructArray
The struct or struct array.
Returns
-------
bool
True if the field exists, False otherwise.
"""
return var in fieldnames(sel)
[docs]
def unique(var, by=None):
"""
Incomplete matlabs unique() function.
https://de.mathworks.com/help/matlab/ref/double.unique.html
- 'first' and 'sorted' are defaults
https://numpy.org/doc/stable/reference/generated/numpy.unique.html#numpy-unique
Parameters
----------
var : array_like
Input array.
by : str, optional
The axis to return indices for repetitions. Default is None.
Returns
-------
tuple
The unique elements and their indices.
"""
if by == 'first': # which indices to return if repetitions
by = 0
elif by == 'row':
by = 0
return np.unique(var, axis=by, return_index=True)
[docs]
def find(X, n=None, direction='first'):
"""
Find non-zero elements in an array.
Parameters
----------
X : array_like
Input array.
n : int, optional
Maximum number of elements to return. Default is None.
direction : str, optional
Direction to search ('first' or 'last'). Default is 'first'.
Returns
-------
tuple
Indices and values of non-zero elements.
"""
# https://www.mathworks.com/help/matlab/ref/find.html?searchHighlight=find
if n is None:
return np.argwhere(X)
## Find non-zero elements
idx = np.nonzero(X)[0]
values = X[idx]
## Sort based on direction
if direction == 'last':
values = values[::-1]
## Limit to n elements if specified
if n is not None:
n = min(n, len(values))
idx = idx[:n]
values = values[:n]
return idx, values
[docs]
def setdiff(A, B, kind='rows'):
"""
Set difference of two arrays.
Parameters
----------
A : array_like
First input array.
B : array_like
Second input array.
kind : str, optional
Type of set difference ('rows' or others). Default is 'rows'.
Returns
-------
ndarray
The set difference.
"""
#return np.setdiff1d(var1, var2)
if 'rows' in kind:
# Ensure A and B are numpy arrays
A = np.atleast_2d(A)
B = np.atleast_2d(B)
# Check if the number of columns in A and B are the same
if A.shape[1] != B.shape[1]:
raise ValueError("A and B must have the same number of columns.")
# Use a list comprehension to filter rows in A not in B
diff = np.array([row for row in A if not any(np.array_equal(row, b_row) for b_row in B)])
return diff
else:
raise NotImplementedError("Not implemented yet.")
[docs]
def sortrows(A):
"""
Sort rows of a matrix in ascending order.
Parameters
----------
A : array_like
Input array.
Returns
-------
tuple
Sorted array and sorted indices.
"""
# matlab doc on sortrows(A):
# B = sortrows(A) sorts the rows of a matrix in ascending order based on the elements in the first column.
# When the first column contains repeated elements, sortrows sorts according to the values in the next column and
# repeats this behavior for succeeding equal values.
if len(A) < 2:
return A, [0]
if np.atleast_2d(A).shape[0] == 1:
return A, [0]
# Reverse the order of columns for lexsort to sort by the first column first
sorted_indices = np.lexsort(A.T[::-1])
# Use these indices to sort the entire array
sorted_array = A[sorted_indices]
return sorted_array, sorted_indices
[docs]
def sort(arr, axis=None, direction='ascend'):
"""
Sort an array in ascending or descending order.
Parameters
----------
arr : array_like
Input array.
axis : int, optional
Axis along which to sort. Default is None.
direction : str, optional
Direction to sort ('ascend' or 'descend'). Default is 'ascend'.
Returns
-------
tuple
Sorted array and sorted indices.
"""
if direction == 'ascend':
indices = np.argsort(arr, axis=axis)
elif direction == 'descend':
indices = np.argsort(-arr, axis=axis)
else:
raise ValueError("direction must be 'ascend' or 'descend'")
sorted_arr = np.take_along_axis(arr, indices, axis=axis)
return sorted_arr, indices
[docs]
def reshape(var, r, c=None): # TODO: to be removed
"""
Reshape an array.
Parameters
----------
var : array_like
Input array.
r : int or array_like
New shape or number of rows.
c : int, optional
Number of columns. Default is None.
Returns
-------
ndarray
Reshaped array.
"""
# Convert input to numpy array if it's not already
arr = np.array(var)
# Flatten the input array
# @TODO: can we just use array.reshape() ?
flat_arr = arr.flatten()
if c is None:
if isinstance(r, np.ndarray) or isinstance(r, list):
return flat_arr.reshape(r)
else:
raise ValueError(f"Cannot reshape array of size {flat_arr.size} into shape ({r}) (is {r} a list or nparray?)")
# Check if the total number of elements matches r * c
if flat_arr.size != r * c:
raise ValueError(f"Cannot reshape array of size {flat_arr.size} into shape ({r}, {c})")
# Reshape the flattened array
return flat_arr.reshape((r, c))
[docs]
def logical(var):
"""
Convert array to boolean (0=false, else=true).
Parameters
----------
var : array_like
Input array.
Returns
-------
ndarray
Boolean array.
"""
# converts array to boolean (0=false, else=true)
return np.array(var, dtype=bool)
[docs]
def intersect(x, y):
"""
Find the intersection of two arrays.
Parameters
----------
x : array_like
First input array.
y : array_like
Second input array.
Returns
-------
tuple
Common elements and their indices in the original arrays.
"""
# Ensure inputs are NumPy arrays
x = np.asarray(x)
y = np.asarray(y)
# Find unique elements in both arrays
x_unique = np.unique(x)
y_unique = np.unique(y)
# Find common elements
common = np.intersect1d(x_unique, y_unique)
# Find indices in original arrays
x_indices = np.searchsorted(x_unique, common)
y_indices = np.searchsorted(y_unique, common)
return common, x_indices, y_indices
[docs]
def issorted(matrix, typ='rows'):
"""
Check if the matrix is sorted.
Parameters
----------
matrix : array_like
Input matrix.
typ : str, optional
Type of sorting to check. Default is 'rows'.
Returns
-------
bool
True if the matrix is sorted, False otherwise.
"""
# Check if the matrix is empty
if matrix.size == 0:
return True
if typ != 'rows':
raise Exception('Not implemented.')
# Lexicographical sort check using numpy's lexsort
# lexsort returns index list, and they should match index list 1:n when already sorted
return np.all(np.lexsort(matrix.T[::-1]) == np.arange(len(matrix)))
# TODO: remove?
[docs]
def fliplr(arr):
"""
Flip the elements of the array from left to right.
Parameters
----------
arr : array_like
Input array to be flipped.
Returns
-------
ndarray
Flipped array.
"""
if isinstance(arr, list):
arr = np.array(arr)
if arr.ndim == 1:
# For 1D arrays, simply reverse the array
return arr[::-1]
elif arr.ndim == 2:
# For 2D arrays, flip each row from left to right
return arr[:, ::-1]
else:
raise ValueError("Input array must be 1D or 2D.")
[docs]
def ifft(arr):
"""
Compute the inverse FFT of a 1D or 2D array, matching MATLAB's ifft() behavior.
Parameters
----------
arr : array_like
A 1D or 2D numpy array.
Returns
-------
ndarray
The inverse FFT of the input array.
Raises
------
ValueError
If the input array is not 1D or 2D.
"""
arr = np.atleast_1d(np.squeeze(arr))
if arr.ndim == 1:
# For 1D array, compute the inverse FFT
return np.fft.ifft(arr)
elif arr.ndim == 2:
# For 2D array, compute the inverse FFT of each column like MATLAB
return np.fft.ifft(arr, axis=0)
else:
raise ValueError("Input array must be 1D or 2D.")
[docs]
def eig(var):
"""
Compute the eigenvalues and eigenvectors of a matrix.
Parameters
----------
var : array_like
Input matrix.
Returns
-------
tuple
Eigenvectors and eigenvalues.
"""
# X are eigenvalues (already a vector of the diagonals)
# matlab would return X as matrix
x, W = np.linalg.eig(var)
return W, x # swap to get matlab behavior
[docs]
def isnumeric(obj):
"""
Check if the object is numeric.
Parameters
----------
obj : any
The object to check.
Returns
-------
bool
True if the object is numeric, False otherwise.
"""
# Check if the object is a NumPy array
if isinstance(obj, np.ndarray):
# Check if the dtype of the array is a numeric type
return np.issubdtype(obj.dtype, np.number)
# Check if the object is a single number (int, float, etc.)
if is_numeric_scalar(obj):
return True
# Check if the object is a sequence (list, tuple, etc.) and all elements are numbers
if isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)):
return all(is_numeric_scalar(item) for item in obj)
return False
[docs]
def save(filename, *args):
"""
Save NumPy arrays to a MATLAB file.
Parameters
----------
filename : str
The name of the file to save.
*args : tuple
Variable length argument list of strings and NumPy arrays.
The first argument is the filename, followed by pairs of variable names and arrays.
"""
# Create a dictionary to hold the data
data_dict = {}
# Iterate over the arguments in pairs
for i in range(0, len(args), 2):
var_name = args[i]
var_array = args[i + 1]
data_dict[var_name] = var_array
# Save the dictionary to a MATLAB file
savemat(filename, data_dict)