Metadata-Version: 2.4
Name: cuquantum-python-jax
Version: 0.0.3
Summary: NVIDIA cuQuantum Python JAX
Home-page: https://developer.nvidia.com/cuquantum-sdk
Author: NVIDIA Corporation
Author-email: cuquantum-python@nvidia.com
License: BSD-3-Clause
Classifier: Development Status :: 5 - Production/Stable
Classifier: Operating System :: POSIX :: Linux
Classifier: Topic :: Education
Classifier: Topic :: Scientific/Engineering
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: Implementation :: CPython
Classifier: Environment :: GPU :: NVIDIA CUDA
Classifier: Environment :: GPU :: NVIDIA CUDA :: 13
Requires-Python: >=3.11.0
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: pybind11
Requires-Dist: cuquantum-python-cu13~=25.11
Requires-Dist: jax[cuda13-local]<0.9,>=0.8
Dynamic: author
Dynamic: author-email
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: home-page
Dynamic: license
Dynamic: license-file
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary

# cuQuantum Python JAX

cuQuantum Python JAX provides a JAX extension for cuQuantum Python. It exposes selected functionality of cuQuantum SDK in a JAX-compatible way that enables JAX frameworks to directly interface with the exposed cuQuantum API. In the current release, cuQuantum JAX exposes a JAX interface to the Operator Action API from the cuDensityMat library.

## Documentation

Please visit the [NVIDIA cuQuantum Python documentation](https://docs.nvidia.com/cuda/cuquantum/latest/python).

## Building and installing cuQuantum Python JAX

### Requirements

The install-time dependencies of the cuQuantum Python package include:

* cuquantum-python-cu12~=25.11 for CUDA 12 or cuquantum-python-cu13~=25.11 for CUDA 13
* jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
* pybind11
* setuptools>=77.0.3

Note: cuQuantum Python JAX is only supported with CUDA 12 and CUDA 13.

#### Installation using `jax[cudaXX-local]`

`cuquantum-python-jax` depends explicitly on `jax[cudaXX-local]`. `pip install cuquantum-python-jax` will install `jax[cudaXX-local]`.

Using `jax[cudaXX-local]` assumes the user provides both cuDNN and the CUDA Toolkit. cuDNN is not a part of the CUDA Toolkit and requires an additional installation. The user must also specify `LD_LIBRARY_PATH`, including the library folders containing `libcudnn.so` and `libcupti.so`.

`libcupti.so` is provided by the CUDA Toolkit. If the CUDA Toolkit is installed under `/usr/local/cuda`, `libcupti.so` is located under `/usr/local/cuda/extras/CUPTI/lib64` and `LD_LIBRARY_PATH` should contain this path.

`libcudnn.so` is installed separately from the CUDA Toolkit. The default installation location is `/usr/local/cuda/lib64`, and `LD_LIBRARY_PATH` should contain this path.

Both `libcudnn.so` and `libcupti.so` are installable with pip:

```
pip install nvidia-cudnn-cu12
pip install nvidia-cuda-cupti-cu12
```

After installing cuDNN and cuPTI, the user may install `cuquantum-python-jax` using `pip` using either:

```
pip install cuquantum-python-jax
```

or

```
pip install cuquantum-python-cu12[jax]
pip install cuquantum-python-cu13[jax]
```

Note: if cuDNN and cuPTI are installed with `pip`, the user does not need to specify library folders in `LD_LIBRARY_PATH`.

#### Installing from source

To install cuQuantum Python JAX from source, first compile cuQuantum Python from source using the [instructions on GitHub](https://github.com/NVIDIA/cuQuantum/blob/main/python/README.md). Once complete, navigate to `python/extensions`, then:

```
export CUDENSITYMAT_ROOT=...
pip install .
```

Where `CUDENSITYMAT_ROOT` is the path to the libraries parent directory. For example, if `CUDENSITYMAT_ROOT=/usr/local`, `libcudensitymat.so` would be found under `/usr/local/lib` or `/usr/local/lib64`.

## Running

### Requirements

Runtime dependencies of the cuQuantum Python package include:

* An NVIDIA GPU with compute capability 7.5+
* cuquantum-python-cu12~=25.11 for CUDA 12 or cuquantum-python-cu13~=25.11 for CUDA 13
* jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8<0.9 for CUDA 13 
* pybind11

## Developer Notes

* cuQuantum Python JAX does not support editable installation.
* Both cuQuantum Python and cuQuantum Python JAX need to be installed into `site-packages` for proper import of the library.
* cuQuantum Python JAX assumes cuQuantum Python will be available under the current `site-packages` directory.
