Source code for utils.plotting

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import patches
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from matplotlib.collections import LineCollection
import networkx as nx
import numpy.typing as npt
from scipy.interpolate import griddata
from matplotlib import cm
import os
import plotly.express as px
import plotly.graph_objects as go
from math import floor
from typing import Optional, Iterable, Tuple, Callable, Literal, Union


def parity_plot(true, pred, fn: str):
    fig, ax = plt.subplots()
    ax.scatter(true, pred, s=15, alpha=0.5)
    ax.axline((true.min(),true.min()),slope=1,c='k')
    ax.set_xlabel(f'True')
    ax.set_ylabel(f'Predicted')
    validAREL = np.mean(np.abs(pred-true)/true)
    ax.text(0.95,0.05,
        f'MAPE {100*validAREL:.1f}%',
        fontsize=12,
        transform=ax.transAxes,
        horizontalalignment='right') 
    os.makedirs(os.path.dirname(fn), exist_ok=True)
    plt.savefig(fn, dpi=300, facecolor='w', bbox_inches='tight', pad_inches=0.1)

def surface_plot(true, pred, dataset, nplot: int, fn: str):
    fig = plt.figure(figsize=(6,3*nplot))
    names = np.array(dataset.data.name)
    uq_names = np.unique(names)
    nplot = min(nplot, len(uq_names))
    plot_uq_names = np.random.choice(uq_names, nplot, replace=False)
    for i in range(nplot):
        name = plot_uq_names[i]
        idx = np.flatnonzero(names==name)
        phi = [dataset[k].phi.item() for k in idx]
        th = [dataset[k].th.item() for k in idx]
        ax = fig.add_subplot(nplot,2,1+(2*i), projection='3d')
        plot_surface_data(phi, th, true[idx], ax=ax, resolution=400j, clims=(1,2))
        ax = fig.add_subplot(nplot,2,2+(2*i), projection='3d')
        plot_surface_data(phi, th, pred[idx], ax=ax, resolution=400j, clims=(1,2))
    plt.subplot(nplot, 2, 1).set_title('Ground truth')
    plt.subplot(nplot, 2, 2).set_title('Prediction')
    plt.tight_layout()
    os.makedirs(os.path.dirname(fn), exist_ok=True)
    plt.savefig(fn, dpi=300, facecolor='w', bbox_inches='tight', pad_inches=0.1)

def get_nodes_edge_coords(
    lat, repr, coords, highlight_nodes: Optional[Iterable]=None
) -> Tuple[np.ndarray, np.ndarray, Union[None, Iterable], np.ndarray, np.ndarray]:

    nodes = lat.reduced_node_coordinates

    if coords=='transformed': 
        Q = lat.transform_matrix
    elif coords=='reduced':
        Q = np.eye(3)
    else: 
        raise ValueError
    Q6 = np.block([[Q, np.zeros_like(Q)],[np.zeros_like(Q), Q]])

    if repr=='cropped':
        edges = lat.edge_adjacency
        edge_coords = lat._node_adj_to_ec(nodes, edges)
        node_numbers = np.arange(nodes.shape[0])
        try:
            edge_widths = lat.windowed_edge_radii
        except AttributeError:
            edge_widths = 2*np.ones(edges.shape[0])
    elif repr=='fundamental':
        if not hasattr(lat, 'fundamental_edge_adjacency'):
            lat.calculate_fundamental_representation()
        edges = lat.fundamental_edge_adjacency
        edge_coords = lat._node_adj_to_ec(nodes, edges)
        edge_coords += lat.fundamental_tesselation_vecs
        uq_inds = np.unique(edges)
        nodes = nodes[uq_inds] # only plot the fundamental nodes
        if isinstance(highlight_nodes, Iterable):
            assert np.all(np.in1d(highlight_nodes, uq_inds)), \
                "Highlighted node must be a fundamental node"
            highlight_nodes = np.searchsorted(uq_inds, highlight_nodes)
        node_numbers = uq_inds
        try:
            edge_widths = lat.fundamental_edge_radii
        except AttributeError:
            edge_widths = 2*np.ones(edges.shape[0])
    else:
        raise ValueError
    
    # scale mean to 2
    edge_widths = 2*edge_widths/np.mean(edge_widths)

    return nodes@(Q.T), edge_coords@(Q6.T), highlight_nodes, node_numbers, edge_widths

def plot_unit_cell_2d(
    lat, repr='cropped', coords='reduced', show_node_numbers=True,
    ax=None
    ) -> plt.Axes:
    
    nodes, edge_coords, _, node_numbers, edge_widths = get_nodes_edge_coords(
        lat, repr, coords
    )
    
    if not isinstance(ax, plt.Axes):
        fig = plt.figure(figsize=(5,5),facecolor='w')
        ax = plt.axes()
    ax.scatter(nodes[:,0], nodes[:,1])
    segments = []
    colors = [] 
    for i_e, e in enumerate(edge_coords):
        p0, p1 = e[:3], e[3:]
        x_0, y_0, z_0 = p0
        x_1, y_1, z_1 = p1
        segments.append([(x_0, y_0), 
                        (x_1, y_1)])
        colors.append(f'C{i_e%5}')

    lc = LineCollection(segments, colors=colors, linewidths=edge_widths)
    ax.add_collection(lc)
    if show_node_numbers:
        for n, num in zip(nodes, node_numbers):
            ax.text(n[0],n[1],f"{num}")
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    return ax

def plot_unit_cell_3d(
    lat, repr='cropped', coords='reduced', show_node_numbers=False,
    show_uc_box: bool = False,
    ax=None
    ) -> plt.Axes:

    nodes, edge_coords, _, node_numbers, edge_widths = get_nodes_edge_coords(
        lat, repr, coords
    )
    
    if not isinstance(ax, plt.Axes):
        fig = plt.figure(figsize=(5,5),facecolor='w')
        ax = plt.axes(projection='3d')

    if show_uc_box:
        pts = np.array(
            [
                [0,0,0],
                [1,0,0],
                [1,1,0],
                [0,1,0],
                [0,0,1],
                [1,0,1],
                [1,1,1],
                [0,1,1]
            ]
        )
        if coords=='transformed':
            pts = lat.transform_coordinates(pts)
        inds = [1,0,3,2,None,0,4,7,3,None,4,5,6,7,None,5,1,2,6]
        segments = []
        for i0, i1 in zip(inds[:-1], inds[1:]):
            if i0 is not None and i1 is not None:
                segments.append([pts[i0], pts[i1]])
        uc_box = Line3DCollection(segments, colors='black', linewidths=1)
        ax.add_collection(uc_box)      

    ax.scatter(nodes[:,0], nodes[:,1], nodes[:,2])
    segments = []
    colors = [] 
    for i_e, e in enumerate(edge_coords):
        p0, p1 = e[:3], e[3:]
        x_0, y_0, z_0 = p0
        x_1, y_1, z_1 = p1
        segments.append([(x_0, y_0, z_0), 
                        (x_1, y_1, z_1)])
        colors.append(f'C{i_e%5}')
    lc = Line3DCollection(segments, colors=colors, linewidths=edge_widths)
    ax.add_collection(lc)
    if show_node_numbers:
        for n, num in zip(nodes, node_numbers):
            ax.text(n[0],n[1],n[2],f"{num}")
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    return ax

[docs] def plotly_unit_cell_3d( lat: "Lattice", repr: Literal['cropped','fundamental'] = 'cropped', coords: Literal['reduced','transformed'] = 'reduced', show_node_numbers: bool = False, fig: Optional[go.Figure] = None, subplot: Optional[dict] = None, highlight_nodes: Optional[Iterable] = None, highlight_edges: Optional[Iterable] = None, show_uc_box: bool = False ) -> go.Figure: """Plot unit cell in 3D using plotly Args: lat (Lattice): Unit cell to plot repr (Literal["cropped", "fundamental"], optional): Representation of edges. Defaults to 'cropped'. coords (Literal["reduced", "transformed"], optional): Coordinate system to use. Defaults to 'reduced'. show_node_numbers (bool, optional): Whether to show node numbers. Defaults to False. fig (Optional[go.Figure], optional): Existing plotly figure. Defaults to None. subplot (Optional[dict], optional): Subplot information. Should contain keys 'index' and 'ncols'. Defaults to None. highlight_nodes (Optional[Iterable], optional): List of node numbers that should be highlighted. Defaults to None. highlight_edges (Optional[Iterable], optional): List of edges that should be highlighted. Defaults to None. show_uc_box (bool, optional): Whether to show unit cell bounding box. Defaults to False. Returns: go.Figure: Resulting plotly figure """ nodes, edge_coords, highlight_nodes, node_numbers, edge_widths = get_nodes_edge_coords( lat, repr, coords, highlight_nodes ) colororder = px.colors.qualitative.G10 x,y,z = nodes.T if isinstance(highlight_nodes, Iterable): assert np.all(np.in1d(highlight_nodes, np.arange(nodes.shape[0]))), \ "Highlighted nodes outside of limits" colors = ['rgba(40,40,40,0.3)' for _ in range(len(x))] for i_node_highlight in highlight_nodes: colors[i_node_highlight] = 'rgb(255,0,0)' else: colors = [colororder[i%10] for i in range(len(x))] if isinstance(highlight_edges, Iterable): assert np.min(highlight_edges)>=0 and np.max(highlight_edges)<len(edge_coords), \ "Highlighted edges outside of limits" edge_colors = ['rgba(40,40,40,0.3)' for _ in range(len(edge_coords))] for i_edge_highlight in highlight_edges: edge_colors[i_edge_highlight] = 'rgb(255,0,0)' else: edge_colors = [colororder[i%10] for i in range(len(edge_coords))] if not isinstance(fig, go.Figure): fig = go.Figure() mode = 'text+markers' if show_node_numbers else 'markers' if isinstance(subplot, dict): subplot_args = dict( row=floor(subplot['index']/subplot['ncols']) + 1, col=subplot['index']%subplot['ncols'] + 1 ) else: subplot_args = {} if show_uc_box: pts = np.array( [ [0,0,0], [1,0,0], [1,1,0], [0,1,0], [0,0,1], [1,0,1], [1,1,1], [0,1,1] ] ) if coords=='transformed': pts = lat.transform_coordinates(pts) inds = [1,0,3,2,None,0,4,7,3,None,4,5,6,7,None,5,1,2,6] fig.add_scatter3d( x=[pts[i,0] if isinstance(i,int) else None for i in inds], y=[pts[i,1] if isinstance(i,int) else None for i in inds], z=[pts[i,2] if isinstance(i,int) else None for i in inds], mode='lines', line=dict(color='black', width=2), name='unit cell', showlegend=False, **subplot_args ) fig.add_scatter3d( x=x, y=y, z=z, marker={'color':colors,'size':6}, mode=mode, text=node_numbers, textfont={'size':14}, showlegend=False, name='nodes', **subplot_args ) colors = [] x = [] y = [] z = [] for i_e, e in enumerate(edge_coords): n0, n1 = e[:3], e[3:] x_0, y_0, z_0 = n0 x_1, y_1, z_1 = n1 x.extend([x_0, x_1, None]) y.extend([y_0, y_1, None]) z.extend([z_0, z_1, None]) col = edge_colors[i_e] colors.extend([col, col, col]) fig.add_scatter3d( x=x, y=y, z=z, line={'width':7,'color':colors}, mode='lines', name='edges', hoverinfo='none', connectgaps=False, showlegend=False, **subplot_args ) if hasattr(lat, 'name'): title = lat.name else: title = '' if isinstance(subplot, dict): fig.layout.annotations[subplot['index']].update(text=title) else: fig.update_layout(title=title) return fig
# %% def visualize_graph( edges : npt.NDArray, nodes= None, node_types=None, ax=None ) -> plt.Axes: if not isinstance(nodes, np.ndarray): nodes = np.unique(edges) colors = {'corner':'blue', 'edge':'red', 'face':'green', 'inside':'grey'} cmap = ['grey' for i in range(len(nodes))] if isinstance(node_types, dict): for ntype in node_types.keys(): for n in node_types[ntype]: i = n cmap[i] = colors[ntype] edges_in = [] edges_count = [] for e in edges: e_sorted = sorted(e) if e_sorted in edges_in: edges_count[edges_in.index(e_sorted)] += 1 else: edges_in.append(e_sorted) edges_count.append(1) edges_tuples = [] for e_sorted in edges_in: for i in range(edges_count[edges_in.index(e_sorted)]): e = list(e_sorted) edges_tuples.append((e[0],e[1],{'r':i})) G = nx.Graph() G.add_nodes_from(nodes) G.add_edges_from(edges_tuples) # if isinstance(ax, plt.Axes): pass else: fig = plt.figure(facecolor='w') ax = plt.gca() pos = nx.spring_layout(G) nx.draw_networkx_nodes(G, pos, node_color = cmap, node_size = 200, alpha = 1, ax=ax ) nx.draw_networkx_labels(G, pos, ax=ax) edges_in = [] edges_count = [] for j,e in enumerate(edges_tuples): ec="0" if e[0]==e[1]: c = [x+0.1 for x in pos[e[0]]] circ = patches.Circle(xy=c, radius=0.1/np.sqrt(2), fc='none', ec='0') ax.add_patch(circ) else: if len(e)>2: r=0.3*e[2]['r'] else: r=0 ax.annotate("", xy=pos[e[0]], xycoords='data', xytext=pos[e[1]], textcoords='data', arrowprops=dict(arrowstyle="-", color=ec, shrinkA=10, shrinkB=10, patchA=None, patchB=None, connectionstyle=f"arc3,rad=rrr".replace('rrr',str(r)), ), ) plt.axis('off') return ax # %% def plotly_elasticity_surf( S: np.ndarray, title: str='', fig=None, subplot: Optional[dict] = None, clim: Optional[Tuple[float, float]] = None, ): return plotly_tensor_projection( S, point_func=lambda x: 1/x, title=title, fig=fig, subplot=subplot, clim=clim )
[docs] def plotly_tensor_projection( C: np.ndarray, point_func: Optional[Callable] = None, title: str='', fig: Optional[go.Figure] = None, subplot: Optional[dict] = None, clim: Optional[Tuple[float, float]] = None, resolution: complex = 100j, ) -> go.Figure: """Plot a projection of 4-th order tensor using plotly Args: C (np.ndarray): 4-th order tensor of shape (3,3,3,3) or (1,3,3,3,3) point_func (Optional[Callable]): function to apply point-wise to projections. title (str, optional): plot title fig (Optional[go.Figure]): existing figure. If None, a new figure is created. subplot (Optional[dict]): dictionary with keys ['index','ncol'] to pick the subplot clim (Optional[Tuple[float, float]]): limits of color scale resolution (complex, optional): resolution of grid on the sphere given as complex number. Defaults to 100j. Returns: go.Figure: Resulting plotly figure """ assert C.shape==(3,3,3,3) or C.shape==(1,3,3,3,3) u, v = np.mgrid[0:2*np.pi:resolution, 0:np.pi:resolution] X = np.sin(v)*np.cos(u) Y = np.sin(v)*np.sin(u) Z = np.cos(v) x = X.flatten() y = Y.flatten() z = Z.flatten() pos = np.column_stack((x,y,z)) e = np.einsum('ai,aj,ak,al,ijkl->a',pos,pos,pos,pos,C) if point_func is not None: e = point_func(e) rows, cols = X.shape indices = np.unravel_index(np.arange(len(e)), (rows, cols)) E = np.zeros_like(X) E[indices] = e R = E X = R*np.sin(v)*np.cos(u) Y = R*np.sin(v)*np.sin(u) Z = R*np.cos(v) if not isinstance(fig, go.Figure): fig = go.Figure() if isinstance(subplot, dict): subplot_args = dict( row=floor(subplot['index']/subplot['ncols']) + 1, col=subplot['index']%subplot['ncols'] + 1 ) else: subplot_args = {} clim = clim or (np.min(R), np.max(R)) fig.add_trace( go.Surface(x=X, y=Y, z=Z, surfacecolor=R, cmin=clim[0], cmax=clim[1], lighting=dict(roughness=0.25, specular=0.1)), **subplot_args ) if isinstance(subplot, dict): if len(fig.layout.annotations)>0: fig.layout.annotations[subplot['index']].update(text=title) else: print(f'Could not update subplot {subplot["index"]} with title {title}') else: fig.update_layout(title=title) return fig
[docs] def plotly_scaling_surf( C: np.ndarray, rel_dens: Iterable[float], title: str='', fig: Optional[go.Figure] = None, subplot: Optional[dict] = None, resolution: int = 100j, ) -> go.Figure: """Plot the surface of scaling exponent for a given stiffness tensor Args: C (np.ndarray): stacked compliance tensors [n_rel_dens, 3,3,3,3] rel_dens (Iterable[float]): corresponding relative densities title (str, optional): title for plot. Defaults to ''. fig (Optional[dict], go.Figure): can pass an existing figure. Defaults to None. subplot (Optional[dict], optional): subplot arguments. Expect keys ['index','ncol']. Defaults to None. Returns: go.Figure: plotly figure """ # C (np.ndarray): stacked compliance tensors [n_rel_dens, 3,3,3,3] assert C.ndim==5 assert C.shape[1:]==(3,3,3,3) assert len(rel_dens)==C.shape[0] u, v = np.mgrid[0:2*np.pi:resolution, 0:np.pi:resolution] X = np.sin(v)*np.cos(u) Y = np.sin(v)*np.sin(u) Z = np.cos(v) x = X.flatten() y = Y.flatten() z = Z.flatten() pos = np.column_stack((x,y,z)) e = 1/np.einsum('pi,pj,pk,pl,...ijkl->...p',pos,pos,pos,pos,C) # [rel_dens, direction] # linear fit x_fit = np.log(rel_dens) y_fit = np.log(e) fit = np.polyfit(x_fit, y_fit, 1) n = fit[0] # [direction] rows, cols = X.shape indices = np.unravel_index(np.arange(len(n)), (rows, cols)) N = np.zeros_like(X) N[indices] = n R = N X = R*np.sin(v)*np.cos(u) Y = R*np.sin(v)*np.sin(u) Z = R*np.cos(v) if not isinstance(fig, go.Figure): fig = go.Figure() if isinstance(subplot, dict): subplot_args = dict( row=floor(subplot['index']/subplot['ncols']) + 1, col=subplot['index']%subplot['ncols'] + 1 ) else: subplot_args = {} fig.add_trace( go.Surface(x=X, y=Y, z=Z, surfacecolor=R, cmin=1, cmax=2), **subplot_args ) if isinstance(subplot, dict): fig.layout.annotations[subplot['index']].update(text=title) else: fig.update_layout(title=title) return fig
# %% def plot_surface_data(phi, th, val, ax=None, resolution=200j, clims=None): PHI, TH = np.mgrid[0:np.pi:resolution, 0:2*np.pi:resolution] R = griddata(np.column_stack((phi,th)), val, (PHI,TH)) if not isinstance(ax, plt.Axes): fig = plt.figure() ax = plt.axes(projection='3d') z = R*np.cos(PHI) x = R*np.sin(PHI)*np.cos(TH) y = R*np.sin(PHI)*np.sin(TH) if isinstance(clims, tuple): color_val = (R-clims[0])/(clims[1]-clims[0]) maxlim = 1.1*clims[1] ax.set_xlim(-maxlim, maxlim) ax.set_ylim(-maxlim, maxlim) ax.set_zlim(-maxlim, maxlim) else: color_val = (R-np.nanmin(R))/(np.nanmax(R)-np.nanmin(R)) ax.plot_surface(x,y,z, facecolors=cm.viridis(color_val)) return ax