
    VpfK             	         d Z ddlmZ ddlZddlZddlmZmZ ddlZddl	Z	ddl
Z
ddlZddlZddlZddlmZmZ ddlZddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZ ddlmZ ddlmZmZmZ ddlm Z  ddlm!Z" ddlm#Z# ddlm$Z$ ddlm%Z% ddlm&Z& ddlm'Z( ddlm)Z) ddlm*Z* ddlm+Z+ ddlm,Z, ddlm-Z. ddl/m0Z0 ddl/m1Z1 ddl2m3Z3 ddl4Z5 ej6        d ej7        dd           d!"          Z8 ej9        d# ej:        d$ e;d%                    d& e;d'          (          Z< ej6        d) ej7        d*d           d+"          Z= ej6        d, ej7        d-d.          d/"          Z> e
j?        e@          ZAdd4ZBdd7ZCe0jD        jE        ZFe0jG        ZGe0jH        ZIe0jJ        ZJe0jK        ZLe0jM        ZNeZO G d8 d9ejP                  ZQdd deQjR        d:d;ZSdd dddeQjR        d<d=ZTdd deQjR        d>ddAZU G dB dC          ZVdd dd eQjR        dDddEZW ejX                    ZYdddFdGdHZZddJZ[	  ej\        dK          Z]d.e]_^        ej_        `                    e]           ddOZae]b                    ea           dP Zce]d                    ec           dQ ZedRdSddVZf e#jg        e]ef           dRdSdd]Zh e ji        e]ehd^_           d.dddRd d`daZjddeZkddfZldg Zmdh Znenejo        e]<   di Zpepejq        e]<   dj Zrerejs        e]<   ddoZtddrZudd~ZvddZw ej\        d          Zxd.ex_^        exd                    d            exb                    d             e#jg        exd            d e&_y         G d dez          Z{e{Z| G d d          Z} e}            Z~ddZddZ e;d%          fddZdddZd ZdZeeSfeeTfeeUfeefeefdZddlZej        reSZeTZeUZeZeZnddlmZ  ee@e          Z[[dS )aU  Primitives for calling Python functions on the host from JAX accelerator code.

.. warning::
  The host_callback APIs are deprecated as of March 20, 2024.
  The functionality is subsumed by the
  `new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
  See https://github.com/google/jax/issues/20385.

This module introduces the host callback functions :func:`call`,
:func:`id_tap`, and :func:`id_print`, that send their arguments from the device
to the host and invoke user-defined Python functions on the host, optionally
returning results back to the device computation.

We show below how these functions can be used. We start with :func:`call`,
and we discuss examples of calling from JAX to arbitrary Python functions
on the CPU, e.g., to use NumPy CPU custom kernels. Then we
show uses of :func:`id_tap` and :func:`id_print`, which have the restriction
that they cannot return values from the host to the device.
These primitives are generally faster
because they are executed asynchronously with the device code.
In particular, they can be used to tap into and to debug JAX code.

Using :func:`call` to call a host function and return results to device
-----------------------------------------------------------------------

Use :func:`call` to invoke a computation on the host and return
NumPy arrays to the device computation.
Host computation is useful, e.g., when a device computation needs some data
that requires I/O on the host, or it needs a library that is available on the
host and you do not want to code it in JAX.
For example, eigen decomposition for general matrices in JAX does not work on TPU.
We can call the Numpy implementation from any JAX accelerator computation,
using a host computation::

  # This function runs on the host
  def host_eig(m: np.ndarray) -> np.ndarray:
    return np.linalg.eigvals(m)

  # This function is used in JAX
  def device_fun(m):
    # We send "m" to the host, asking it to call "host_eig" and return the result.
    # We have to specify the result shape and dtype, either in the form of an
    # example return value or any object that has `shape` and `dtype` attributes,
    # e.g., a NumPy array or a `jax.ShapeDtypeStruct`.
    return hcb.call(host_eig, m,
                    # Given an input of shape (..., d, d), eig output has shape (..., d)
                    result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))


The :func:`call` function and the Python host function both take a single argument
and return a single result, but those can be pytrees. Note that we must tell
the :func:`call` what shape and dtype to expect from the host invocation, using
the ``result_shape`` keyword argument.
This is important because the device code is compiled with that expectation.
There will be an error raised at runtime if the actual invocation produces a
different result shape. In general, **such errors and also exceptions raised
by the host computation may be difficult to debug**. See the Debugging section
below.
This is a problem for :func:`call` but not for :func:`id_tap` because for the
latter the device code does not expect a returned value.

The :func:`call` API can be used inside a jit or pmap computation or inside
cond/scan/while control flow. When used inside :func:`jax.pmap`, there will be
separate calls to the host from each of the participating devices::

  def host_sin(x, *, device):
    # The ``device`` argument is passed due to ``call_with_device=True`` below.
    print(f"Invoking host_sin with {x.shape} on {device}")
    return np.sin(x)

  # Use pmap to run the computation on two devices
  jax.pmap(lambda x: hcb.call(host_sin, x,
                              result_shape=x,
                              # Ask that the `host_sin` function be passed `device=dev`
                              call_with_device=True))(
           np.ones((2, 4), dtype=np.float32))

  # prints (in arbitrary order)
  # Invoking host_sin with (4,) on cpu:0
  # Invoking host_sin with (4,) on cpu:1

Note that :func:`call` does not support any JAX transformations, but as we
show below one can make use of the
existing support for `Custom differentiation in JAX <https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html>`_.

Using :func:`id_tap` to call a Python function on the host, with no returned values
-----------------------------------------------------------------------------------

The :func:`id_tap` and :func:`id_print` are special cases of :func:`call`, when
you just want the side effects of your Python callback. These functions have
the advantage that once the arguments have been sent to the host, the device
computation can proceed without waiting for the Python callback to return.
For :func:`id_tap` you can specify your Python callback to be called, while
:func:`id_print` uses a built-in callback that prints the arguments to
`stdout` on the host.
The Python function passed
to :func:`id_tap` takes two positional arguments (the value tapped
from the device computation along with a ``transforms`` tuple,
described below). Optionally, the function may be passed a keyword argument
``device`` with the Device from which the value was tapped.

A few examples::

  def host_func(arg, transforms):
     ...do something with arg...

  # calls host_func(2x, []) on host
  id_tap(host_func, 2 * x)

  # calls host_func((2x, 3x), [])
  id_tap(host_func, (2 * x, 3 * x))  # The argument can be a pytree

  # calls host_func(2x, [], device=jax.devices()[0])
  id_tap(host_func, 2 * x, tap_with_device=True)  # Pass the device to the tap

  # calls host_func(2x, [], what='activation')
  id_tap(functools.partial(host_func, what='activation'), 2 * x)

  # calls host_func(dict(x=x, y=y), what='data')
  id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y))

The above examples can all be adapted to use :func:`id_print` instead, with
the difference that :func:`id_print` prints on the host the positional argument,
along with any additional kwargs and the automatic kwarg ``transforms``.

Using :func:`barrier_wait` to wait until all callbacks have executed
--------------------------------------------------------------------

If your Python callbacks have side-effects you may need to wait until the
computation has finished to ensure that the side-effects have been observed.
You can use the :func:`barrier_wait` function for that purpose::

   accumulator = []
   def host_log(arg, transforms):
     # We just record the arguments in a list
     accumulator.append(arg)


   def device_fun(x):
     id_tap(host_log, x)
     id_tap(host_log, 2. * x)

   jax.jit(device_fun)(1.)
   jax.jit(device_fun)(1.)

   # At this point, we have started two computations, each with two
   # taps, but they may not have yet executed.
   barrier_wait()
   # Now we know that all the computations started before `barrier_wait`
   # on all devices, have finished, and all the callbacks have finished
   # executing.

Note that :func:`barrier_wait` will start one
tiny computation with one tap on each of the `jax.local_devices()` and
will wait for all these taps to be received.

An alternative to using :func:`barrier_wait` is to just wait for the end
of the computation, if all the callbacks are :func:`call`::

   accumulator = p[]
   def host_log(arg):
     # We just record the arguments in a list
     accumulator.append(arg)
     return 0.  #  return something


   def device_fun(c):
     y = call(host_log, x, result_shape=jax.ShapeDtypeStruct((), np.float32))
     z = call(host_log, 2. * x, result_shape=jax.ShapeDtypeStruct((), np.float32))
     return y + z  # return something that uses both results

   res1 = jax.jit(device_fun)(1.)
   res2 = jax.jit(device_fun)(1.)
   res1.block_until_ready()
   res2.block_until_ready()

Behavior under parallelization transformations
----------------------------------------------

In presence of :func:`jax.pmap` the code will run on multiple devices and
each device will tap its values independently.
It may be helpful to use the ``tap_with_device`` option for :func:`id_print`
or :func:`id_tap`, so that you see which device is sending which data::

  jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.])
  # device=cpu:0 what=x,x^2: (3., 9.)  # from the first device
  # device=cpu:1 what=x,x^2: (4., 16.)  # from the second device

When using :func:`jax.pmap` with multiple devices on multiple hosts, every
host will receive callbacks from all of its local devices, with an operand
that corresponds to each device slice. For a
:func:`call`, the callback must return to each device only the slice of the
result that pertains to the corresponding device.

When using the experimental :func:`pjit.pjit` the code will run on multiple
devices on different shards of the input. The current implementation of
host callbacks will ensure that a single device will collect and outfeed
the entire operand, in a single callback. The callback function is supposed
to return the entire array, which will then be sent in a single infeed to the
same device that issued the outfeed. This device is then responsible for
sending the required shards to the other devices::

  with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]):
    pjit.pjit(power3, in_shardings=(P("d"),),
              out_shardings=(P("d"),))(np.array([3., 4.]))

  # device=TPU:0 what=x,x^2: ( [3., 4.],
  #                            [9., 16.] )

Note that the collection of the operand on one device may result in OOM if
the operand was sharded across devices.

When using :func:`pjit.pjit` with multiple devices on multiple hosts, only
the host for the device 0 (w.r.t. the mesh) will receive the callback, with
the operand collected
from all participating devices on all hosts. For a :func:`call`, the callback
must return the entire array for all devices on all hosts.

Behavior under JAX autodiff transformations
-------------------------------------------

When used under a JAX autodiff transformation, the host callback functions
operate on the primal values only. Consider the following example::

    def power3(x):
      y = x * x
      # Print both 'x' and 'x^2'. Must pack as a tuple.
      hcb.id_print((x, y), what="x,x^2")
      return y * x

    power3(3.)
    # what: x,x^2 : (3., 9.)

(You can see these examples tested in `host_callback_test.HostCallbackTapTest.test_tap_transforms`.)

When used under :func:`jax.jvp` there will be one callback with the primal
values only::

    jax.jvp(power3, (3.,), (0.1,))
    # what: x,x^2 : (3., 9.)

Similarly for :func:`jax.grad`, we get a callback from the forward computation
only::

    jax.grad(power3)(3.)
    # what: x,x^2 : (3., 9.)

If you want to invoke the callback on the tangents during a :func:`jax.jvp`,
you can use a custom_jvp. For example, you can define a function that does
nothing interesting except that its custom_jvp will print the tangents::

    @jax.custom_jvp
    def print_tangents(arg):
      return None

    @print_tangents.defjvp
    def print_tangents_jvp(primals, tangents):
      arg_dot, = tangents
      hcb.id_print(arg_dot, what="tangents")
      return primals, tangents

Then you use this function in the places where you want to tap the tangents::

    def power3_with_tangents(x):
      y = x * x
      # Print both 'x' and 'x^2'. Must pack as a tuple.
      hcb.id_print((x, y), what="x,x^2")
      print_tangents((x, y))
      return y * x

    jax.jvp(power3_with_tangents, (3.,), (0.1,))
    # what: x,x^2 : (3., 9.)
    # what: tangents : (0.1, 0.6)

You can do a similar thing for the cotangents during :func:`jax.grad`. This
time you must be careful to use in the rest of the computation the values whose
cotangents you want to tap. Hence we make the ``print_cotangents`` return
its argument::

    @jax.custom_vjp
    def print_cotangents(arg):
      # Must return the argument for which we want the cotangent.
      return arg

    # f_fwd: a -> (b, residual)
    def print_cotangents_fwd(arg):
      return print_cotangents(arg), None
    # f_bwd: (residual, CT b) -> [CT a]
    def print_cotangents_bwd(residual, ct_b):
      hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
      return ct_b,

    print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)

    def power3_with_cotangents(x):
      y = x * x
      # Print both 'x' and 'x^2'. Must pack as a tuple.
      hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
      (x1, y1) = print_cotangents((x, y))
      # Must use the output of print_cotangents
      return y1 * x1

    jax.grad(power3_with_cotangents)(3.)
    # what: x,x^2 : (3., 9.)
    # what: cotangents : (9., 3.)

If you use :func:`ad_checkpoint.checkpoint` to rematerialize the residuals
for the backward pass, then the callbacks from the primal computation will
be called twice::

    jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)
    # what: x,x^2 : (3., 9.)
    # what: x,x^2 : (27., 729.)
    # what: x,x^2 : (3., 9.)

The callbacks are, in order from: the primal computation of the inner ``power3``,
the primal computation of the outer ``power3``, and the rematerialization
of the residuals for the inner ``power3``.


Behavior under jax.vmap
-----------------------

The host callback functions :func:`id_print` and :func:`id_tap` support the
vectorization transformation :func:`jax.vmap`.

For :func:`jax.vmap` the arguments to the callback are batched,
and the callback function is
passed an additional special ``transforms`` containing a list of transformation descriptors
in the form ``("batch", {"batch_dims": ...})``, where ``...``` denotes the
batched dimensions for the tapped values (one entry per argument, `
`None`` denotes an argument that was broadcast).

  jax.vmap(power3)(np.array([2., 3.]))
  # transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : ([2., 3.], [4., 9.])

See documentation for :func:`id_tap`, :func:`id_print`, and :func:`call`.

For more usage example, see tests/host_callback_test.py.

Using :func:`call` to call a TensorFlow function, with reverse-mode autodiff support
------------------------------------------------------------------------------------

Another possible use for host computation is to invoke a library written for
another framework, such as TensorFlow.
In this case it becomes interesting to support JAX autodiff for host callbacks
by deferring to the autodiff mechanism in TensorFlow,
using the :func:`jax.custom_vjp` mechanism.

This is relatively easy to do, once one understands both the JAX custom VJP
and the TensorFlow autodiff mechanisms.
The code for how this can be done is shown in the ``call_tf_full_ad``
function in `host_callback_to_tf_test.py <https://github.com/google/jax/blob/main/tests/host_callback_to_tf_test.py>`_.
This example supports arbitrary higher-order differentiation as well.

Note that if you just want to call TensorFlow functions from JAX, you can also
use the `jax2tf.call_tf function <https://github.com/google/jax/blob/main/jax/experimental/jax2tf/call_tf.py>`_.

Using :func:`call` to call a JAX function on another device, with reverse-mode autodiff support
------------------------------------------------------------------------------------------------

It should not be surprising that we can use host computation to invoke a JAX
computation on another device. The arguments are sent from the accelerator to
the host, and then to the outside device on which the JAX host
computation will run, and then the results are sent back to the original accelerator.

The code for how this can be done is shown in the ``call_jax_other_device function``
in `host_callback_test.py <https://github.com/google/jax/blob/main/tests/host_callback_test.py>`_.

Low-level details and debugging
-------------------------------

The host callback functions will be executed for each device in the order in
which the send operations were performed on the device.

The host callback functions for multiple devices may be interleaved.
The data from the devices is received by separate threads managed by the JAX
runtime (one thread per device). The runtime maintains a buffer of
configurable size (see the flag ``--jax_host_callback_max_queue_byte_size``).
When the buffer is full, all the receiving threads are paused
which eventually pauses the computation on devices. The runtime has one
additional thread for each device to invoke the Python user functions with the
received data. If the processing of the callbacks is slow, it may actually
lead to the runtime buffer filling up, and eventually pausing the computation
on the devices when they need to send something.
For more details on the outfeed receiver runtime mechanism see
`runtime code
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_.

In order to pause the execution until all data from computations already
started on devices has arrived and has been processed, use :func:`barrier_wait`.

Exceptions from the user-defined callback functions are logged along with their
stack traces, but the receiving threads are not stopped. Instead the last
exception is recorded and the subsequent :func:`barrier_wait` will
raise :exc:`CallbackException` if any exception had occurred
in one of the tap functions. This exception will include the text and the
stack trace of the last exception encountered.

One further complication arises for callback functions that must return
results to the call origin device, such as :func:`call()`. This is handled
differently on CPU/GPU devices compared to TPU devices.

On CPU/GPU devices, in order to avoid the device computation
being stuck waiting for a result that will never arrive, in case of any
error during the processing of the callback (whether raised by the user-code
itself or due to a mismatch of the returned value and the expected return_shape)
we send the device a "fake" result of shape ``int8[12345]``.
This will make the device
computation abort because the received data is different than the one that
it expects. On CPU the runtime will crash with a distinctive error message:

```
Check failed: buffer->length() == buffer_length (12345 vs. ...)
```

On GPU, the failure is more user-friendly and will be surfaced to the Python
program as:

```
RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ...
```

To debug the underlying cause for these messages, see the Debugging section.

On TPU devices, there is currently no shape check for infeed, so we take the
safer route of not sending this fake result in case of errors. This means
that the computation will hang, and no exception will be raised (but any
exceptions in the callback functions will still appear in the logs).

The current implementation uses the outfeed mechanism provided by XLA. The
mechanism itself is quite primitive in the sense that a receiver must know
exactly the shape of each incoming packet, and how many packets are expected.
This makes it hard to use for multiple kinds of data in the same computation,
and it is practically impossible to use it under conditionals or in loops
of non-constant iteration count. Furthermore, code that uses the outfeed
mechanism directly cannot be transformed by JAX. All these limitations are
addressed by the host callback functions. The tapping API introduced here
makes it easy to share the outfeed mechanism for multiple purposes, while
supporting all transformations.

**Note that after you have used the host callback functions, you cannot
use lax.outfeed directly**. You may want to :func:`stop_outfeed_receiver`
if you later need to use lax.outfeed.

Since the actual calls to your callback functions are made from the C++
receiver, it may be hard to debug the calls. In particular, the stack trace
will not include the calling code. You can use the flag
``jax_host_callback_inline`` (or the environment variable
``JAX_HOST_CALLBACK_INLINE``) to ensure that the calls to the callbacks are
inlined. This works only if the calls are outside a staging context
(:func:`~jax.jit` or a control-flow primitive).

The C++ `receiver
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_
is started automatically on the first call to :func:`id_tap`. In order to stop
it properly, upon start an ``atexit`` handler is registered to call
:func:`barrier_wait` with the logging name "at_exit".

There are a few environment variables that you can use to turn on logging
for the C++ outfeed `receiver backend
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_.

  * ``TF_CPP_MIN_LOG_LEVEL=0``: will turn on INFO logging, needed for all below.
  * ``TF_CPP_MIN_VLOG_LEVEL=3``: will make all VLOG logging up to level 3 behave
    like INFO logs. This may be too much, but you will see which modules are
    logging relevant info, and then you can select which modules to log from.
  * ``TF_CPP_VMODULE=<module_name>=3`` (the module name can be either C++ or
    Python, without the extension).

You should also use the ``--verbosity=2`` flag so that you see the logs
from Python.

For example, you can try to enable logging in the ``host_callback`` module:
``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple``

If you want to enable logging in lower-level implementation modules try:
``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple``

(For bazel tests use --test_arg=--vmodule=...

Still to do:
  * More performance tests.
  * Explore implementation with outside compilation for TPU.
  * Explore implementation with XLA CustomCall for CPU and GPU.

    )annotationsN)CallableSequence)Anycast)api)core)config)custom_derivatives)dtypes)lax)pjit)io_callback)adbatchingpxla)mlir)partial_eval)xla)ad_checkpoint)compiler)dispatch)pretty_printer)sharding_impls)source_info_util)	tree_util)util)
xla_bridge)
xla_client)xla_extension)hlojax_host_callback_inlineJAX_HOST_CALLBACK_INLINEFz5Inline the host_callback, if not in a staged context.)help%jax_host_callback_max_queue_byte_size%JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZEg    AzThe size in bytes of the buffer used to hold outfeeds from each device. When this capacity is reached consuming outfeeds from the device is paused, thus potentially pausing the device computation, until the Python callback consume more outfeeds.g    nA)r$   lower_boundjax_host_callback_outfeedJAX_HOST_CALLBACK_OUTFEEDzUse outfeed implementation for host_callback, even on CPU and GPU. If false, use the CustomCall implementation. Has no effect on TPU, since only the outfeed mechanism is implemented.jax_host_callback_legacyJAX_HOST_CALLBACK_LEGACYTzUse old implementation of host_callback, documented in the module docstring.If False, use the jax.experimental.io_callback implementation. See https://github.com/google/jax/issues/20385.platformstrreturnboolc                "    | dv pt           j        S )N)tpugpucudarocm)_HOST_CALLBACK_OUTFEEDvaluer,   s    ^/var/www/html/nettyfy-visnx/env/lib/python3.11/site-packages/jax/experimental/host_callback.py_use_outfeedr9   D  s    
4
4 '
 
&(    backendxb.XlaBackendc                L    t          j        |           rt          d          dS )z;Should be called whenever outfeed (or infeed) will be used.aI  host_callback functionality isn't supported with PJRT C API. See https://jax.readthedocs.io/en/latest/debugging/index.html and https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html for alternatives. Please file a feature request at https://github.com/google/jax/issues if none of the alternatives are sufficient.N)xbusing_pjrt_c_apiNotImplementedError)r;   s    r8   '_raise_if_using_outfeed_with_pjrt_c_apirA   I  s6    !! 
	   r:   c                      e Zd ZdZdZdZdZdS )CallbackFlavorzSpecifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False.

  See https://github.com/google/jax/issues/20385.
           N)__name__
__module____qualname____doc__IO_CALLBACKPUREDEBUG r:   r8   rC   rC   ^  s)          +	
$
%%%r:   rC   resulttap_with_devicedevice_indexcallback_flavorc          	         |rd}t          |          |0t          j        |          \  }}	|D ]}
t          j        |
           t          | ||dd||          }||S |S )az	  Host-callback tap primitive, like identity function with a call to ``tap_func``.

  .. warning::
    The host_callback APIs are deprecated as of March 20, 2024.
    The functionality is subsumed by the
    `new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
    See https://github.com/google/jax/issues/20385.

  ``id_tap`` behaves semantically like the identity function but has the
  side-effect that a user-defined Python function is called with the runtime
  value of the argument.

  Args:
    tap_func: tap function to call like ``tap_func(arg, transforms)``, with
      ``arg`` as described below and where ``transforms`` is the sequence of
      applied JAX transformations in the form ``(name, params)``. If the
      `tap_with_device` optional argument is True, then the invocation also
      includes the device from which the value is tapped as a keyword argument:
      ``tap_func(arg, transforms, device=dev)``.
    arg: the argument passed to the tap function, can be a pytree of JAX
      types.
    result: if given, specifies the return value of ``id_tap``. This value is
      not passed to the tap function, and in fact is not sent from the device to
      the host. If the ``result`` parameter is not specified then the return
      value of ``id_tap`` is ``arg``.
    tap_with_device: if True then the tap function is invoked with the
      device from which the tap originates as a keyword argument.
    device_index: specifies from which device the tap function is invoked in a
      SPMD program. Works only when using the outfeed implementation mechanism,
      i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
    callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
       the flavor of callback to use.
       See https://github.com/google/jax/issues/20385.

  Returns:
    ``arg``, or ``result`` if given.

  The order of execution is by data dependency: after all the arguments and
  the value of ``result`` if present, are computed and before the returned
  value is used. At least one of the returned values of ``id_tap`` must be
  used in the rest of the computation, or else this operation has no effect.

  Tapping works even for code executed on accelerators and even for code under
  JAX transformations.

  For more details see the :mod:`jax.experimental.host_callback` module documentation.
  zSupport for **kwargs in ``id_tap`` has been removed. Instead, pre-apply keyword arguments, either by using a closure or by passing ``functools.partial(tap_func, **kwargs)``.NT)call_with_deviceresult_shapeidentityrR   rS   )	TypeErrorr   tree_flattenr   	check_arg_call)tap_funcargrP   rQ   rR   rS   kwargsmsgflat_results_rcall_ress               r8   _deprecated_id_taprd   h  s    n  	5  C..,V44OL!  	&%' ' '( MOr:   )rP   rQ   rR   output_stream	thresholdrS   c               `    t          j        t          f||d|}t          || ||||          S )a  Like :func:`id_tap` with a printing tap function.

  .. warning::
    The host_callback APIs are deprecated as of March 20, 2024.
    The functionality is subsumed by the
    `new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
    See https://github.com/google/jax/issues/20385.

   On each invocation of the printing tap, the ``kwargs`` if present
   will be printed first (sorted by keys). Then arg will be printed,
   with the arrays stringified with ``numpy.array2string``.

   See the :func:`id_tap` documentation.

   Additional keyword arguments:

   * ``tap_with_device`` if True, will print also the device from which
     the value originates.
   * ``output_stream`` if given then it will be used instead of the
     built-in ``print``. The string will be passed as
     ``output_stream.write(s)``.
   * ``threshold`` is passed to ``numpy.array2string``.
   * ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
       the flavor of callback to use.
       See https://github.com/google/jax/issues/20385.

  For more details see the :mod:`jax.experimental.host_callback` module documentation.
  )re   rf   rO   )	functoolspartial_print_tap_funcrd   )	r]   rP   rQ   rR   re   rf   rS   r^   printers	            r8   _deprecated_id_printrl     s\    J o =,9(1= =5;= =' 
	%%
' 
' 
' 'r:   )rV   rU   rR   rS   callback_funcr   c          	         t           j        s|t          j        u r|t	          d          t          | |||d||          S )a  Make a call to the host, and expect a result.

  .. warning::
    The host_callback APIs are deprecated as of March 20, 2024.
    The functionality is subsumed by the
    `new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
    See https://github.com/google/jax/issues/20385.

  Args:
    callback_func: The Python function to invoke on the host as
      ``callback_func(arg)``. If the ``call_with_device`` optional argument is True,
      then the invocation also includes the ``device`` kwarg with the device
      from which the call originates: ``callback_func(arg, device=dev)``. This function
      must return a pytree of numpy ndarrays.

    arg: the argument passed to the callback function, can be a pytree of JAX
      types.

    result_shape: a value that describes the expected shape and dtype of the
      result. This can be a numeric scalar, from which a shape and dtype are
      obtained, or an object that has ``.shape`` and ``.dtype`` attributes.
      If the result of the callback is a pytree, then ``result_shape`` should
      also be a pytree with the same structure. In particular, ``result_shape``
      can be `()` or `None` if the function does not have any results.
      The device code containing ``call`` is compiled with the expected result shape and dtype,
      and an error will be raised at runtime if the actual ``callback_func``
      invocation returns a different kind of result.

    call_with_device: if True then the callback function is invoked with the
      device from which the call originates as a keyword argument.

    device_index: specifies from which device the tap function is invoked in a
      SPMD program. Works only when using the outfeed implementation mechanism,
      i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
    callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
       the flavor of callback to use.
       See https://github.com/google/jax/issues/20385.

  Returns:
    the result of the ``callback_func`` invocation.

  For more details see the :mod:`jax.experimental.host_callback` module documentation.
  NzWhen using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` flavor of callback only when the `result_shape` is None. See https://github.com/google/jax/issues/20385.F)rV   rU   rW   rR   rS   )_HOST_CALLBACK_LEGACYr6   rC   rM   r@   r[   )rm   r]   rV   rU   rR   rS   s         r8   _deprecated_callrp     sh    `  
% ---
	:  
 
}c 05(/
K 
K 
K Kr:   c                  ,    e Zd Zd Zd Zd Zd Zd ZdS )_CallbackWrapperc                n    || _         || _        || _        t          j        s|rt          d          d S d S )NzWhen using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs do not support `tap_with_device` and `call_with_device`. See https://github.com/google/jax/issues/20385.)rm   rW   rU   ro   r6   r@   )selfrm   rW   rU   s       r8   __init__z_CallbackWrapper.__init__,  sW    &DDM,D & =+; =<= = == = = =r:   c                D    t          | j        | j        | j        f          S N)hashrm   rW   rU   rt   s    r8   __hash__z_CallbackWrapper.__hash__6  s    #T]D4IJKKKr:   c                b    | j         |j         k    o| j        |j        k    o| j        |j        k    S rw   )rm   rW   rU   )rt   others     r8   __eq__z_CallbackWrapper.__eq__9  s9    %"55 <MU^+<!U%;;=r:   c                    t           j        r | j        |i |S | j        r|                     |d         d          S  | j        |i |S )Nr   rN   )ro   r6   _call_legacyrW   rm   )rt   argsr^   s      r8   __call__z_CallbackWrapper.__call__>  sa    " 1T////	 /!!$q'2...T0000r:   c                    | j         r5| j        r|                     |||          S |                     ||          S | j        r|                     ||          S |                     |          S )Ndevice)rW   rU   rm   )rt   r]   r   
transformss       r8   r   z_CallbackWrapper._call_legacyG  s    } 
'		 3!!#z&!AAA!!#z222		 '!!#f!555!!#&&&r:   N)rG   rH   rI   ru   rz   r}   r   r   rN   r:   r8   rr   rr   +  sb        = = =L L L= = =
1 1 1' ' ' ' 'r:   rr   )rV   rU   rR   rW   rS   c                  t           j        rt          t          j                   t	          j        |            t          j        |          \  }}|D ]}	t          j	        |	           i }
t          | ||          |
d<   ||
d<   ||
d<   ||
d<   |s^t          j        |          \  }}	 d |D             }n"# t          $ r d| }t          |          w xY w||
d<   t          |          |
d	<   t           j        r=t          j        |i |
}|s|                    |          n|                    |          S t#          j                    |         }t"          j                            |          }t          | ||          } |t*          j        u r&|sJ t"          j                            | |           |S |t*          j        u rt#          j        | |||
          }nt7          | |||d          }|s|n|S )N)max_callback_queue_size_bytescallbackrW   arg_treedefrR   c           	         g | ]=}t          j        t          j        |          t	          j        |d                     >S )T)canonicalize)r	   ShapedArraynpshaper   dtype).0rb   s     r8   
<listcomp>z_call.<locals>.<listcomp>t  sM     8 8 8   +BHQKKaVZ9[9[9[\\ 8 8 8r:   zresult_shape should be a pytree of values with structure matching the expected result of the callback function. The values must be either numeric scalars, or must have 'shape' and 'dtype' attributes. Got result_treedefflat_results_aval)shardingT)r   ordered)ro   r6   _initialize_outfeed_receiver"_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZEr   check_callabler   rY   r   rZ   rr   	Exception
ValueErrortupleoutside_call_pbind	unflattenjaxlocal_devicesr   SingleDeviceShardingrC   rM   debugr   rL   pure_callbackr   )rm   r]   rV   rU   rR   rW   rS   	flat_argsr   arg_paramsflat_results_shaper   r   r_   r`   callback_devicer   rc   s                      r8   r[   r[   W  s      P &H&NP P P P]###$1#66)[  dt&'x(8: :&&%&'&	 ;)2)?)M)M&8 8$68 8 8   7 )57 7c sOO  .F"'(9":":F  -!&	<V<<L9Aj>##L111{G\G\]iGjGjj')),7O|00AAH$]H%57 7M....ooo	i,,,j	N/	/	/"=,,46 6 6hh ]L#&.%)+ + +h $,88,s   *B7 7Ci   )r   re   rf   c               n   dfd}|rd |D             |d<   |||d<   d                     d	 t          |                                          D                       }dfdt          5  |r ||            |t	           |                                ddd           dS # 1 swxY w Y   dS )aZ  The consumer for id_print.

  We provide this as a simple tapping function for printing.
  This is **experimental** and may not want to add many features to it;
  it should be easy for the user to roll their own printing function.

  Args:
    device: the device from which the value originates (only if
      ``tap_with_device`` was used for :func:`id_print`).
    output_stream: a function whose `write` method is called with the strings to
      be output.
    threshold: the value of numpy.array2string threshold parameter.
    **kwargs: all other keyword args are printed before printing `arg`.
  sr-   c                ^                         | dz              d S t          |            d S )N
)writeprint)r   re   s    r8   emit_strz!_print_tap_func.<locals>.emit_str  s6     !d(#####Ahhhhhr:   c                $    g | ]\  }}|r||fn|S rN   rN   )r   namer   s      r8   r   z#_print_tap_func.<locals>.<listcomp>  s9     < < < ,f /5>T6NN$ < < <r:   r   Nr    c                "    g | ]\  }}| d | S )z: rN   )r   kvs      r8   r   z#_print_tap_func.<locals>.<listcomp>  s3       akkakk  r:   r.   pp.Docc                   t          | t                    rt          j        t          j        t          j        d          t          j        dt          j        t          j                    fd| D                                 t          j        d          g                    S t          | t                    rt          j        t          j        t          j        d          t          j        dt          j        t          j                    fd| D                                 t          j        d          g                    S t          | t                    rt          j        t          j        t          j        d          t          j        dt          j        t          j                    fd	t          |                                           D                                 t          j        d
          g                    S t          | t          j                  r(t          j        t          j        |                     S t          j        t!          |                     S )Nz( rE   c                &    g | ]} |          S rN   rN   r   epp_vals     r8   r   z3_print_tap_func.<locals>.pp_val.<locals>.<listcomp>  !    %=%=%=AffQii%=%=%=r:   z )z[ c                &    g | ]} |          S rN   rN   r   s     r8   r   z3_print_tap_func.<locals>.pp_val.<locals>.<listcomp>  r   r:   z ]z{ c                \    g | ](\  }}t          j        | d            |          z   )S )=)pptext)r   r   r   r   s      r8   r   z3_print_tap_func.<locals>.pp_val.<locals>.<listcomp>  sG     &
 &
 &
+/1a"'Q'''

VVAYY
&&
 &
 &
r:   z })rf   )
isinstancer   r   groupconcatr   nestjoinbrklistdictsorteditemsr   ndarrayarray2stringr-   )r]   r   rf   s    r8   r   z_print_tap_func.<locals>.pp_val  s
   #u Xbi

272688%=%=%=%=%=%=%=>>??
! 	 	 
 
 

 
C		 Xbi

272688%=%=%=%=%=%=%=>>??
! 	 	 
 
 

 
C		 Xbi

272688 &
 &
 &
 &
39#))++3F3F&
 &
 &
   	 	 	! 	 	 
 
 
 
C	$	$ WR_SI>>>???WSXXr:   )r   r-   )r.   r   )r   r   r   _print_tap_lockr-   )	r]   r   r   re   rf   r^   r   kv_pairsr   s	      ``   @r8   rj   rj     sf   "       << <0:< < <F<F8XX  $V\\^^44    (      4    hxHS                 s   ./B**B.1B.Sequence[core.ShapedArray]c                4    t          d | D                       S )Nc              3  b   K   | ]*}t          j        t          j        |                    V  +d S rw   )r	   raise_to_shapedget_aval)r   r   s     r8   	<genexpr>z#_values_to_avals.<locals>.<genexpr>  s7      DD!t#DM!$4$455DDDDDDr:   r   )valss    r8   _values_to_avalsr     s    	DDtDDD	D	DDr:   outside_callargs_ape.AbstractValueSequence[pe.AbstractValue]c                    | rd|vsJ d|vsJ |S |d         J |d         J |d         J |d         }d|v r*|d         r"t          |          dk    sJ ||dd          z   S |S )Nr   r   rR   	has_tokenrE   )len)rW   r   r   r   s       r8   _outside_call_abstract_evalr     s     6))))f,,,,M			+	+	+	 	!	-	-	-	#	$	0	0	001Fvk2v;;!vbcc{**r:   c                     d|vsJ t           j        r3|d         }t          j                    |         }t	          | |fddi|}|S t          j        t          g| R i |S )Nr   rR   send_infeedF)_HOST_CALLBACK_INLINEr6   r>   devices_outside_call_run_callbackr   apply_primitiver   )r   r   rR   r   resultss        r8   _outside_call_implr   #  s}    	F	"	"	"	"  E.)LZ\\,'F(vSS5SFSSGN #NDTDDDVDDDr:   c                    |                      |           	  ||i ||                                  S # |                                  w xY w)z7Builds op_fn(*args, **kwargs) with sharding annotation.)set_shardingclear_sharding)buildersharding_protoop_fnr   r^   s        r8   _with_sharding_protor   8  sZ    	~&&&5$!&!!Gs	   3 A	rN   )r   args_opXlaOpc          	        |sJ t          | j                  }	|	s
J d            |d         }
|d         }| j                            |
                                          r'                    |                                          s
J d            |d d         }t          t          d |                    }| ot          |          dk    }|	o|}d}t          t          j
        | j                             t          t          j        t          f|||d|          }t          j                            |
|||          }|rt          |          }|}nfd	|D             |r|sJ t$                              ||g          }t)          j                    }t(          j        j        j        |_        d
g|_        |g|_        t)          j                    }t(          j        j        j        |_        t9          j        |gt          |          z  |gz             }d |D             }t          j        t$          j        |t(          j                             |                    }tC          ||          }t$          "                    |d          t$          "                    |d
          }fdtG          t          |                    D             d}fd|D             }n}|}||k    sJ d| d| d            |sMt          |          t          |          k    s-J dt          |           dt          |           d|             |||gz   S )Nz3Should be using MLIR path for `CustomCall` loweringr   %The last two arguments must be tokensc                "    t          |            S rw   _aval_is_emptyavals    r8   <lambda>z0_outside_call_translation_rule.<locals>.<lambda>W  s    nT>R>R9S r:   r   Fr   rW   r   c           	         g | ]I}t          |          t                              t          j        |j        |j                            JS rN   )r   xopsConstantLiteralr   zerosr   r   )r   r   comps     r8   r   z2_outside_call_translation_rule.<locals>.<listcomp>l  sW       $T28DJ
#C#CDD  r:   rD   c                d    g | ]-}t          j        |          D ]}|                                .S rN   )r   aval_to_xla_shapes$with_major_to_minor_layout_if_absent)r   xr   s      r8   r   z2_outside_call_translation_rule.<locals>.<listcomp>  sV       -a00   
4
4
6
6   r:   c                F    g | ]}t                               |          S rN   )r   GetTupleElement)r   ioutss     r8   r   z2_outside_call_translation_rule.<locals>.<listcomp>  s9        

tQ
'
'  r:   Tc                    g | ];}t          |          r                    d           n                    d           <S )r   )r   pop)r   result_avalempty_resultsnon_empty_resultss     r8   r   z2_outside_call_translation_rule.<locals>.<listcomp>  s]         K((G-

A


.?.C.CA.F.F  r:   zgenerated_infeed (z) != send_infeed ()got  but expected . identity = )$r9   r,   r   	get_shapeis_tokenr   filterr   rA   r>   get_backend_register_callbackrh   ri   r   _callback_handler_datareceiveradd_outfeedr   AfterAllr   
OpShardingTypeMAXIMALtypetile_assignment_dimensionstile_assignment_devices
REPLICATEDr   tuple_sharding_protoInfeedWithTokenShapetuple_shaper   r  range) ctxavals_in	avals_outr   rW   rR   r   r   r   use_outfeedcurrent_tokencurrent_itokenargs_to_outfeednon_empty_flat_results_avalneed_callback_results_on_devicer   generated_infeedcallback_id
next_tokenr   next_itokenafter_outfeed_itokenarray_sharding_prototoken_sharding_protoinfeed_sharding_protor   build_infeedouts_and_tokenr  r  r  r	  s                                @@@@r8   _outside_call_translation_ruler;  A  sL    
S\**+	KKKKKK"+-2;.	$		&	&	/	/	1	1 /dnn^6T6T6]6]6_6_ / /-/ / / CRCL/ !%V,S,S,=&? &? !@ !@)1\ &J%()D%E%E%I "? ?+)".*F*FGGG"
$!-	 
   + &.::
M;G G* 3#?##G KK   %  M
 # *#,,,,!]]4.*1MNN (244","7"<"D9:56B^2'244","7"<"G!6
 3'B#C#C
C
 !" " .  e &t';(<(2(8(D(DU(K(KM Ml ,D2G,8: :n!!.!44d((;;k   899::        /  gg g"k	[	(	(	(M+MM{MMM 
)	(	(	  S\\S):%;%;;;;S\\  ->)?)?    <;; 
J,	,,r:   r(  mlir.LoweringRuleContextr   rW   rR   intc          
     6   t          | j        j                  dk    rt          d          | j        j        d         }t	          |          }|r( t          j        t                    | g|R ||dS |dk    rd| j        j        v rt          d          |sJ |d         }	|d         }
|	j	        t          j                                        k    s
J d	            |
j	        t          j                                        k    s
J d	            |d
d         }t          j                    }|g|}t          j        dt           j                  g| j        d
d         }rg }ng }fd}t'          | j        j        t*          j        t*          j        f          r?t1          j                    }t0          j        j        j        |_	        dg|_        |g|_        nd
}t          j        | ||	|||d|          \  }}}t>          j         !                    |           rtE          |          }|
}sMt          |          t                    k    s-J dt          |           dt                     d             tE          |          ||gz   S )z)MLIR Lowering for `CustomCall`-based HCB.rD   z)multi-platform lowering for host_callbackr   )r   rW   r   rR   cpuz>The device_index feature on CPU works only when using outfeed.r   r   r   NrN   c                 j    | ^}}t          |t          j                    |         fdd}rd}|S )NFr   rN   )r   r>   r   )r   
replica_idarraysresult_arraysr   rW   r   s       r8   wrapped_callbackz0_outside_call_lowering.<locals>.wrapped_callback  sd    J.
:& +   M  mr:   T)has_side_effectr   r  r  r  )#r   module_context	platformsr@   r9   r   xla_fallback_loweringr   r   r  r!   	TokenTypegetReplicaIdOpr	   r   r   uint32r)  r   axis_contextr   SPMDAxisContextShardingContextr   r  r  r  r   r!  emit_python_callbackr  keep_alivesappendr   )r(  r   rW   rR   r   r   r   r,   r+  r,  r-  r.  rA  callback_operandscallback_operand_avalscallback_flat_results_avalrD  r   r   r3  
keep_aliver4  s     ` ` `               r8   _outside_call_loweringrW    s7    			%&&**
I
J
JJ)!,(X&&+ L 64%n55	  +!     	es'9'CCC
JL L L 
r(-8.		s}0022	2	2	24[	2	2	2		 1 1 3 3	3	3	35\	3	3	3"I/   *!4O4
r29%%;(+SbS(9; 6!#!5#4!5       	%%~'EF  

 $&&H).6HM+,#H'(4~H$$H$($=c'88X%/ %/ %/!':z $++J777 $?##G+	  S\\S):%;%;;;;S\\  ->)?)?    <;; 
g*k2	22r:   r?  r7   )r   r   r   r   r   c                  dd}
	 t          j        ||           } |
|          }t                              d|||            ||||          }|rt	          |           S |J |J t          j        |          \  }}||k    rd| d| d| }t          |          t	          t          j	        t          j        |                    }t          |          }t                              d	|||           t          d
 t          j        ||          D                       s@d| d|                    |           d|                    |           }t          |          |r3t	          t!          d |                    }|                    |           |S # t$          $ r}t                              d||           |r|j        dk    rrt          j        t+          j        dt*          j                            g}t                              d||           |                    t	          |                     nt                              d||           |d}~ww xY w)aH  Performs the callback:
       callback(arg, device, transforms)

  Called during the device computation once we have the argument, either from
  an inlined callback or from an XLA computation outfeed.

  Returns the flat list of result arrays. If `send_infeed` then it will also send
  the flat list of results to the device.
  r.   &tuple[tuple[str, dict[str, Any]], ...]c                @    d t          fd| D                       S )Nc                    | dk    r| t          |d                   fS | dk    r| t          d          fS |rJ |  d|             | i fS )Nbatchr   )
batch_dimsmask   )logical_shapes, )r   )r   r   s     r8   _unpack_transformzQ_outside_call_run_callback.<locals>._unpack_transforms.<locals>._unpack_transform  so    	TVAY/////6>>T+++++..d..f.....Rxr:   c              3  "   K   | ]	} | V  
d S rw   rN   )r   trb  s     r8   r   zI_outside_call_run_callback.<locals>._unpack_transforms.<locals>.<genexpr>   s,      ;;1""A&;;;;;;r:   r   )r   rb  s    @r8   _unpack_transformsz6_outside_call_run_callback.<locals>._unpack_transforms  s7       ;;;;
;;;;;;r:   z<Outside call invoking call_func %s, device=%s, transforms=%sNzCallback func z+ should have returned a result with pytree z but returned z;Outside call %s result %s. Sending to infeed for device %s.c              3  p   K   | ]1\  }}|                                 |                                 k    V  2d S rw   )strip_weak_type)r   earas      r8   r   z-_outside_call_run_callback.<locals>.<genexpr>>  s_       H HR ##%%););)=)== H H H H H Hr:   z4 should have returned a result with abstract values c                "    t          |            S rw   r   )rb   s    r8   r   z,_outside_call_run_callback.<locals>.<lambda>I  s    nUVFWFWBW r:   z#Outside call %s threw exception %s.r1   i90  )r   zJOutside call consumer %s exception %s. Sending to infeed the error result.zDOutside call consumer %s exception %s. On TPU we do not send infeed.)r.   rY  )r   tree_unflattenloggerr   r   r   rY   rX   r   safe_mapr   canonicalize_dtyper   allsafe_zipr   r  transfer_to_infeedr   errorr,   r   arangeint8)rB  r   r   r   r   rW   r   r   r   r   re  r]   unpacked_transformsresactual_flat_resultsactual_result_treedefr_   canonical_flat_resultsactual_flat_results_aval non_empty_canonical_flat_resultsr   s                        r8   r   r     s   "
< 
< 
< 
<?

[&
1
1C,,Z88
LLD+   (3 3
4
4C "$6]] '''***3<3I#3N3N00	.	0	0* * *,* *'* * nn$T]33IK^%_%_``!12H!I!IllE#V
 
 

  H H#}->-E G  GH H H H H [ [ [ **+<==[ [  5>>?WXX[ [ nn	 D+08W8W8N2P 2P ,Q ,Q(!!"BCCC##	   
LL6!DDD " 
E	!	!"%"85PRPW9X9X9X"Y"Y!Zaq	" 	" 	"!!%(>"?"?@@@@[q	" 	" 	"
G's    AF !D9F 
I#%B9II#r   r   r   c                `    |g|R }t          | |                     dd          |fz             S )zAdds the `transform` to the params["transforms"].

  Uses a tuple representation internally, will be unpacked before the
  callback by _ConsumerCallable.
  r   rN   )r   )r   rJ  )r   r   transform_paramsnew_transforms       r8   _add_transformr  d  sK     +*++-	&**\266-9II
L 
L 
L Lr:   c                <    t          j        | j                  dk    S )Nr   )mathprodr   r   s    r8   r   r   o  s    	4:		!	##r:   c                ,    ~t          j        |           S rw   )r   instantiate_zeros)tanr]   s     r8   _instantiate_zerosr  r  s    			c	"	""r:   c                    d|vsJ |d         st          d          t          j        | i |}t          |          |fS )Nr   rW   z6JVP rule is implemented only for id_tap, not for call.)r@   r   r   r   )primalstangentsr   out_primals_tappeds       r8   _outside_call_jvp_ruler  v  sZ    	F	"	"	"	"	
	 X
V
W
WW%*G>v>>	!	"	"H	,,r:   c                P   |d         st          d          d|vsJ t          |           t          |          k    sJ t          t          t          | |                    }|                    dd          }|r|d         dk    rt          j        |i t          |d          S J )	NrW   zDdifferentiation rules are implemented only for id_tap, not for call.r   r   rN   r   )jvp	transpose)	r@   r   r   mapr  rJ  r   r   r  )ctsr   r   cts_instantiatedr   s        r8   _outside_call_transpose_ruler    s    	
	 f
d
e
ee	F	"	"	"	"	SSYY				313==>> zz,++*	 /z"~11 	/

-
-/ / / ,r:   c                    |d         st          d          d|vsJ t          |d|          }t          j        | i |}||fS )NrW   z=batching rules are implemented only for id_tap, not for call.r   r\  )r@   r  r   r   )batched_argsr]  r   
new_paramsrv  s        r8   _outside_call_batching_ruler    s`    	
	 _
]
^
^^	F	"	"	"	"fgz::*\8Z88#	jr:   cjaxprcore.ClosedJaxprhas_input_tokenhas_output_tokenc                b    t          | j        ||          }t          j        || j                  S )z6Rewrites a ClosedJaxpr to thread the token, if needed.)_rewrite_jaxprjaxprr	   ClosedJaxprconsts)r  r  r  	new_jaxprs       r8   _rewrite_closed_jaxprr    s-     V\?<LMM)		)V]	3	33r:   r  
core.Jaxprc                   |s|rJ |st          j        |           s| S t          j                    }g } |t           j                  } |t           j                  }|r| j        ||gz   }n| j        }|                    t          j        | j        |gt          j        i t           j	        t          j                                         |                    t          j        | j        |gt          j        i t           j	        t          j                                         | j        D ]p}t          j        |j        |j                  s|                    |           7 ||j                  }	 ||j                  }
t#          ||||	||
|           |	}|
}q| j        |r||gng z   }t          j        | j        |||| j                  }|S )z/Rewrite a Jaxpr to thread the token, if needed.)r	   jaxpr_uses_outfeedgensymabstract_tokeninvarsrR  new_jaxpr_eqnr   create_token_p
no_effectsr   currenteqnsprimitive_uses_outfeed	primitiver   r   _rewrite_eqnoutvarsJaxpr	constvarseffects)r  r  r  
mk_new_varr  last_token_varlast_itoken_varr  eqnoutput_token_varoutput_itoken_varr  r  s                r8   r  r    s    
0 0000	 !8!?!? L{}}* $:d122.Jt233/ 
a\^_==FF\FKK5<.)9-r4?DTD\D^D^	` 	`a a a 	KK5</):-r4?DTD\D^D^	` 	`a a a Z 	* 	*c&s}cjAA *
kk##N$788$*_%9::3n.>"$5zC C C'n)ooMBRZno>>XZ['j&'4OO)	r:   r  core.JaxprEqnr  list[core.JaxprEqn]input_token_varcore.Varr  input_itoken_varr  r  (Callable[[core.AbstractValue], core.Var]c                ~   | j         t          u r`d| j        vsJ |                    |                     | j        ||gz   | j        ||gz   t          | j        d                               dS | j         t          j	        u rt          j        | j        g d          \  }}}	}t          j        |j                  rt          | ||||||           dS |                    |                     | j        ||gz   | j        ||gz   t          | j        t!          |	dd          t!          |dd                                         dS | j         t          j        u rt          j        | j        d	g          \  }
| j        ^}}|g|||}|                    |                     || j        ||gz   t          | j        t%          d
 |
D                                                      dS | j         t          j        u rt          j        | j        g d          \  }}}}}}}}||z   }| j        d|         ||gz   | j        |d         z   }t!          |dd          }|j        j        }|d|         |dd         z   ||d         z   }|                    |j                            |                    }|j        j        }|d|         |dd         z   ||d         z   }|                    |j                            |                    }|                    |                     || j        d|         ||gz   | j        |d         z   t          | j        ||dz   |d|         dz   ||d         z                                  dS | j         t(          j        u rt-          t          j        | j        d                   }|                    |                     | j        ||gz   | j        ||gz   t          | j        t1          |dd          | j        d         dz   | j        d         dz   | j        d         dz                                  dS | j         t2          j        u r}| j        d         }d }d |_        |                    |                     | j        ||gz   | j        ||gz   t          | j        t!          |dd          |                               dS | j         t2          j        u rz| j        d         }g | j        ||}d  }|                    |                     || j        ||gz   t          | j        t!          |dd          |d!d!"                               dS | j         t:          j        u r t-          t          j        | j        d#                   }|                    |                     | j        ||gz   | j        ||gz   t          | j        t!          |dd          | j        d         dz   | j        d$         t@          j!        t@          j!        fz   | j        d%         t@          j!        t@          j!        fz   | j        d&         dz   | j        d'         dz   (                               dS | j         tD          j#        u rt-          t          j        | j        d#                   }|                    |                     | j        ||gz   | j        ||gz   t          | j        t1          |dd                                         dS tI          d)| j                    )*a  Rewrite an `eqn` and append equations to `eqns`.

  This is only called if the current primitive uses outfeed.
  Assume that the current token is in `input_token_var` and the resulting
  token must end in `output_token_var`.

  Append the result of rewriting to `eqns`.
  r   T)r   )r  r  r   
cond_jaxprcond_nconsts
body_jaxprbody_nconstsNF)r  r  branchesc              3  8   K   | ]}t          |d d           V  dS )TN)r  )r   r  s     r8   r   z_rewrite_eqn.<locals>.<genexpr>  sB       + + *%t<<+ + + + + +r:   )r  )
num_consts	num_carryr  linearreverselengthunroll_split_transposer   r   )r  )r  )r  rE   )FF)r  r  r  
call_jaxprdonated_invarsin_axes)NNout_axes)r   r   )r  r  r  r  c                     J d            NFzShould not be reachedrN   rN   r:   r8   unreachable_thunkz'_rewrite_eqn.<locals>.unreachable_thunk9      +++++r:   c                     d S rw   rN   rN   r:   r8   r   z_rewrite_eqn.<locals>.<lambda>;  s    T r:   )r  jvp_jaxpr_thunk	fun_jaxprc                     J d            r  rN   rN   r:   r8   r  z'_rewrite_eqn.<locals>.unreachable_thunkJ  r  r:   zillegal param)r  fwd_jaxpr_thunkbwd	out_treesr  in_shardingsout_shardings
in_layoutsout_layouts)r  r  r  r  r  r  zoutfeed rewrite )%r  r   r   rR  replacer  r  r   r   while_pr   
split_dictr	   r  r  _rewrite_while_outfeed_condr  cond_pr   scan_pr   
xla_pmap_pr   r  r  r   custom_jvp_call_preset_storescustom_vjp_call_jaxpr_pr   pjit_pr  r   UNSPECIFIEDr   remat_pr@   )r  r  r  r  r  r  r  r  ra   r  r  indexoperands
new_invarsr  r  carry_jaxprr  nr_const_and_carryr  new_jaxpr_invarsnew_jaxpr_outvarsr  r  r  r  jaxpr_s                              r8   r  r    s    	]n$$cj((((KK3:BR0S#S$'K3CEV2W$W#'
d#C#C#C  E E F F F F F }###'?
DDD$F $F J:q z/00 !#t_>N"24E",. . . fKK:2B CCK#35F"GG
0T4HH0T5IIK K K 	 	L 	LM M M M M }
""
ZL99IHzEHF(FOF5EFJKKs{6FHY5Z'Z
 + +!)+ + + + +, , , 	 	- 	-. . . . . }
""=A_
	' 	' 	'>( >(:J	;1a
 $i/A001)5+ +-0Z8J8K8K-LMJ%k4>>I ---.1A"##1FF+B./	0  !!	(?(?GW(?(X(X!YYI!/!I+&):233)??)B,'	(  !!	(?(?HY(?(Z(Z![[IKK[9-1ACT0UU[,-
#a-a 223nDvN`NaNaGbb	d d d 	 		e 		e
f 
f 
f 
f 
f }''dj#*\":;;JKK:2B CCK#35F"GG
)*dDAA"z*:;nL 
9-<J/&8: : : 	 
	; 
	;< < < < < }*<<<
<(I, , ,%1\"KK:2B CCK#35F"GG
0D$GG 1   	 	 	     }*BBB
;'IA3:AA0@AJ, , , 	KKK#35F"GG
/	4FF 1 $)+ + + 	 	, 	,- - - - - }##!3:g#677EKK:2B CCK#35F"GG
+E4>>"z*:;nLJ~.%1>3MNO J/%1>3MNO  J|4|C Z6E   	 	
 	
    * }---$*cj122FKK:2B CCK#35F"GG
$VT488   	 	 	     @@@
A
AAr:   c                   t          j        | j        g d          \  }}}	}
t          |dd          }| j        ||
z   d         }fd|j        j        D             }|                    t          j	        | j        d|         |z   ||gz   |t          j
        t          |j        d          |j        j        | j                              |j        d                   }|gfd|D             z    |j                   |j                  gz   }t          j        t          j        g ||gg t%                                g           }t          |	dd          }fd	| j        d|         D             }fd
| j        |||
z            D             } |j        d                   }fd|D             } |j                  } |j                  }fd|D             } |j                  } |j                  } |j        d                   } |j                  } |j                  }t          j	        ||z   ||gz   |||gz   t          j
        t          |j        d          |j        | j                  t          j	        ||z   ||gz   |||gt          j
        t          |j        d          |j        | j                  g}t          j        d |D              }t          j        t          j        g ||z   |gz   |z   ||gz   |g|z   ||gz   ||          g           }  |j        d                   }!|                    t          j	        | j        d||
z            |d         gz   |z   |dd         z   |!g| j        z   ||gz   t(          j        t          |d| ||
z             | j        | j                             dS )z&Rewrite a while whose cond has outfeedr  TNc                0    g | ]} |j                   S rN   r   )r   ovr  s     r8   r   z/_rewrite_while_outfeed_cond.<locals>.<listcomp>  s2        jj  r:   r   cond_before)r  r   c                0    g | ]} |j                   S rN   r   r   cvr  s     r8   r   z/_rewrite_while_outfeed_cond.<locals>.<listcomp>  s%    JJJrzz"'22JJJr:   c                0    g | ]} |j                   S rN   r   r   r   r  s     r8   r   z/_rewrite_while_outfeed_cond.<locals>.<listcomp>  s2     $ $ $jj$ $ $r:   c                0    g | ]} |j                   S rN   r   r  s     r8   r   z/_rewrite_while_outfeed_cond.<locals>.<listcomp>  s5     $ $ $
 j$ $ $r:   c                0    g | ]} |j                   S rN   r   r  s     r8   r   z/_rewrite_while_outfeed_cond.<locals>.<listcomp>  s%    FFF2::bg..FFFr:   c                0    g | ]} |j                   S rN   r   r  s     r8   r   z/_rewrite_while_outfeed_cond.<locals>.<listcomp>  s%    @@@RZZ((@@@r:   body	cond_bodyc              3  $   K   | ]}|j         V  d S rw   )r  )r   r  s     r8   r   z._rewrite_while_outfeed_cond.<locals>.<genexpr>  s$      EEEEEEEEr:   rD   )r   r  r   r  r  r  r  rR  r	   r  call_pr   r  source_info	out_avalsr   r  r  setjoin_effectsr   r  )"r  r  r  r  r  r  r  r  r  r  r  transformed_cond_jaxprcarry_invarspred1_and_token1new_cond_pred_invarnew_cond_invarsnew_cond_jaxprtransformed_body_jaxprnew_body_invars_cond_constvarsnew_body_invars_body_constvarsnew_body_invars_prednew_body_invars_carrynew_body_invars_tokennew_body_invars_itokennew_body_carry2new_body_token2new_body_itoken2new_body_pred2new_body_token3new_body_itoken3new_body_eqnsr  new_body_jaxprpred_outs"         `                           r8   r  r    s     8<	jNNN8P 8P4*lJ0T4HHL<7889,   $:$@$H   ++

*Q|^
$|
3HX6Y
Y
DK
/5 " " " !
&
.
/    #
:#7#:;;JJJJ\JJJJz/&''z"'((**  #
j_':&;RGGM M. 1T4HH$ $ $ $"%*Q|^"<$ $ $ $ $ $ $z,|l'BBC$ $ $  $J$8$;<<FFFFFFF$*_%9::%:&6&;<<@@@@<@@@/J344/Z 0 566:j21566.J344/Z 0 566 
(+@
@ "8
9:
_.>?
?
+
/5   !
(
/	 	 
(?
:oO_=`
`?,<
=t{
/5      !
(
/ -( EE}EEEF'#
j4458L7MN+,/DF\.]^ ""_4IY7ZZ	) ) +-. .. Z
,Q/00(++
:a|3348H8K7LL*122./:#'79J&KK
+
''',6	8 8 8
 
 
/     r:   idc                     | S rw   rN   r   s    r8   r   r     s    D r:   c                     | S rw   rN   r#  s    r8   r   r     s    T r:   c                    |S rw   rN   )r(  r)  r*  r   s       r8   r   r     s    t r:   c                $    t          | dd          S )NF)r  )js    r8   r   r     s    nQu&E&E r:   c                      e Zd ZdZdS )CallbackExceptionzSignals that some callback function had exceptions.

  Raised by :func:`barrier_wait`. See the :mod:`jax.experimental.host_callback`
  module documentation for details.
  N)rG   rH   rI   rJ   rN   r:   r8   r)  r)    s         
 $r:   r)  c                  z    e Zd ZU dZded<   ded<   ded<   ded<   d	ed
<   ded<   ded<   ded<   ded<   d Zd ZdS )_CallbackHandlerDataz(Keep track of the outfeed receiver data.r   r  r/   initializedon_exitzthreading.Locklockztuple[Exception, str] | Nonelast_callback_exceptionztuple[XlaLocalClient, ...]clientsztuple[XlaDevice, ...]r   zdict[Callable, int]consumer_registryzdict[int, Callable]consumer_registry_by_idc                    d | _         d| _        d| _        t          j                    | _        d | _        d| _        d| _        i | _	        i | _
        g | _        d S )NFrN   )r  r,  r-  	threadingLockr.  r/  r0  r   callback_registrycallback_registry_by_idrQ  ry   s    r8   ru   z_CallbackHandlerData.__init__  s_    DMDDL  DI#'D DLDL
  D#%D  Dr:   c                >    d| _         d| _        d| _        d| _        dS )z4Wait for all pending outfeeds and stop the receiver.NFrN   )r  r,  r0  r   ry   s    r8   stopz_CallbackHandlerData.stop  s#    DMDDLDLLLr:   N)rG   rH   rI   rJ   __annotations__ru   r9  rN   r:   r8   r+  r+    s         00------7777%%%%    ((((....  $    r:   r+  rB  r   c                   d                     d |D                       }t                              d| ||           t          j                            |          }|
J d            	  |||           S # t          $ rG}t          j                    }t          	                    d|           ||ft          _
        Y d }~d S d }~ww xY w)Nra  c                2    g | ]}d |j          |j         dS )(r  )r   r   )r   as     r8   r   z,_callback_input_received.<locals>.<listcomp>$  s-    BBBa1ag1qw111BBBr:   z?Callback input received on device %s for consumer %s arrays: %sz%We should have crashed in the runtimez4Postponing exception raised in callback function: %s)r   rl  r   r  r7  rJ  r   	traceback
format_excrr  r/  )r   consumer_idrB  
array_reprr   r   formatted_es          r8   _callback_input_receivedrD  #  s    yyBB6BBBCC*,,P
K% % %#;??LL(			F			F8FF###	 F F F&((K
LLGUUU675E2222222Fs   )A5 5
C?<CCr   c                    t           j                            |           }||S t          |           dz  }|dz  }|t           j        vs
J d            |t           j        | <   | t           j        |<   |S )zRegisters a callback function, cache by hash of callback.

  The callback is a function to be invoked as `callback(arrays, device)`.
  Nl    rD   zcallback id collision)r  r6  rJ  rx   r7  )r   r2  s     r8   r  r  1  s    
 '8<<XFF+X+++	2D	D	D	D 
E	D	D7B*84@H0=	r:   r   c           
        t           j        }t          j        5  t          j        r	 ddd           dS d t          j                                                    D             }t          t          j
        d |D                        }|t          _        |t          _        d |D             }|D ]}t          |           |rt          t          j
        d |D                        }t                              t           j                  r;d                    d |D                       }t                              d||            |                    t*          t-          |          | t/          j        d	d	          j                  t          _        d
 }t7          j        |           dt          _        ddd           dS # 1 swxY w Y   dS )az  Creates and starts the outfeed_receiver.

  This function is called lazily only when we compile an id_tap.

  Args:
    * clients: the list of clients (backends) on whose devices to listen on.
    * max_callback_queue_size_bytes: an optional integer to bound the maximum
      size of arrays in the callback queue. When this limit is reached the
      device listener pauses.
  Nc                     g | ]\  }}|d v 	|S ))r?  r3   r4   r1   rN   )r   r   r;   s      r8   r   z0_initialize_outfeed_receiver.<locals>.<listcomp>U  s1     : : :=4888 888r:   c                6    g | ]}|                                 S rN   r   r   r;   s     r8   r   z0_initialize_outfeed_receiver.<locals>.<listcomp>X  s$    IIIg'//11IIIr:   c                :    g | ]}t          |j                  |S rN   )r9   r,   )r   cs     r8   r   z0_initialize_outfeed_receiver.<locals>.<listcomp>[  s'    KKK!,qz2J2JKAKKKr:   c                6    g | ]}|                                 S rN   rI  rJ  s     r8   r   z0_initialize_outfeed_receiver.<locals>.<listcomp>`  s$    VVVg'//11VVVr:   ra  c                ,    g | ]}t          |          S rN   )r-   )r   ds     r8   r   z0_initialize_outfeed_receiver.<locals>.<listcomp>b  s     F F FAQ F F Fr:   zBStarting outfeed_receiver for %s. max_callback_queue_size_bytes=%srD   c                 p    dt           _        t          j        sdt          _        t	          d           d S d S )NTat_exit)r   _on_exitr  r-  _deprecated_barrier_waitrN   r:   r8   exit_handlerz2_initialize_outfeed_receiver.<locals>.exit_handlerj  s<    h#+ ,)-& +++++, ,r:   T)r    outfeed_receiverr  r.  r,  r>   backendsr   r   	itertoolschainr0  r   rA   rl  isEnabledForloggingrM   r   r   startrD  r   r   get_compile_optionsexecutable_build_optionsr  atexitregister)	r   outfeed_receiver_moduler0  r   clients_with_outfeedclientdevices_with_outfeeddevice_reprrT  s	            r8   r   r   B  sJ    *:" ". ".) ". ". ". ". ". ". ". ".
: :BKMM,?,?,A,A : : :GIIIIIJL LG%,"%,"KKwKKK& 6 6-f5555 
G!VVAUVVVWY Y			W]	+	+ Cii F F1E F F FGGY"$A	C 	C 	C(?(E(E
"E*>$?$?
'

&q!
,
,
E)G )G%
, , , OL!!!)-&E". ". ". ". ". ". ". ". ". ". ". ". ". ". ". ". ". ".s   GFGG	Glogging_name
str | Nonec                   	 t           j        st          j                     dS  pd t                              d            t          j                    	t          j        	          g 	 fdt          t          j                  D ]Y\  }}t                              d |           t          j        ||          } t          j        fd|          |           Zt                              d	            	5                      fd
           ddd           n# 1 swxY w Y   t                              d            t          j        .t          j        \  }}dt          _        t#          d|           |dS )a5  Blocks the calling thread until all current outfeed is processed.

  Waits until all callbacks from computations already running on all devices
  have been received and processed by the Python callbacks. Raises
  CallbackException if there were exceptions while processing the callbacks.

  This works by enqueueing a special tap computation to all devices to which
  we are listening for outfeed. Once all those tap computations are done, we
  return from barrier_wait.

  Note: If any of the devices are busy and cannot accept new computations,
  this will deadlock.

  Args:
    logging_name: an optional string that will be used in the logging statements
      for this invocation. See `Debugging` in the module documentation.

  For more details see the :mod:`jax.experimental.host_callback` module documentation.
  N zbarrier_wait[%s]: start)r.  c                   t           j        |          }t                              d|t	          j                               5                      |           t                              t          j	                  rBfdt           j        D             }t                              dt          |          |                                            d d d            d S # 1 swxY w Y   d S )Nz9barrier_wait[%s]: at barrier_tap for device %s. Thread %sc                    g | ]}|v|	S rN   rN   )r   rO  devices_at_barriers     r8   r   zJ_deprecated_barrier_wait.<locals>.barrier_tap_received.<locals>.<listcomp>  s.     ? ? ?Q"#+="="=  !"="="=r:   z>barrier_wait[%s]: still waiting for %s devices at barrier (%s))r  r   rl  r   r4  current_threadrR  rY  rZ  rM   r   notify)dev_idxra   r   waiting_for_devicesr  rk  r.  re  s       r8   barrier_tap_receivedz6_deprecated_barrier_wait.<locals>.barrier_tap_received  s6   #+G4F
LLAFI466   
 	 	'''			W]	+	+ 
? ? ? ?*@*H ? ? ?
J
/002E	
 	
 	
 	iikkk	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	s   BC!!C%(C%z1barrier_wait[%s]: enqueueing barrier on device %sr   c                $    t          |           S rw   )rd   )r  rp  s    r8   r   z*_deprecated_barrier_wait.<locals>.<lambda>  s    ()=qAA r:   z'barrier_wait[%s]: waiting for callbacksc                 X    t                     t          t          j                  k    S rw   )r   r  r   )rk  s   r8   r   z*_deprecated_barrier_wait.<locals>.<lambda>  s!    .//37M7U3V3VV r:   zbarrier_wait[%s]: donez@There were exceptions during callback processing. Last one was: )ro   r6   r   effects_barrierrl  r   r4  r5  	Condition	enumerater  r   r   
device_putjitwait_forr/  r)  )
re  d_idxrO  x_on_devlast_exceptionformatted_last_exceptionrp  r  rk  r.  s
   `     @@@@r8   rS  rS  u  s   ( 
	$ 
F#,,,(,777			$%%%"       " 2:;; W WheQ
LLDlTUVVV~eA...HLCGAAAA!LLLXVVVV,,8,GGG X XKKVVVVWWWX X X X X X X X X X X X X X X 	,,'6663?/E/],N,592
	41	4 	45 5:HI @?s   D--D14D1c                 8    t                                            dS )aX  Stops the outfeed receiver runtime.

  .. warning::
    The host_callback APIs are deprecated as of March 20, 2024.
    The functionality is subsumed by the
    `new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_

  This waits for all outfeeds from computations already running on all devices,
  and then stops the outfeed receiver runtime. The runtime will be restarted
  next time you use a tap function.

  It should not be necessary to use this function, unless you want to start
  using lax.outfeed directly after having used host callbacks.
  N)r  r9  rN   r:   r8   !_deprecated_stop_outfeed_receiverr~    s     r:   zThe host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the new JAX external callbacks. See https://github.com/google/jax/issues/20385.)id_tapid_printcallbarrier_waitstop_outfeed_receiver)deprecation_getattr)r,   r-   r.   r/   )r;   r<   )rm   r   )r.   r   )r   r   r.   r   )r   r   )r(  r<  r   r/   rW   r/   rR   r=  )r   r   r   r-   r.   r   )r.   r/   )r  r  r  r/   r  r/   r.   r  )r  r  r  r/   r  r/   r.   r  )r  r  r  r  r  r  r  r  r  r  r  r  r  r  )r  r  r  r  r  r  r  r  r  r  r  r  r  r   )rB  r   )r   r   r.   r=  )r   r=  rw   )re  rf  )rJ   
__future__r   r^  enumcollections.abcr   r   rh   rW  rZ  r  r4  r?  typingr   r   r   jax._srcr   r	   r
   r   r   r   jax.experimentalr   r   jax._src.interpretersr   r   r   r   r   per   r   r   r   r   r   r   r   r   r   r   r>   jax._src.libr   r    jax._src.lib.mlir.dialectsr!   numpyr   	bool_flagbool_envr   int_flagint_envr=  r   r5   ro   	getLoggerrG   rl  r9   rA   _xlaopsr   r   r%  XlaShape
XlaBuilderDevice	XlaDeviceClientXlaLocalClientDTypeEnumrC   rK   rd   rl   rp   rr   r[   r5  r   rj   r   	Primitiver   multiple_resultsoutfeed_primitivesaddr   def_abstract_evalr   def_implr   r;  register_translationrW  register_loweringr   r  r   r  r  primitive_jvpsr  primitive_transposesr  primitive_batchersr  r  r  r  id_poutfeed_rewriterr   r)  TapFunctionExceptionr+  r  rD  r  r   rS  r~  _deprecation_msg_deprecationsTYPE_CHECKINGr  r  r  r  r  jax._src.deprecationsr  _deprecation_getattr__getattr__rN   r:   r8   <module>r     s	  f fP # " " " " "   . . . . . . . .                           



                   " " " " " "             ! ! ! ! ! ! ( ( ( ( ( ( 4 4 4 4 4 4 4 4 4 4 & & & & & & 4 4 4 4 4 4 % % % % % % " " " " " "             ) ) ) ) ) ) # # # # # # % % % % % %             % % % % % % # # # # # # & & & & & & * * * * * *     )(FO.66	@   
 &5V_+FN:CC	NNKK= H& & & " *)FO/77	Q	    )(FO.55	:	    
	8	$	$( ( ( (
	 	 	 	 "
	"    TY     )5O O O O Oh "+7.' .' .' .' .'d '3	:K :K :K :K :K :K@'' '' '' '' '' '' '' ''^  (4<- <- <- <- <- <-B !).""  $$= = = = =@E E E E)T  //"&     N + + +   $     !< = = =E E E$   * + + +    68a- a- a- a- a- a-H  )G H H H .0Y3 Y3 Y3 Y3 Y3 Y3v  ~'= N N N N  "TU\ \ \ \ \~L L L L$ $ $ $# # #- - - %; . !  & +G  '   /J N +4 4 4 4' ' ' 'TeB eB eB eBPb b b bL t~d     ! ! !   )) * * *  KK L L LEE     	    ) # # # # # # # #N .-// F F F F   $ *-Y0. 0. 0. 0. 0.f@I @I @I @I @IF     $6   !34!#78/0%'?@.0QR  	 	&!(	$),;OOOOOO$$X}==+
FFr:   