Skip to content

CryoFM1: Sampling

Prerequisites

Before using CryoFM1, ensure you have:

1. Install CryoFM with compatible dependencies

CryoFM1 uses the HDiT model architecture, which depends on the natten package. Different versions of natten have varying requirements for PyTorch and CUDA versions. For a reproducible installation, follow these steps:

# natten 0.17.5 uses type union syntax, you must use python >=3.10
conda create -n cryofm python=3.10 -y
conda activate cryofm

# Install PyTorch 2.5.1 with CUDA 12.4 support
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124

# Install natten 0.17.5 compatible with PyTorch 2.5.0 and CUDA 12.4
pip install natten==0.17.5+torch250cu124 -f https://whl.natten.org

# Clone and install CryoFM
git clone https://github.com/ByteDance-Seed/cryofm
cd cryofm
pip install .

2. Download model checkpoints and configuration files

Download the CryoFM1 model weights and configuration files from the Hugging Face repository.


Basic Usage

CryoFM-S: Unconditional Generation

CryoFM-S generates 64×64×64 voxel density maps at 1.5 Å/pixel resolution. Example outputs are shown below:

CryoFM-S sampling examples.

CryoFM-S sampling examples at 1.5 Å/pixel resolution.
import torch
from mmengine import Config
from cryofm.core.utils.mrc_io import save_mrc
from cryofm.projects.cryofm1.lit_modules import CryoFM1
from cryofm.core.utils.sampling_fm import sample_from_fm

# Load configuration and model
cfg = Config.fromfile("path_to/cryofm-v1/cryofm-s/config.yaml")
lit_model = CryoFM1.load_from_safetensors(
    "path_to/cryofm-v1/cryofm-s/model.safetensors", 
    cfg=cfg
)

# Set up device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lit_model = lit_model.to(device)
lit_model.eval()

# Define vector field function for flow matching
def v_xt_t(_xt, _t):
    return lit_model(_xt, _t)

# Generate samples
# Note: Enable bfloat16 if your GPU supports it for better performance
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
    out = sample_from_fm(
        v_xt_t, 
        lit_model.noise_scheduler, 
        method="euler", 
        num_steps=200, 
        num_samples=3, 
        device=device, 
        side_shape=64
    )
    # Apply z-scaling normalization if configured
    if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
        out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean

# Save generated density maps
for i in range(3):
    save_mrc(
        out[i].float().cpu().numpy(), 
        f"sample-{i}.mrc", 
        apix=1.5  # Angstroms per pixel
    )

CryoFM-L: Unconditional Generation

CryoFM-L generates 128×128×128 voxel density maps at 3.0 Å/pixel resolution. Example outputs are shown below:

CryoFM-L sampling examples.

CryoFM-L sampling examples at 3.0 Å/pixel resolution.
import torch
from mmengine import Config
from cryofm.core.utils.mrc_io import save_mrc
from cryofm.projects.cryofm1.lit_modules import CryoFM1
from cryofm.core.utils.sampling_fm import sample_from_fm

# Load configuration and model
cfg = Config.fromfile("path_to/cryofm-v1/cryofm-l/config.yaml")
lit_model = CryoFM1.load_from_safetensors(
    "path_to/cryofm-v1/cryofm-l/model.safetensors", 
    cfg=cfg
)

# Set up device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lit_model = lit_model.to(device)
lit_model.eval()

# Define vector field function for flow matching
def v_xt_t(_xt, _t):
    return lit_model(_xt, _t)

# Generate samples
# Note: Enable bfloat16 if your GPU supports it for better performance
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
    out = sample_from_fm(
        v_xt_t, 
        lit_model.noise_scheduler, 
        method="euler", 
        num_steps=200, 
        num_samples=3, 
        device=device, 
        side_shape=128
    )
    # Apply z-scaling normalization if configured
    if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
        out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean

# Save generated density maps
for i in range(3):
    save_mrc(
        out[i].float().cpu().numpy(), 
        f"sample-{i}.mrc", 
        apix=3.0  # Angstroms per pixel
    )

Advanced Usage

Sampling Methods

The sample_from_fm function supports multiple ODE solvers for the flow matching process:

  • "euler": Euler method (default, fastest)
  • "rk4": 4th-order Runge-Kutta method (more accurate, slower)
  • "midpoint": Midpoint method (balanced)
  • "heun": Heun's method
  • "ralston": Ralston's method

For most use cases, "euler" with 200 steps provides a good balance between quality and speed. For higher quality, consider using "rk4" or increasing num_steps.

Adjusting Sampling Parameters

  • num_steps: Number of integration steps (default: 200). More steps generally yield better quality but take longer.
  • num_samples: Number of samples to generate in a single batch.
  • side_shape: Spatial dimensions of the output volume (64 for CryoFM-S, 128 for CryoFM-L).