How to train a cross-encoder reranker with Sentence Transformers

A CrossEncoder reranker reads a query and one candidate document together, so it can learn relevance signals that a first-stage embedding search may miss. Fine-tuning the reranker on query-document labels helps the second ranking stage favor the answers, passages, or records that match the application's own support language.

Current Sentence Transformers training uses CrossEncoderTrainer with CrossEncoderTrainingArguments. For a reranker, the model usually has one output label, and BinaryCrossEntropyLoss is a direct fit for labeled positive and negative query-document pairs.

Eight labeled pairs keep the run small enough for a CPU smoke test while still exercising the trainer, loss, model save, and saved-model reload path. Replace the rows with reviewed production labels before relying on the model in search, especially when mined negatives may contain documents that also answer the query.

Steps to train a Sentence Transformers cross-encoder reranker:

  1. Install Sentence Transformers with the training dependencies.
    $ pip install "sentence-transformers[train]"

    The training extra installs the dataset and trainer dependencies used by CrossEncoderTrainer.
    Related: How to install Sentence Transformers with pip

  2. Create a reranker training script with labeled query-document pairs.
    train_support_reranker.py
    import os
    import shutil
    import warnings
    from pathlib import Path
     
    import torch
    from datasets import Dataset
    from sentence_transformers import CrossEncoder
    from sentence_transformers.cross_encoder import CrossEncoderTrainer, CrossEncoderTrainingArguments
    from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
    from transformers.utils import logging as hf_logging
     
     
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    os.environ["WANDB_DISABLED"] = "true"
    hf_logging.set_verbosity_error()
    warnings.filterwarnings("ignore", category=FutureWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
     
    model_dir = Path("models/support-reranker")
    final_dir = model_dir / "final"
    if model_dir.exists():
        shutil.rmtree(model_dir)
     
    train_dataset = Dataset.from_dict(
        {
            "query": [
                "reset a user password",
                "reset a user password",
                "enable two-factor authentication",
                "enable two-factor authentication",
                "export audit logs",
                "export audit logs",
                "restore a deleted project",
                "restore a deleted project",
            ],
            "document": [
                "Open the user profile, choose Reset password, and send a recovery email.",
                "Open billing settings and update the saved credit card.",
                "Open account security, scan the authenticator QR code, and save backup codes.",
                "Open the release dashboard and deploy the web container.",
                "Open compliance reports, choose Audit logs, and export the CSV file.",
                "Open team settings and invite a new member by email.",
                "Open deleted projects, select the project, and click Restore.",
                "Open API tokens, revoke the old token, and create a replacement token.",
            ],
            "label": [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
        }
    )
     
    model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2", num_labels=1)
    loss = BinaryCrossEntropyLoss(model=model)
    args = CrossEncoderTrainingArguments(
        output_dir=str(model_dir / "checkpoints"),
        max_steps=4,
        per_device_train_batch_size=2,
        learning_rate=2e-5,
        warmup_steps=0,
        use_cpu=True,
        eval_strategy="no",
        save_strategy="no",
        logging_strategy="no",
        report_to="none",
        disable_tqdm=True,
        seed=7,
    )
     
    trainer = CrossEncoderTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        loss=loss,
    )
     
    print(f"Training rows: {len(train_dataset)}")
    train_result = trainer.train()
    model.save_pretrained(final_dir)
    print(f"Training steps: {train_result.global_step}")
    print(f"Saved model: {final_dir}")
     
    reranker = CrossEncoder(str(final_dir), activation_fn=torch.nn.Sigmoid())
    query = "How do I reset an account password?"
    pairs = [
        (query, "Open the user profile, choose Reset password, and send a recovery email."),
        (query, "Open the release dashboard and deploy the web container."),
    ]
    scores = reranker.predict(pairs, convert_to_numpy=True)
     
    print(f"Relevant score: {scores[0]:.3f}")
    print(f"Unrelated score: {scores[1]:.3f}")
    print("Top document:", pairs[int(scores.argmax())][1])

    Use one row per query-document pair with a numeric label. For larger runs, split the dataset and add a reranking evaluator before saving the final model.

  3. Run the training script.
    $ python3 train_support_reranker.py
    Training rows: 8
    ##### snipped #####
    Training steps: 4
    Saved model: models/support-reranker/final
    Relevant score: 0.995
    Unrelated score: 0.000
    Top document: Open the user profile, choose Reset password, and send a recovery email.

    Trainer metrics include runtime values that vary by CPU. Check the rows, steps, saved path, and score comparison before moving the saved model into a search workflow.

  4. Remove the sample training script after copying the pattern into the project code.
    $ rm train_support_reranker.py

    Keep models/support-reranker/final when the saved reranker still needs evaluation, packaging, or search-service loading.