
    VpfEV                       U d dl mZ d dlZd dlmZmZmZ d dlZd dlZd dl	m
Z
 d dlZd dlZd dlmZmZ d dlZd dlZd dlZd dlZd dlmZ d dlmZ d dlmZ d d	lmZ d d
lmZ d dlmZ d dlmZ d dlmZ d dlmZ d dlm Z  d dl!m"Z" d dlm#Z# d dlm$Z$ d dlm%Z% d dlm&Z& d dl'm(Z) d dl*m+Z+ d dl,m-Z- d dl.m/Z/ d dl0m1Z1m2Z2m3Z3m4Z4m5Z5 d dl6m7Z7m8Z8 dZ9dZ:dZ; ej<        e=           e)j>        Z?e?j@        ZAe)jB        ZBe)jC        ZCejD        eEcZEZFejG        eHcZHZI ejJ        eK          ZLdZMd ZN ejO                    dud"            ZPd# ZQeZR G d$ d%ejS                  ZT eT            ZUd%eVd&<   ejW        d'             ZXejY        dvdwd-            ZZdxd3Z[dyd7Z\ e]            Z^d8eVd9<   dzd:Z_ G d; d<e          Z`d{d>Zad|d?Zbd}dBZcdZddCeVdD<   d~dEZeddHZfddIZgddLZhddOZiddPZjddUZkddZZld[ Zmd\ Zn ejo        d]^           G d_ d`                      Zpda ZqddeZrddiZs ejt        dj          Zud]eu_v        euw                    es           eux                    dk            dl Zy e
ejz        eu          ej{        eu<   eyej|        eu<   dm Z}e}e j~        eu<   dn Z e#j        euedop            e#j        euedqp           dr Z e#j        eue           dddsdtZee%j        eu<   dS )    )annotationsN)CallableIteratorSequence)partial)Any
NamedTuple)	basearray)config)core)api)dtypes)source_info_util)traceback_util)util)ad)batching)array_types)mlir)xla)pxla)lib)
xla_client)record_event_duration_secs)PartitionSpec)Sharding)SingleDeviceShardingNamedShardingGSPMDShardingTransferToMemoryKindis_single_device_sharding)LayoutDeviceLocalLayoutz&/jax/core/compile/jaxpr_trace_durationz//jax/core/compile/jaxpr_to_mlir_module_durationz*/jax/core/compile/backend_compile_durationFc                    t          | fi |}t          j                            d          }	  || }t          j                            |           n$# t          j                            |           w xY w|S )zEImpl rule that compiles and runs a single primitive 'prim' using XLA.F)xla_primitive_callabler   jax_jit#swap_thread_local_state_disable_jit)primargsparamsfunprevoutss         Q/var/www/html/nettyfy-visnx/env/lib/python3.11/site-packages/jax/_src/dispatch.pyapply_primitiver/   P   sx    t..v..# 
	8	8	?	?$:3:DK33D9999CK33D9999	+s   A !A4r(   core.Primitivec                j      fd} j         |_         j         |_        t          j        |          S )Nc                      j         | i S N)bind)r)   r*   r(   s    r.   prim_funz(xla_primitive_callable.<locals>.prim_fun^   s    49d%f%%%    )name__name____qualname__r   jit)r(   r*   r5   s   `` r.   r%   r%   \   sF    & & & & & &i()(			r6   c                V    |                      t          t          |                      d S r3   )def_implr   r/   )r(   s    r.   simple_implr=   e   s$    --../////r6   c                  R    e Zd ZU dZded<   ded<   d ZddZddZddZd Z	d Z
dS )RuntimeTokenSetzJSee docstring for effects.py module for the calling convention for tokens.zdict[core.Effect, core.Token]current_tokenszdict[Device, RuntimeToken]output_runtime_tokensc                "    i | _         i | _        d S r3   r@   rA   selfs    r.   __init__zRuntimeTokenSet.__init__u       D!#Dr6   effcore.Effectdeviceslist[Device]return
core.Tokenc                   | j                             |t          j        dt          j                            }t          |t          j                  r2t          j	        |t          j
                            |                    S t          j
        j                            |          }t          j        t          j        |g|g          d                   }|| j         |<   |S )Nr   )r@   getnpzerosbool_
isinstancer   Tokenjax
device_putshardingPositionalShardingr   get_replicatedr   
shard_args)rE   rH   rJ   tokssharded_toks         r.   get_token_inputzRuntimeTokenSet.get_token_inputy   s     

!
!#rx28'<'<
=
=C#tz"" K
 ^C!@!@!I!IJJJ 	"11'::A*T_aS3%88;<<K*Dr6   tokenc                    || j         |<   d S r3   )r@   )rE   rH   r_   s      r.   set_token_resultz RuntimeTokenSet.set_token_result   s    $Dr6   deviceDeviceRuntimeTokenc                    || j         |<   d S r3   )rA   )rE   rb   r_   s      r.   set_output_runtime_tokenz(RuntimeTokenSet.set_output_runtime_token   s     */Dv&&&r6   c                "    i | _         i | _        d S r3   rC   rD   s    r.   clearzRuntimeTokenSet.clear   rG   r6   c                    | j                                         D ]}|                                 | j                                        D ]}|                                 |                                  d S r3   )r@   valuesblock_until_readyrA   rh   )rE   r_   s     r.   rk   z!RuntimeTokenSet.block_until_ready   sy    $++--    +2244    JJLLLLLr6   N)rH   rI   rJ   rK   rL   rM   )rH   rI   r_   rM   )rb   rc   r_   rd   )r8   
__module__r9   __doc____annotations__rF   r^   ra   rf   rh   rk    r6   r.   r?   r?   j   s         RR 0/// 4333$ $ $   &% % % %/ / / /$ $ $    r6   r?   runtime_tokensc                 8    t                                            d S r3   )rp   rk   ro   r6   r.   wait_for_tokensrr      s    ""$$$$$r6   fmtstrfun_nameevent
str | Nonec              #    K   t           rd V  d S t          j        j        rt          j        nt          j        }t          j                    }d V  t          j                    |z
  }t          	                    |          r0t          
                    ||                     ||                     |t          ||           d S d S )N)ru   elapsed_time)_on_exitr   log_compilesvalueloggingWARNINGDEBUGtimeloggerisEnabledForlogformatr   )rs   ru   rv   log_priority
start_timery   s         r.   log_elapsed_timer      s       6	EEEEE&,&9&?R7??W]LJ	EEE9;;+L<(( 9jjszz,  *  8  8 9 9 9 55555 r6   num_argsintplatformrL   boolc                    |dk    r| dk    S dS )Ntpui  Fro   )r   r   s     r.   should_tuple_argsr      s     d?5r6   jaxpr
core.Jaxpr	prim_namec                    | j         D ]}||j        j        v r dS t          j        |           D ]}t          ||          r dS dS )zCWhether there is a primitive given by user anywhere inside a Jaxpr.TF)eqns	primitiver7   r   	subjaxprsjaxpr_has_primitive)r   r   eqnsubjaxprs       r.   r   r      sk    Z  cCM&&&TT '.''  h8Y// TT	r6   zset[core.Primitive]%prim_requires_devices_during_loweringc                    | j         D ]}|j        t          v r dS t          j        |           D ]}t          |          r dS dS )NTF)r   r   r   r   r    jaxpr_has_prim_requiring_devices)r   r   r   s      r.   r   r      sg    Z  c
}===TT >.''  h'11 TT	r6   c                  $    e Zd ZU ded<   ded<   dS )
SourceInfozsource_info_util.SourceInfosource_infort   eqn_nameN)r8   rl   r9   rn   ro   r6   r.   r   r      s%         ****-----r6   r   %Iterator[tuple[Sharding, SourceInfo]]c              #  p  K   ddl m} ddlm} | j        D ]oj        |j        u r1t          j        j        j	                  j
        d         fV  Bj        |j        u r^t          j        j        j	                  fdj
        d         D             E d {V  fdj
        d         D             E d {V  j        |j        u rgj
        d	         j        st          j        j        j	                  d
 fdg j
        d         j
        d         D             E d {V  #j        t          u r>t          j        j        j	                  fdj
        d         D             E d {V  qt          j        |           D ]}t#          |          E d {V  d S )Nr   )pjit)	shard_maprW   c              3      K   | ]}|fV  	d S r3   ro   ).0ir   s     r.   	<genexpr>z-get_intermediate_shardings.<locals>.<genexpr>   s(      GGq1k"GGGGGGr6   in_shardingsc              3      K   | ]}|fV  	d S r3   ro   )r   or   s     r.   r   z-get_intermediate_shardings.<locals>.<genexpr>   s(      HHq1k"HHHHHHr6   out_shardingsmeshc                t      rt                     dz   nd}t           fdt          |          D              S )N   r   c              3  B   K   | ]}                     |          V  d S r3   )rO   )r   r   namess     r.   r   zFget_intermediate_shardings.<locals>._names_to_pspec.<locals>.<genexpr>   s-      BBuyy||BBBBBBr6   )maxr   range)r   ndmins   ` r.   _names_to_pspecz3get_intermediate_shardings.<locals>._names_to_pspec   sA    "'.E

QQBBBBU5\\BBBCCr6   c              3  d   K   | ]*}t          j        d           |                    fV  +dS )r   N)r   r*   )r   r   r   r   r   s     r.   r   z-get_intermediate_shardings.<locals>.<genexpr>   s_       T T !F!3__U5K5KLLkZ T T T T T Tr6   in_names	out_namesc              3  X   K   | ]$}t          |t                    r|j        |fV  %d S r3   )rS   r   memory_kind)r   r\   r   s     r.   r   z-get_intermediate_shardings.<locals>.<genexpr>   sR       L Lq8,,L121J k"1J1J1J1JL Lr6   rJ   )jax._srcr   jax.experimentalr   r   r   sharding_constraint_pr   r   r7   r*   pjit_pshard_map_p_is_jax_device_meshdevice_put_pr   r   get_intermediate_shardings)r   r   r   r   r   r   r   s       @@@r.   r   r      s      ((((((Z L Lc
}222s0BCCkZ
#[11111	$+	%	%s0BCCkGGGGCJ~,FGGGGGGGGGGHHHHCJ,GHHHHHHHHHHH	)/	/	/Z3 s0BCCkD D DT T T T T TRJ!7R#*[:QRT T T T T T T T T T T	,	&	&s0BCCkL L L LCJy,A L L L L L L L L L L.'' 4 4h)(33333333334 4r6   c           	         t          d | j        D                       p>t          d t          j        | gt	          j        |                     D                       S )Nc              3     K   | ]H}t          |j        t          j                  !t	          |j        j                  t          j        u V  Id S r3   )rS   avalr   UnshapedArraytypedtypebint)r   vs     r.   r   z"jaxpr_has_bints.<locals>.<genexpr>   s\       9 9!AFD$6779d16<  DI- 9 9 9 9 9 9r6   c              3     K   | ]U}|j         D ]K}|j        D ]A}t          |j        t          j                  !|j        j        D ]}t          |          V  BLVd S r3   )r   outvarsrS   r   r   DShapedArrayshape_is_bint_axis_size)r   jer   ds        r.   r   z"jaxpr_has_bints.<locals>.<genexpr>   s       N NvN NqyN N"#AFD$566N ABN N <= !## N N N N N N N N Nr6   )anyinvars	itertoolschainr   r   r   s    r.   jaxpr_has_bintsr      s    
 9 9u| 9 9 9 9 9 N
 N N wu0E0EFFN N N N NOr6   r   core.AxisSizec                H   t          | t          j                  r)| j        rJ t	          | j                  t          j        u S t          | t          j                  rDt          | j        t          j	                  o$t	          | j        j                  t          j        u S dS )NF)
rS   r   DArrayr   r   r   r   Varr   r   )r   s    r.   r   r     s    4; -w==DI%%!TX -qvt011 ,$)+-	r6   z)Callable[[core.Jaxpr], core.Jaxpr] | Noneoutfeed_rewriterc                2    t           t          |           S | S r3   )r   r   s    r.   apply_outfeed_rewriterr     s    !E"""Lr6   argr   c                    t          | t          j                  s7t          j        |           s%t	          d|  dt          |            d          d S d S )N
Argument '
' of type z is not a valid JAX type.)rS   r   Tracervalid_jaxtype	TypeErrorr   )r   s    r.   	check_argr     sv    
S$+
&
& !$*<S*A*A !
      S		       ! ! !! ! ! !r6   c                T    t          t          t          | j                  d          S )zThe number of replicas needed for a jaxpr.

  For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the
  subjaxprs. For a list of eqns, take the maximum number of replicas.
  r   default)r   
unsafe_map_eqn_replicasr   r   s    r.   jaxpr_replicasr      s#     
Zuz22A	>	>	>>r6   r   core.JaxprEqnc                    | j                             d          }|r+| j                             dd          t          |          z  S | j        t          j        v rt          | j                   S dS )N
call_jaxpr	axis_sizer   )r*   rO   r   r   r   initial_style_primitives!_initial_style_primitive_replicas)r   r   s     r.   r   r   *  se    z~~l++* :>>+q))N:,F,FFF
}444,SZ8881r6   r*   dict[str, Any]c                x    t          t          j        t          |                                           d          S )Nr   r   )r   r   traverse_jaxpr_paramsr   rj   )r*   s    r.   r   r   3  s7    	T'??FFHH
 
 
 r6   c                 F    t           j        j        pt           j        j        S r3   )r   
debug_infsr|   
debug_nansro   r6   r.   needs_check_specialr   7  s    			 	;F$5$;;r6   r7   bufsSequence[basearray.Array]Nonec                \    t                      r|D ]}t          | |j        |           d S d S r3   )r   _check_specialr   )r7   r   bufs      r.   check_specialr   :  sJ     + + +T39c****+ ++ +r6   r   np.dtyper   basearray.Arrayc                   t          j        |t          j                  rt          j        j        rJt          j        t          j        t          j	        |                              rt          d|            t          j        j        rLt          j        t          j        t          j	        |                              rt          d|            d S d S d S )Nz#invalid value (nan) encountered in z#invalid value (inf) encountered in )r   
issubdtyperP   inexactr   r   r|   r   isnanasarrayFloatingPointErrorr   isinf)r7   r   r   s      r.   r   r   ?  s    ubj)) M M26"(2:c??*C*C#D#D MKTKKLLL M26"(2:c??*C*C#D#D MKTKKLLL	M MM M M Mr6   c                    | S r3   ro   )xs    r.   _identity_fnr  G  s    	
(r6   c           
         ddl m}m} |                                  | j        }|j        |j        k    r   |j        t          |          |           S |j        |j        k    rd |j        D             }|j        d         j	        
                                }d |j        D             }|j        d         j	        
                                }t          d| d| d| d|           |                    | j                  }	|	                                r|	}
n
 t          j        |j        j        t$          g	          |j                  }|	                                }g |_        g |_        t          j        ||	                                          |_        t0          j                            |          }
t7          t          j        |j        |	                                                    t7          t          j        |j        |j                            k    sJ |                    | j        t=          |j        |
|j        
          | j                   }  |j        t          |          |          S )Nr   )r   arrayr   c                    g | ]	}|j         
S ro   idr   r   s     r.   
<listcomp>z3_different_device_order_reshard.<locals>.<listcomp>T  s    ===qt===r6   c                    g | ]	}|j         
S ro   r  r  s     r.   r  z3_different_device_order_reshard.<locals>.<listcomp>V  s    CCC1!$CCCr6   z[Input and target sharding should have the same set of devices. Got input's device set ids: z on platform z' and target sharding's device set ids: )otypes)r   )!r   r   r  _check_if_deletedrW   _device_assignmentr:   r  
device_setr   upper
ValueError_to_xla_hlo_shardingndimis_replicatedrP   	vectorizeindexr   to_protoiota_reshape_dimsiota_transpose_permtaketile_assignment_devicesxcHloSharding
from_protolist$make_array_from_single_device_arraysr   r   r   _arrays)r  target_shardingr   r  inp_shardinginp_idsinp_plat
target_idstarget_platold_hlo_shardingnew_hlo_shardingpermute_ordernew_op_shardingnew_xs                 r.   _different_device_order_reshardr7  J  s   !!!!!!!!,$(JJJ?737<???BBB :::==\<===G.q1:@@BBHCC BCCCJ!4Q7@FFHHK
 D=DD D!)D D (D D 7BD D E E E
 "66qv>>##%% G'/BL!C!I),/ / //;/NP PM '//11O(*O%*,O'.0g'??AA/ /O+ ~00AA 8)AACCE E F FBGO>+CE E F FF F F F
 
4
4gO68H / ;= = =i	 % 
>_	=	=	=e	D	DDr6   T)frozenc                  R    e Zd ZU dZded<   ded<   ded<   ded	<   ed
             ZdS )_DeferredShardArgzDeferred call to `pxla.shard_args`.

  Per-array impls return this object instead of a result array to indicate a
  deferred `shard_args` call. `_batched_device_put_impl` then batches all
  `_DeferredShardArg` objects into a single `shard_args` call.
  r   r  r   r\   zcore.AbstractValuer   r   	committedc                L    t          j        | j        | j        | j                  S r3   )r   global_aval_to_result_handlerr   r\   r;  rD   s    r.   result_handlerz _DeferredShardArg.result_handler  s    -diPPPr6   N)r8   rl   r9   rm   rn   propertyr>  ro   r6   r.   r:  r:    sg           	&&&+++///Q Q 8Q Q Qr6   r:  c                $   ddl m} ddlm} t	          |t
                    r|}t          | dd           |k    rt          | dd          r| S |j        sCt	          | |j                  r.| j        s't	          |t
                    sJ t          | |          S |j        rt	          | |j                  rp| j        rit          |j                  dk    rQ|j        | j        j        k    r<|j        | j        j        k    r't	          |t
                    sJ t          | |          S |j        st	          | |j                  r| j        rt          |           t           v r[|                    | t          |            dt          |            d	
            t%          j        t(          |          |           S t+          d| d          t-          | ||d          S t	          | |j                  ri| j        s$t+          d|                                            || S t1          | j                  r&t3          j        |t7          |          | g|g          S t7          |t3          j                    n|          }t-          | |||d u          S )Nr   )r  )multihost_utilsrW   
_committedFr   zc passed to device_put is not the same on each process. Make sure you are passing the same value of z on each process.)fail_messager  zjdevice_put's second argument must be a Device or a Sharding which represents addressable devices, but got zF. Please pass device or Sharding which represents addressable devices.TzZdevice_put's first argument must be a fully addressable array, but got value with devices )r   r  r   rA  rS   r   getattris_fully_addressable	ArrayImplr7  lenr  _internal_device_listrW   rB  r   r   assert_equalr   r:   r  r  r:  rJ   r!   r   batched_device_putr   _get_default_device)r  r   rb   r  rA  r\   shs          r.   _device_put_sharding_implrM    s    ......!!  /Aq*d##q((WQe-L-L(h" 31eo&&3/0/E38$$$$$,Q222	 3:a#A#A 3	3#&q|#4#4q#8#8	1:#CCC	
---8$$$$$,Q222! =a)) 9!, 9
q''[
 
 $$77 / /GG/ / / 	% 	1 	1 	1
 6sw|1555a888<67< < <= = = Q4... 5?## 	/! 32$%IIKK2 23 3 3 ~h	"1:	.	. /$T+?+G+G!&,X/ / / %~ !466639; ;"	1b$d(:	;	;;r6   rb   !Device | Sharding | Layout | Nonesrcc          	     |   t          |t                    st          |t                    rt          d          	 t          j        |           }n6# t
          $ r)}t          d|  dt          |            d          |d }~ww xY wt          |t                    r|}|j        }t          | d          r| j
        j        nd }||j        t          | ||j                  S t          |j        t                    r$t          |t          t          d           f          s't          d| d|                                           t!          | dd           |k    rt!          | dd	          r| S ||t          | ||j                  S  t#          j        t&          |
          |           S t          | ||          S )NzTransferToMemoryKind argument to jax.device_put can only be used inside jax.jit. If you are using device_put outside jax.jit, then please provide a concrete Sharding with memory_kind.r   r   z is not a valid JAX typelayoutzVsharding and device_local_layout in `Layout` instance should be concrete. Got layout: z for input rB  Fr  )rS   r    r  r   abstractifyr   r   r"   device_local_layouthasattrrQ  rW   rM  r   r#   	str_shortrD  r   r:   r  )r  rb   rO  r   errldllx_dlls           r.   _device_put_implrZ    s    -.. A*++A
	@A A A
N?1DD	 N N N
CQCC$q''CCCE EJMNN  5A

C,3Ax,@,@JAH((dE
{qz)&q$
;;;qz8,, Fs.T

;<<FE$%E E26..2B2BE EF F F q(D!!Q&&71lE+J+J&h}&q$
;;;137<q111!444	"1dF	3	33s   A 
B$A>>BrJ   +Sequence[Device | Sharding | Layout | None]srcsc                @   g }g g g }}}t          t          || |                    D ]\  }\  }}	}
t          ||	|
          }t          |t                    rI|                    |           |                    |j                   |                    |j                   |                    |           |rft          j	        ||          }t          ||          D ]@\  }}t          ||         t                    sJ ||         
                    |          ||<   A|S )N)rb   rO  )	enumerateziprZ  rS   r:  appendr  r\   r   rZ   r>  )rJ   r\  xsysshard_arg_indicesshard_arg_xsshard_arg_shardingsr   r  rb   rO  yshard_arg_resultsshard_arg_results                 r.   _batched_device_put_implri    s;   
 
"9;R#6\&s2w'='=>>  a	!VS6s333A!&'' &q!!!!#  %%%IIaLLLL 5 (;\JJ"#46GHH 5 51011111e""#344bee	)r6   rV   c                    |S r3   ro   )rJ   r\  ra  s      r.   <lambda>rk    s    " r6   c               |   d gt          |           z  }g }t          t          | ||                    D ]=\  }\  }}}	t          |          t          j        ur|                    ||||	f           >|rHt          t          |           \  }
}}}t          j	        |||d}t          |
|          D ]
\  }}|||<   |S )NrJ   r\  )
rG  r^  r_  r   r   Zeror`  r)  r   r4   )ctsrJ   r\  _resultsdp_argsr   ctrb   rO  indicesr)   rb  rf  s                 r.   _device_put_transposeru    s    FSXX'''C$(?(?@@ + +a	"fcBxxrwnnaVS)*** #'W#6#6 GT7D		D$W	=	=	=BGR    1gajj	.r6   c                    d |D             r-t          fddd          D                       s
J |            t          j        | i ||fS )Nc                .    g | ]}|t           j        u|S ro   )r   
not_mapped)r   bds     r.   r  z'_device_put_batcher.<locals>.<listcomp>  s%    PPPb"H<O2O2Or2O2O2Or6   c              3  0   K   | ]}d          |k    V  dS )r   Nro   )r   ry  mapped_batch_dimss     r.   r   z&_device_put_batcher.<locals>.<genexpr>   s?       & &%'b & & & & & &r6   r   )allr   r4   )batched_args
batch_dimsr*   r{  s      @r.   _device_put_batcherr    s    PPJPPP # & & & &+<QRR+@& & & # #     
	L	3F	3	3Z	??r6   c          
     d      fd}t          t          |||| j         j                            S )Nc                4   t          |t          t          f          rz|j        st          |t                    rAt	          j        | ||                    |j                                                            } t	          j	        | |j        |          } | S | S r3   )
rS   r   r    r   r   wrap_with_sharding_opr  r  r!  wrap_with_memory_kind)r  rb   rO  r   out_avalctxs        r.   lowerz+_tpu_gpu_device_put_lowering.<locals>.lower'  s    6H&:;<< &	FH	%	% Q&Hf99$)DDMMOOQ Q

$Q(:H
E
EahHr6   )r)  mapavals_in	avals_out)r  rJ   r\  ra  r  s   `    r.   _tpu_gpu_device_put_loweringr  &  sB         
c%WdCL#-HH	I	IIr6   r   )r   gpuc                   |D ]A}t          |t          t          f          r#|j        t	          d| j        j                   B|S )NzNPassing memory_kind to device_put via Shardings is not supported on platforms )rS   r   r    r   NotImplementedErrormodule_context	platforms)r  rJ   r\  ra  rb   s        r.   _common_device_put_loweringr  7  sd     8 8f6H&:;<< 8&7*47 78 8 8 
)r6   rm  c                    g }| D ]N}t          |t          t          f          r|                    |j                   9|                    d            O|S r3   )rS   r   r    r`  r   )rJ   r\  xmmemory_kindsrb   s        r.   _propagate_mem_kind_dpr  A  se    ,    f&8%9:;;  &,----$	r6   )r(   r0   r3   )rs   rt   ru   rt   rv   rw   )r   r   r   rt   rL   r   )r   r   r   rt   rL   r   )r   r   )r   r   rL   r   )r   r   rL   r   )r   r   rL   r   )r   r   rL   r   )r   r   )r   r   rL   r   )r   r   rL   r   )r*   r   rL   r   )rL   r   )r7   rt   r   r   rL   r   )r7   rt   r   r  r   r  rL   r   )rb   rN  rO  rN  )rJ   r[  r\  r[  )
__future__r   atexitcollections.abcr   r   r   
contextlibdataclasses	functoolsr   r   r   typingr   r	   r}   	threadingnumpyrP   rU   r   r
   r   r   r   r   r   r   r   jax._src.interpretersr   r   jax._src.abstract_arraysr   r   r   r   r   jax._src.libr   r&  jax._src.monitoringr   jax._src.partition_specr   jax._src.shardingr   jax._src.sharding_implsr   r   r   r    r!   jax._src.layoutr"   r#   JAXPR_TRACE_EVENTJAXPR_TO_MLIR_MODULE_EVENTBACKEND_COMPILE_EVENTregister_exclusion__file___xlaxeClientBackendrc   CompileOptionssafe_mapr  r   safe_zipr_  
unsafe_zip	getLoggerr8   r   rz   r/   cacher%   r=   rd   localr?   rp   rn   registerrr   contextmanagerr   r   r   setr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r7  	dataclassr:  rM  rZ  ri  	Primitiver   multiple_resultsr<   def_abstract_evalru  
linear_jvpprimitive_jvpsprimitive_transposesr  primitive_batchersr  register_loweringr  r  memory_kind_propagate_rulero   r6   r.   <module>r     sT    # " " " " " "  8 8 8 8 8 8 8 8 8 8                    " " " " " " " "          



                               % % % % % % # # # # # #       $ $ $ $ $ $ * * * * * * 0 0 0 0 0 0 & & & & & & % % % % % % & & & & & &       ) ) ) ) ) ) : : : : : : 1 1 1 1 1 1 & & & & & &D D D D D D D D D D D D D D 6 5 5 5 5 5 5 5 = N D  ! !( + + +W
)	"-Z-Z		8	$	$ 
 
 
    0 0 0 4 4 4 4 4io 4 4 4l #2/"3"3 3 3 3 3% % % 6 6 6 6 6       >ASUU % B B B B          
4 4 4 4>O O O O    ?C  B B B B   ! ! ! !? ? ? ?      < < < <+ + + +
M M M M  2E 2E 2Ej d###Q Q Q Q Q Q Q $#Q$4< 4< 4<n$4 $4 $4 $4N   4 t~l++ $    . / / /   << = = =   #*'"-"F"F , (=  %@ @ @ -@ L )
J 
J 
J  ,u> > > >  ,u> > > >    |%@ A A A(,4      1G  - - -r6   