import os import shutil import warnings from pathlib import Path import torch from datasets import Dataset from sentence_transformers import ( SparseEncoder, SparseEncoderTrainer, SparseEncoderTrainingArguments, ) from sentence_transformers.sparse_encoder.losses import ( SparseMultipleNegativesRankingLoss, SpladeLoss, ) from transformers.utils import logging as hf_logging os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["WANDB_DISABLED"] = "true" hf_logging.set_verbosity_error() warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) model_dir = Path("models/support-sparse-encoder") training_dir = Path("training-output/support-sparse-encoder") final_dir = model_dir / "final" shutil.rmtree(model_dir, ignore_errors=True) shutil.rmtree(training_dir, ignore_errors=True) train_dataset = Dataset.from_dict( { "anchor": [ "reset a locked account", "reset a user password", "enable two-factor authentication", "rotate an API token", "restore a deleted project", "invite a new support agent", "export audit logs", "change the billing contact", ], "positive": [ "Unlock the account from the admin users page.", "Open the user profile and send a password reset email.", "Open account security and enroll an authenticator app.", "Revoke the old API token and create a replacement token.", "Open deleted projects and restore the selected project.", "Open team settings and send an invitation email.", "Open compliance reports and export the audit log CSV.", "Open billing settings and update the primary contact.", ], } ) model = SparseEncoder("sparse-encoder-testing/splade-bert-tiny-nq") base_loss = SparseMultipleNegativesRankingLoss(model=model) loss = SpladeLoss( model=model, loss=base_loss, document_regularizer_weight=3e-5, query_regularizer_weight=5e-5, ) args = SparseEncoderTrainingArguments( output_dir=str(training_dir), max_steps=2, per_device_train_batch_size=2, learning_rate=2e-5, warmup_steps=0, use_cpu=True, save_strategy="no", eval_strategy="no", logging_strategy="no", disable_tqdm=True, report_to=[], seed=7, ) trainer = SparseEncoderTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, ) train_result = trainer.train() model.save_pretrained(final_dir) reloaded = SparseEncoder(str(final_dir)) query_embedding = reloaded.encode_query( ["How do I replace an expired API token?"], convert_to_tensor=True ).coalesce() document_embeddings = reloaded.encode_document( [ "Revoke the old API token and create a replacement token.", "Open billing settings and update the saved credit card.", "Open team settings and send an invitation email.", ], convert_to_tensor=True, ).coalesce() scores = torch.mm(query_embedding.to_dense(), document_embeddings.to_dense().T) best_index = int(scores.argmax().item()) print(f"Training rows: {len(train_dataset)}") print(f"Loss wrapper: {loss.__class__.__name__}") print(f"Base loss: {base_loss.__class__.__name__}") print(f"Training steps: {train_result.global_step}") print(f"Saved model: {final_dir}") print(f"Query embedding shape: {tuple(query_embedding.shape)}") print(f"Query active dimensions: {query_embedding._nnz()}") print(f"Document embedding shape: {tuple(document_embeddings.shape)}") print(f"Document active dimensions: {document_embeddings._nnz()}") print(f"Top smoke-test document: d{best_index + 1}") print(f"Top smoke-test score: {scores[0, best_index]:.4f}") if train_result.global_step != 2: raise SystemExit(f"unexpected training steps: {train_result.global_step}") if query_embedding._nnz() == 0 or document_embeddings._nnz() == 0: raise SystemExit("sparse encoder returned no active dimensions") if not torch.isfinite(scores).all(): raise SystemExit("non-finite sparse similarity score returned")