JAX-RNAfold: scalable differentiable folding

Bioinformatics. 2025 May 6;41(5):btaf203. doi: 10.1093/bioinformatics/btaf203.

Abstract

Summary: Differentiable folding is an emerging paradigm for RNA design in which a probabilistic sequence representation is optimized via gradient descent. However, given the significant memory overhead of differentiating the expected partition function over all RNA sequences, the existing proof-of-concept algorithm only scales to ≤50 nucleotides. We present JAX-RNAfold, an open-source software package for our drastically improved differentiable folding algorithm that scales to 1,250 nucleotides on a single GPU. Our software permits the natural inclusion of differentiable folding as a module in larger deep learning pipelines, as well as complex RNA design procedures such as mRNA design with flexible objective functions.

Availability and implementation: JAX-RNAfold is hosted on GitHub (https://github.com/rkruegs123/jax-rnafold) and can be installed locally as a Python package. All source code is also archived on Zenodo (https://doi.org/10.5281/zenodo.15003072).

MeSH terms

  • Algorithms
  • Computational Biology* / methods
  • Nucleic Acid Conformation
  • RNA Folding*
  • RNA* / chemistry
  • RNA* / genetics
  • Sequence Analysis, RNA* / methods
  • Software*

Substances

  • RNA