How to fine-tune an embedding model with Sentence Transformers

Domain language can make a general embedding model miss matches that matter inside a support or search product. Fine-tuning a Sentence Transformers bi-encoder on labeled sentence pairs adapts the vector space so related phrases from the same domain move closer together.

The current Sentence Transformers training stack uses SentenceTransformerTrainer, SentenceTransformerTrainingArguments, a datasets dataset, and a loss function that matches the row format. Sentence-pair similarity data works well with CosineSimilarityLoss because each row carries two texts plus a numeric similarity score.

A tiny support dataset keeps the CPU run short while still exercising the trainer, evaluator, model save, saved-model reload, and first similarity smoke test. Replace the sample rows with reviewed production labels before using the trained model in a retrieval index, and keep a separate evaluation set so training rows do not become the only proof of model behavior.

Steps to fine-tune a Sentence Transformers embedding model:

  1. Install Sentence Transformers with the training dependencies in the active Python environment.
    $ python -m pip install --upgrade "sentence-transformers[train]"

    The training extra installs the datasets and trainer dependencies used by SentenceTransformerTrainer.
    Related: How to install Sentence Transformers with pip

  2. Create the fine-tuning script with labeled sentence-pair scores.
    train_support_embedding.py
    import os
    import shutil
    import warnings
    from pathlib import Path
     
    import numpy as np
    from datasets import Dataset
    from sentence_transformers import (
        SentenceTransformer,
        SentenceTransformerTrainer,
        SentenceTransformerTrainingArguments,
    )
    from sentence_transformers.sentence_transformer.evaluation import (
        EmbeddingSimilarityEvaluator,
        SimilarityFunction,
    )
    from sentence_transformers.sentence_transformer.losses import CosineSimilarityLoss
    from transformers.utils import logging as hf_logging
     
     
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    os.environ["WANDB_DISABLED"] = "true"
    hf_logging.set_verbosity_error()
    warnings.filterwarnings("ignore", category=FutureWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
     
    model_dir = Path("models/support-embedding")
    training_dir = Path("training-output/support-embedding")
    final_dir = model_dir / "final"
     
    shutil.rmtree(model_dir, ignore_errors=True)
    shutil.rmtree(training_dir, ignore_errors=True)
     
    train_dataset = Dataset.from_dict(
        {
            "sentence1": [
                "reset a forgotten password",
                "enable two-factor authentication",
                "export audit logs",
                "restore a deleted project",
                "invite a new support agent",
                "rotate an expired API token",
                "change the billing contact",
                "archive a resolved support case",
            ],
            "sentence2": [
                "send a password reset email",
                "enroll an authenticator app",
                "download the audit log CSV file",
                "recover a project from deleted items",
                "add another person to the support team",
                "replace an expired access token",
                "update the primary billing contact",
                "close and archive the completed case",
            ],
            "score": [0.94, 0.91, 0.89, 0.92, 0.86, 0.9, 0.84, 0.82],
        }
    )
     
    validation_sentences1 = [
        "send a password reset",
        "replace an API token",
        "restore deleted project data",
        "download compliance logs",
    ]
    validation_sentences2 = [
        "email a user a password recovery link",
        "create a replacement access token",
        "recover a removed project",
        "export audit records to CSV",
    ]
    validation_scores = [0.93, 0.91, 0.9, 0.88]
     
    model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    evaluator = EmbeddingSimilarityEvaluator(
        sentences1=validation_sentences1,
        sentences2=validation_sentences2,
        scores=validation_scores,
        main_similarity=SimilarityFunction.COSINE,
        name="support-validation",
        write_csv=False,
    )
     
    before = evaluator(model)
    loss = CosineSimilarityLoss(model)
    args = SentenceTransformerTrainingArguments(
        output_dir=str(training_dir),
        max_steps=4,
        per_device_train_batch_size=4,
        learning_rate=2e-5,
        warmup_steps=0,
        save_strategy="no",
        eval_strategy="no",
        logging_strategy="no",
        disable_tqdm=True,
        report_to=[],
        use_cpu=True,
        seed=7,
    )
     
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        loss=loss,
    )
     
    train_result = trainer.train()
    after = evaluator(model)
    model.save_pretrained(final_dir)
     
    reloaded = SentenceTransformer(str(final_dir))
    embedding = reloaded.encode(["replace an expired access token"], convert_to_numpy=True)
    comparison = reloaded.encode(
        [
            "replace an expired access token",
            "create a replacement access token",
            "download the latest invoice",
        ],
        normalize_embeddings=True,
        convert_to_numpy=True,
    )
    scores = comparison[:1] @ comparison[1:].T
     
    metric_key = evaluator.primary_metric
     
    print(f"Training rows: {len(train_dataset)}")
    print(f"Loss: {loss.__class__.__name__}")
    print(f"Evaluation metric: {metric_key}")
    print(f"Before score: {before[metric_key]:.4f}")
    print(f"Training steps: {train_result.global_step}")
    print(f"After score: {after[metric_key]:.4f}")
    print(f"Saved model: {final_dir}")
    print(f"Reloaded embedding shape: {embedding.shape}")
    print(f"Support score: {scores[0, 0]:.4f}")
    print(f"Unrelated score: {scores[0, 1]:.4f}")
     
    if embedding.shape != (1, 384):
        raise SystemExit(f"unexpected embedding shape: {embedding.shape}")
    if not np.isfinite(scores).all():
        raise SystemExit("non-finite similarity score returned")
    if scores[0, 0] <= scores[0, 1]:
        raise SystemExit("related support text did not score higher")

    Use sentence1, sentence2, and score columns for CosineSimilarityLoss. Keep validation rows separate from training rows when judging whether a fine-tune helped the real task.

  3. Run the fine-tuning script.
    $ python train_support_embedding.py
    ##### snipped #####
    Training rows: 8
    Loss: CosineSimilarityLoss
    Evaluation metric: support-validation_spearman_cosine
    Before score: 0.2000
    Training steps: 4
    After score: 0.2000
    Saved model: models/support-embedding/final
    Reloaded embedding shape: (1, 384)
    Support score: 0.8227
    Unrelated score: 0.1526

    A tiny validation set may keep the same before and after score. The saved path, reload shape, and score split confirm that the model trained, saved, reloaded, and encoded text from the saved directory.

  4. Remove the sample script after copying the pattern into project code.
    $ rm train_support_embedding.py

    Keep models/support-embedding/final when the trained model still needs evaluation, packaging, or loading into a retrieval service.