
    VpfC                        d Z ddlmZ ddlZddlmZ ddlmZ ddlmZ ddlm	Z	 ddl
mZ dd	l
mZ dd
l
mZ ddl
mZ ddlmZ ddlmZ ddlmZ ddlmZ ddlmZ ddlmZ 	 	 	 	 d*dededededededeeef         fdZ	 	 	 	 d*dededededededeeef         fdZ d Z!d Z"d  Z#d! Z$d"d#d$Z%d% Z&d& Z' ej(        d'          Z)de)_*        e)+                     eej,        e)                     e)-                    e!            ej.        e) ee%d#                      ej.        e)e%d()           e&ej/        e)<   e'ej0        e)<   dS )+a  ANN (Approximate Nearest Neighbor) computes top-k with a configurable recall rate.

This package only optimizes the TPU backend. For other device types it fallbacks
to sort and slice.

Usage::

  import functools
  import jax

  # MIPS := maximal inner product search
  # Inputs:
  #   qy: f32[qy_size, feature_dim]
  #   db: f32[db_size, feature_dim]
  #
  # Returns:
  #   (f32[qy_size, k], i32[qy_size, k])
  @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
  def mips(qy, db, k=10, recall_target=0.95):
    dists = jax.lax.dot(qy, db.transpose())
    # Computes max_k along the last dimension
    # returns (f32[qy_size, k], i32[qy_size, k])
    return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target)

  # Multi-core example
  # Inputs:
  #   qy: f32[num_devices, qy_size, feature_dim]
  #   db: f32[num_devices, per_device_db_size, feature_dim]
  #   db_offset: i32[num_devices]
  #   db_size = num_devices * per_device_db_size
  #
  # Returns:
  #   (f32[qy_size, num_devices, k], i32[qy_size, num_devices, k])
  @functools.partial(
      jax.pmap,
      # static args: db_size, k, recall_target
      static_broadcasted_argnums=[3, 4, 5],
      out_axes=(1, 1))
  def pmap_mips(qy, db, db_offset, db_size, k, recall_target):
    dists = jax.lax.dot(qy, db.transpose())
    dists, neighbors = jax.lax.approx_max_k(
        dists, k=k, recall_target=recall_target,
        reduction_input_size_override=db_size)
    return (dists, neighbors + db_offset)

  # i32[qy_size, num_devices, k]
  pmap_neighbors = pmap_mips(qy, db, db_offset, db_size, 10, 0.95)[1]
  # i32[qy_size, num_devices * k]
  neighbors = jax.lax.collapse(pmap_neighbors, start_dimension=1, stop_dimension=3)

Todos::

  * On host top-k aggregation
  * Inaccurate but fast differentiation

    )partialN)ad_util)core)dispatch)dtypes)ad)batching)mlir)xla)lax)
xla_client)ir)func)hlo)Arrayffffff?Toperandkreduction_dimensionrecall_targetreduction_input_size_overrideaggregate_to_topkreturnc           	      D    t                               | |||d||          S )a  Returns max ``k`` values and their indices of the ``operand`` in an approximate manner.

  See https://arxiv.org/abs/2206.14286 for the algorithm details.

  Args:
    operand : Array to search for max-k. Must be a floating number type.
    k : Specifies the number of max-k.
    reduction_dimension : Integer dimension along which to search. Default: -1.
    recall_target : Recall target for the approximation.
    reduction_input_size_override : When set to a positive value, it overrides
      the size determined by ``operand[reduction_dim]`` for evaluating the
      recall. This option is useful when the given ``operand`` is only a subset
      of the overall computation in SPMD or distributed pipelines, where the
      true input size cannot be deferred by the operand shape.
    aggregate_to_topk : When true, aggregates approximate results to the top-k
      in sorted order. When false, returns the approximate results unsorted. In
      this case, the number of the approximate results is implementation defined
      and is greater or equal to the specified ``k``.

  Returns:
    Tuple of two arrays. The arrays are the max ``k`` values and the
    corresponding indices along the ``reduction_dimension`` of the input
    ``operand``. The arrays' dimensions are the same as the input ``operand``
    except for the ``reduction_dimension``: when ``aggregate_to_topk`` is true,
    the reduction dimension is ``k``; otherwise, it is greater equals to ``k``
    where the size is implementation-defined.

  We encourage users to wrap ``approx_max_k`` with jit. See the following
  example for maximal inner production search (MIPS):

  >>> import functools
  >>> import jax
  >>> import numpy as np
  >>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
  ... def mips(qy, db, k=10, recall_target=0.95):
  ...   dists = jax.lax.dot(qy, db.transpose())
  ...   # returns (f32[qy_size, k], i32[qy_size, k])
  ...   return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target)
  >>>
  >>> qy = jax.numpy.array(np.random.rand(50, 64))
  >>> db = jax.numpy.array(np.random.rand(1024, 64))
  >>> dot_products, neighbors = mips(qy, db, k=10)
  Tr   r   r   is_max_kr   r   approx_top_k_pbindr   r   r   r   r   r   s         P/var/www/html/nettyfy-visnx/env/lib/python3.11/site-packages/jax/_src/lax/ann.pyapprox_max_kr#   ]   s7    b 
			-!$A) 
 
+ 
+ +    c           	      D    t                               | |||d||          S )a	  Returns min ``k`` values and their indices of the ``operand`` in an approximate manner.

  See https://arxiv.org/abs/2206.14286 for the algorithm details.

  Args:
    operand : Array to search for min-k. Must be a floating number type.
    k : Specifies the number of min-k.
    reduction_dimension: Integer dimension along which to search. Default: -1.
    recall_target: Recall target for the approximation.
    reduction_input_size_override : When set to a positive value, it overrides
      the size determined by ``operand[reduction_dim]`` for evaluating the
      recall. This option is useful when the given operand is only a subset of
      the overall computation in SPMD or distributed pipelines, where the true
      input size cannot be deferred by the ``operand`` shape.
    aggregate_to_topk : When true, aggregates approximate results to the top-k
      in sorted order. When false, returns the approximate results unsorted. In
      this case, the number of the approximate results is implementation defined
      and is greater or equal to the specified ``k``.

  Returns:
    Tuple of two arrays. The arrays are the least ``k`` values and the
    corresponding indices along the ``reduction_dimension`` of the input
    ``operand``.  The arrays' dimensions are the same as the input ``operand``
    except for the ``reduction_dimension``: when ``aggregate_to_topk`` is true,
    the reduction dimension is ``k``; otherwise, it is greater equals to ``k``
    where the size is implementation-defined.

  We encourage users to wrap ``approx_min_k`` with jit. See the following example
  for nearest neighbor search over the squared l2 distance:

  >>> import functools
  >>> import jax
  >>> import numpy as np
  >>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
  ... def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95):
  ...   dists = half_db_norms - jax.lax.dot(qy, db.transpose())
  ...   return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target)
  >>>
  >>> qy = jax.numpy.array(np.random.rand(50, 64))
  >>> db = jax.numpy.array(np.random.rand(1024, 64))
  >>> half_db_norm_sq = jax.numpy.linalg.norm(db, axis=1)**2 / 2
  >>> dists, neighbors = l2_ann(qy, db, half_db_norm_sq, k=10)

  In the example above, we compute ``db^2/2 - dot(qy, db^T)`` instead of
  ``qy^2 - 2 dot(qy, db^T) + db^2`` for performance reason. The former uses less
  arithmetic and produces the same set of neighbors.
  Fr   r   r!   s         r"   approx_min_kr&      s7    j 
			-!$A) 
 
+ 
+ +r$   c                @   |dk    rt          d|           t          | j                  dk    r't          d                    | j                            t          | j                  }||         |k     r)t          d                    ||         |                    t          j        | j        t          j
                  st          d          ||         }|r|||<   ngt          j        ||f          r;t          j                            |t          |          ||||          d         ||<   nt!          d| d| d          |                     || j        | j        	          |                     |t          j        t          j                  
          fS )Nr   zk must be positive, got z5approx_top_k operand must have >= 1 dimension, got {}z;k must be smaller than the size of reduction_dim {}, got {}zoperand must be a floating typezSapprox_top_k with aggregate_to_topk=False not yet implemented when either the `k` (z$) or the  reduction dimension size (z) are symbolic)shapedtype	weak_type)r(   r)   )
ValueErrorlenr(   	TypeErrorformatlistr   
issubdtyper)   npfloatingr   is_constant_shapexcopsApproxTopKReductionOutputSizeNotImplementedErrorupdater*   int32)	r   r   r   r   r   r   r   dimsreduction_input_sizes	            r"   _approx_top_k_abstract_evalr<      s    !VV
333
4
441
KRR    	gm		$	
""
ELL$%q	* 	*+ + + 
	7="+	6	6 8
6
7
7712 
M !D	3Q788 M " D Dc$iiM;L%!' !''(!*D	 
L
L 
L';
L 
L 
LM M M ..t7=#*#4  6 6
..t28BH+=+=.
>
>
@ @r$   c           	         t          j        d                    |rdnd                    }t          j        |dt           j                            |                     }t          j        |dt           j                            |                     }t          j        |dt           j                            t          j        t          j	                                       t          j        |dt           j                            t          j        t          j	                                       |r!t           j
                            ||          }n t           j
                            ||          }|                    |          S )Nztop_k_{}_comparatorgtltr            )r4   
XlaBuilderr.   r   	parameterShapescalar_shaper1   r)   r9   r5   GtLtbuild)op_typer   cp0p1
cmp_results         r"   _comparator_builderrO      s   m""8#=44>>@ @!
}Q280099::"
}Q280099::"-1bh++BHRX,>,>??@@@-1bh++BHRX,>,>??@@@ #2r""JJ2r""J	
		r$   c                 `    t          j        |rt           j         nt           j        |           S )N)r)   )r1   arrayinf)rJ   r   s     r"   _get_init_val_literalrS     s&    	X126''26	A	A	AAr$   c                    t           j                            g |          }t           j                            g t           j                            d                    }||||g}t           j                            g t           j                            d                    g}t           j                            ||          }t           j                            | j        j	        j
                  5  t          j        d                    |rdnd|          |          }d d d            n# 1 swxY w Y   | j        j                            |           |                                }	t          j        |	          5  |	j        \  }
}}}t$          j                            |rdnd          }t%          j        |
||          }t%          j        |g           d d d            n# 1 swxY w Y   |S )	N    r@   ztop_k_{}_{}_comparatorr>   r?   GTLT)comparison_direction)r   RankedTensorTypegetIntegerTypeget_signlessFunctionTypeInsertionPointat_block_beginmodule_contextmodulebodyr   FuncOpr.   symbol_tableinsertadd_entry_block	argumentsr   ComparisonDirectionAttrcomparereturn_)ctxrJ   r   scalarindexir_typesresult_typescomparator_type
comparatorentry_blockrL   rM   _	directionrN   s                  r"   _comparator_builder_mlirru   	  s)   ""2w//&


!
!"bn&A&A"&E&E
F
F%feU+(%))"bn.I.I!.L.LMMN,O'',??/	''(:(A(FGG   ''(BdGLL J               !((444**,,+	%%  (LBAq+//0JdKKIR)DDDJK	               
s%   0.D**D.1D.<AG$$G(+G(F)fallbackc          	      L     j         sJ t          d  j         D                       sJ  j         d         j        }	t          |	          dk    rt	          d|	           |	}
t          j         j         d         j                  }t          j	        
                                }|dk     rt          |
          |z   }t           ||          }t          j         t          j         j         d         j        t          j                  |          }t#          j        t          j        
                    t          j        d                              }t)           j         d         j        |          }t          j        |                    d                    }t          j        |          t
          j        j        
                    ||          t
          j        j        
                    |          t          j        |          d}|r't
          j        j        
                    |          |d<   t          d	  j        D                       rd }n fd
 j        D             }t          j        |          rPt          j        |          |d<   t          j        dd  j        D             ||||g|j        j        g||          }nQt          j         |f          \  }t          j        dd  j        D             |||||g|j        j        g||          }|j         S )Nc              3   J   K   | ]}t          |t          j                  V  d S N)
isinstancer   ShapedArray).0xs     r"   	<genexpr>z)_approx_top_k_lowering.<locals>.<genexpr>$  s/      CCZ4+,,CCCCCCr$   r   z"operand must be an array, but was )	dimensionr    )reduction_dimr   r   r   is_fallbackc              3   H   K   | ]}t          j        |j                  V  d S ry   )r   r3   r(   )r|   aval_outs     r"   r~   z)_approx_top_k_lowering.<locals>.<genexpr>A  s/      NNH		/	/NNNNNNr$   c                 h    g | ].}t          j        t          j        |j                            /S r   )r
   shape_tensoreval_dynamic_shaper(   )r|   r   rk   s     r"   
<listcomp>z*_approx_top_k_lowering.<locals>.<listcomp>D  sC     ' ' ' 	$1#x~FFGG' ' 'r$   top_k
ApproxTopKc                 6    g | ]}t          j        |          S r   r
   aval_to_ir_typer|   avals     r"   r   z*_approx_top_k_lowering.<locals>.<listcomp>L  #    KKKTd*400KKKr$   )ro   operandscalled_computationsbackend_configresult_shapeszstablehlo.dynamic_approx_top_kc                 6    g | ]}t          j        |          S r   r   r   s     r"   r   z*_approx_top_k_lowering.<locals>.<listcomp>U  r   r$   )!avals_inallr(   r,   r+   r
   dtype_to_ir_typer)   r   F32TyperZ   ru   iotar   r{   r1   r9   r   constantDenseElementsAttrrS   ir_constantreshapei64_attr	FloatAttrBoolAttr	avals_outis_constant_dimcustom_callnamevalueeval_dynamic_shape_as_valsresults)rk   r   r   r   r   r   r   r   rv   op_shapeop_dimsrJ   recall_typerq   r   init_arginit_val_arrayinit_valr   r   outk_values   `                     r"   _approx_top_k_loweringr     s    
	CCclCCC	C	CCCC\!_"(]]a
D(DD
E
EE'!#,q/"788'
  +1g,,)<<'Wh??*	3(a)>II0
2 
2 
2$ \".2228B<<@@AA((a)>II.n44R8899( m$788g'++KGG'*../@AA
m1224 4.  C$(G$4$8$8$B$BN=!NNNNNNN 'MM' ' ' '' ' 'M 
! %"mA..N7

KKS]KKK484'_23%#% % %CC .sQD99HG

(KKS]KKK48W='_23%#% % %C 
r$   c          	        
 t          |           dk    sJ t          |          dk    sJ | \  }|\  

fdt          |j                  D             }	|	|         }t                              |||||||          

ffS )Nr@   c                     g | ]}|u|	S r   r   )r|   d
batch_axiss     r"   r   z,_approx_top_k_batch_rule.<locals>.<listcomp>d  s#    CCC1q
/B/BQ/B/B/Br$   r   )r,   rangendimr   r    )batch_operands
batch_axesr   r   r   r   r   r   r   dim_mapr   s             @r"   _approx_top_k_batch_ruler   ]  s     
^			!	!	!	!	ZA				('+*CCCCgl++CCC' 34				-!$A) 
 
+ 
+ .8,D
E Er$   c                   | \  }|\  }	|rt          |||||          \  }
nt          |||||          \  }
t          |	          t          j        u r t          j                            |
          }noj        t                    }dk     r|z  fdt          |          D             t          fdt          |          D                       }|	|         }|
f|t          j                                      ffS )Nr   c                 F    g | ]}t          j        j        |          S r   )r   broadcasted_iotar)   )r|   iarg_out	arg_shapes     r"   r   z%_approx_top_k_jvp.<locals>.<listcomp>  s7       >?W]Iq99  r$   c              3   8   K   | ]}|k    rn|         V  d S ry   r   )r|   r   r   iotasr   s     r"   r~   z$_approx_top_k_jvp.<locals>.<genexpr>  sO       P P>?1+++qP P P P P Pr$   )
r#   r&   typer   Zero
from_valuer(   r,   r   tuple)primalstangentsr   r   r   r   r   r   r   tangentval_outtangent_outrankidxr   r   r   s      `          @@@r"   _approx_top_k_jvpr   x  sq    ('(' 	7#GQ0C$1$A$57 7GWW
 $GQ0C$1$A$57 7GW 
']]gl"",))'22KKIy>>DQT!    CH;;  E  P P P P P PCH;;P P P P PC#,K
7	k7<+B+B7+K+KL	LLr$   approx_top_ktpu)platform)r   r   r   T)1__doc__	functoolsr   numpyr1   jax._srcr   r   r   r   jax._src.interpretersr   r	   r
   r   jax._src.laxr   jax._src.libr   r4   jax._src.lib.mlirr   jax._src.lib.mlir.dialectsr   r   jax._src.typingr   intfloatboolr   r#   r&   r<   rO   rS   ru   r   r   r   	Primitiver   multiple_resultsdef_implapply_primitivedef_abstract_evalregister_loweringprimitive_batchersprimitive_jvpsr   r$   r"   <module>r      s[  7 7r                                   $ $ $ $ $ $ * * * * * * & & & & & & % % % % % %       ) ) ) ) ) )             + + + + + + * * * * * * ! ! ! ! ! !
 -/(,68+/8+ 8+% 8+8+&)8+ !&8+ 14	8+
 %)8+
 5:%,4G8+ 8+ 8+ 8+z -/(,68+/<+ <+% <+<+&)<+ !&<+ 14	<+
 %)<+
 5:%,4G<+ <+ <+ <+~@ @ @B  B B B  2 ?D< < < < <|E E E6M M M>  //"&     8.II J J J     !< = = =  ~g4tDDDF F F  ~'=!&( ( ( (.F N +$5 . ! ! !r$   