tf_quant_finance.utils.dataclass

Last updated: 2023-03-16.

tf_quant_finance.utils.dataclass#

View source

Creates a data class object compatible with tf.function.

tf_quant_finance.utils.dataclass(
    cls
)

Modifies dunder methods of an input class with typed attributes to work as an input/output to tf.function, as well as a loop variable of tf.while_loop.

An intended use case for this decorator is on top of a simple class definition with type annotated arguments like in the example below. It is not guaranteed that this decorator works with an arbitrary class.

Examples#

import tensorflow as tf
import tf_quant_finance as tff

@tff.utils.dataclass
class Coords:
  x: tf.Tensor
  y: tf.Tensor

@tf.function
def fn(start_coords: Coords) -> Coords:
  def cond(it, _):
    return it < 10
  def body(it, coords):
    return it + 1, Coords(x=coords.x + 1, y=coords.y + 2)
  return tf.while_loop(cond, body, loop_vars=(0, start_coords))[1]

start_coords = Coords(x=tf.constant(0), y=tf.constant(0))
fn(start_coords)
# Expected Coords(a=10, b=20)

Args:#

  • cls: Input class object with type annotated arguments. The class should not have an init method defined. Class fields are treated as ordered in the same order as they appear in the class definition.

Returns:#

Modified class that can be used as a tf.function input/output as well as a loop variable of tf.function. All typed arguments of the original class are treated as ordered in the same order as they appear in the class definition. All untyped arguments are ignored. Modified class modifies len and iter methods defined for the class instances such that len returns the number of arguments, and iter creates an iterator for the ordered argument values.