On Tue, 17 Aug 2021 17:50:59 +0200, joseph pareti <joeparet...@gmail.com> declaimed the following:
>In the following code, where does tl.Fn come from? i see it nowhere in the >documents, i.e I was looking for trax.layers.Fn : "layers" imports a whole slew of sub modules using from xxx import * in order to put all the sub module names at the same level. https://github.com/google/trax/blob/master/trax/layers/base.py >From line 748 on... def Fn(name, f, n_out=1): # pylint: disable=invalid-name """Returns a layer with no weights that applies the function `f`. `f` can take and return any number of arguments, and takes only positional arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`). The following, for example, would create a layer that takes two inputs and returns two outputs -- element-wise sums and maxima: `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)` The layer's number of inputs (`n_in`) is automatically set to number of positional arguments in `f`, but you must explicitly set the number of outputs (`n_out`) whenever it's not the default value 1. Args: name: Class-like name for the resulting layer; for use in debugging. f: Pure function from input tensors to output tensors, where each input tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`. Output tensors must be packaged as specified in the `Layer` class docstring. n_out: Number of outputs promised by the layer; default value 1. Returns: Layer executing the function `f`. """ -- Wulfraed Dennis Lee Bieber AF6VN wlfr...@ix.netcom.com http://wlfraed.microdiversity.freeddns.org/ -- https://mail.python.org/mailman/listinfo/python-list