Source code for sgpykit.main.plot3_sparse_grid

import matplotlib.pyplot as plt
import numpy as np

from sgpykit.util.misc import reshape_nested_lists_to_nrows
from sgpykit.util.struct_array import StructArray


[docs] def plot3_sparse_grid(S, dims, *args, **kwargs): """ Plot a sparse grid in 3D. Parameters ---------- S : StructArray Sparse grid object to plot. S is a sparse grid in 3D. S can be either reduced or not. S can also be a tensor grid. dims : list List of dimensions to plot. It plots the components d1, d2, d3 of the points in S if S is more than 3D. If empty, defaults to [0, 1, 2]. *args : tuple Additional positional arguments for plot styling. **kwargs : dict Additional keyword arguments for plot styling. Returns ------- h : object Handle to the plot. """ # Default dimensions if not provided if len(dims) == 0: dims = [0, 1, 2] # Extract knots from S x = S.knots if isinstance(S, StructArray): x = reshape_nested_lists_to_nrows(x, nrows=3) fig = plt.figure(figsize=(10, 8)) ax = fig.add_subplot(111, projection='3d') # Plotting based on the number of arguments if len(args) == 0 and len(kwargs) == 0: h = ax.scatter_3d(x[dims[0]], x[dims[1]], x[dims[2]])#, 'ok', markerfacecolor='k') else: if len(kwargs) == 0: kwargs = dict(np.array(args).reshape(-1, 2)) # Convert all keys to lowercase strings # kwargs = {str(key).lower(): value for key, value in kwargs.items()} h = ax.scatter(x[dims[0]], x[dims[1]], x[dims[2]]) #grid('on') # Turn on grid return h