Source code for chempy.util.numutil

# -*- coding: utf-8 -*-
from __future__ import (absolute_import, division, print_function)

try:
    import numpy as np
except ImportError:
    np = None


[docs]def broadcast_stack(*args, **kwargs): as_scalars = kwargs.pop('as_scalars', False) if kwargs != {}: raise ValueError("Got unknown kwargs: %s" % kwargs) args = [np.atleast_1d(arg) for arg in args] if as_scalars: args = [arg.reshape(arg.shape + (1,)) if arg.size > 1 else arg for arg in args] if all([arg.ndim == 1 for arg in args]): return np.concatenate(args) head_shape = () leading_length = 0 for arg in args: leading_length += arg.shape[-1] if arg.ndim > 1: if head_shape == (): head_shape = arg.shape[:-1] else: if arg.shape[:-1] != head_shape: raise ValueError("Incompatible shapes") out = np.empty(head_shape + (leading_length,)) for idx, arg in enumerate(args): if arg.shape[-1] != 1: raise ValueError("Trailing dimensions needs to be 1") out[..., idx] = arg[..., 0] return out