from pathlib import Path import shutil from datasets import Dataset from sentence_transformers import ( SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) from sentence_transformers.sentence_transformer.evaluation import InformationRetrievalEvaluator from sentence_transformers.sentence_transformer.losses import MultipleNegativesRankingLoss from sentence_transformers.sentence_transformer.training_args import BatchSamplers output_dir = Path("support-mnrl-model") training_dir = Path("mnrl-training-output") shutil.rmtree(output_dir, ignore_errors=True) shutil.rmtree(training_dir, ignore_errors=True) train_dataset = Dataset.from_dict( { "anchor": [ "reset a locked account", "reset a user password", "enable two-factor authentication", "rotate an API token", "restore a deleted project", "invite a new team member", "export audit logs", "change the billing contact", ], "positive": [ "Unlock the account from the admin users page.", "Open the user profile and send a password reset email.", "Open account security and enroll an authenticator app.", "Revoke the old API token and create a replacement token.", "Open deleted projects and restore the selected project.", "Open team settings and send an invitation email.", "Open compliance reports and export the audit log CSV.", "Open billing settings and update the primary contact.", ], } ) queries = { "q1": "How do I send a password reset?", "q2": "How do I replace an API token?", "q3": "How do I recover a deleted project?", } corpus = { "d1": "Open the user profile and send a password reset email.", "d2": "Revoke the old API token and create a replacement token.", "d3": "Open deleted projects and restore the selected project.", "d4": "Open billing settings and update the primary contact.", } relevant_docs = { "q1": {"d1"}, "q2": {"d2"}, "q3": {"d3"}, } model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") evaluator = InformationRetrievalEvaluator( queries=queries, corpus=corpus, relevant_docs=relevant_docs, name="support-faq", accuracy_at_k=[1], precision_recall_at_k=[1, 3], map_at_k=[3], ndcg_at_k=[3], show_progress_bar=False, ) before = evaluator(model) args = SentenceTransformerTrainingArguments( output_dir=str(training_dir), num_train_epochs=1, per_device_train_batch_size=4, learning_rate=2e-5, warmup_steps=0.0, batch_sampler=BatchSamplers.NO_DUPLICATES, save_strategy="no", eval_strategy="no", logging_steps=1, disable_tqdm=True, report_to=[], ) loss = MultipleNegativesRankingLoss(model) trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, ) trainer.train() after = evaluator(model) model.save_pretrained(output_dir) reloaded = SentenceTransformer(str(output_dir)) embedding = reloaded.encode(["How do I replace an API token?"], convert_to_numpy=True) metric_key = "support-faq_cosine_ndcg@3" print(f"train rows: {len(train_dataset)}") print(f"loss: {loss.__class__.__name__}") print(f"retrieval metric: {metric_key}") print(f"before score: {before[metric_key]:.4f}") print(f"after score: {after[metric_key]:.4f}") print(f"saved model path: {output_dir}") print(f"reloaded embedding shape: {embedding.shape}")