Last updated: 2023-03-16.
tf_quant_finance.utils.broadcast_common_batch_shape#
Broadcasts argument batch shapes to the common shape.
tf_quant_finance.utils.broadcast_common_batch_shape(
*, event_ranks=None, name=None, *args
)
Each input Tensor is assumed to be of shape batch_shape_i + event_shape_i.
The function finds a common batch_shape and broadcasts each Tensor to
batch_shape + event_shape_i. The common batch shape is the minimal shape
such that all batch_shape_i can broadcast to it.
Example 1. Batch shape is all dimensions but the last one#
import tensorflow as tf
import tf_quant_finance as tff
# Two Tensors of shapes [2, 3] and [2]. The batch shape of the 1st Tensor is
# [2] and for the second is []. The common batch shape is [2]
args = [tf.ones([2, 3], dtype=tf.float64), tf.constant([True, False])]
tff.utils.broadcast_common_batch_shape(*args)
# Expected: (array([[1., 1., 1.], [1., 1., 1.]]),
# array([[True, True], [False, False]])
Example 2. Specify ranks of event shapes#
import tensorflow as tf
import tf_quant_finance as tff
args = [tf.ones([2, 3], dtype=tf.float64), tf.constant([True, False])]
tff.utils.broadcast_common_batch_shape(*args,
event_ranks)
# Expected: (array([[1., 1., 1.], [1., 1., 1.]]),
# array([[True, True], [False, False]])
Args:#
*args: A sequence ofTensors of compatible shapes and anydtypes.event_ranks: A sequence of integers of the same length asargsspecifying ranks ofevent_shapefor each inputTensor. Default value:Nonewhich means that all dimensions but the last one are treated as batch dimension.name: Python string. The name to give to the ops created by this function. Default value:Nonewhich maps to the default namebroadcast_tensor_shapes.
Returns:#
A tuple of broadcasted Tensors. Each Tensor has the same dtype as the
corresponding input Tensor.
Raises:#
ValueError: (a) Ifevent_ranksis supplied and is of different fromargslength. (b) If inputs are of incompatible shapes.