Summary: Differentiable folding is an emerging paradigm for RNA design in which a probabilistic sequence representation is optimized via gradient descent. However, given the significant memory overhead of differentiating the expected partition function over all RNA sequences, the existing proof-of-concept algorithm only scales to ≤50 nucleotides. We present JAX-RNAfold, an open-source software package for our drastically improved differentiable folding algorithm that scales to 1,250 nucleotides on a single GPU. Our software permits the natural inclusion of differentiable folding as a module in larger deep learning pipelines, as well as complex RNA design procedures such as mRNA design with flexible objective functions.
Availability and implementation: JAX-RNAfold is hosted on GitHub (https://github.com/rkruegs123/jax-rnafold) and can be installed locally as a Python package. All source code is also archived on Zenodo (https://doi.org/10.5281/zenodo.15003072).
© The Author(s) 2025. Published by Oxford University Press.