import os import warnings from datasets import Dataset, DatasetDict, disable_progress_bars from sentence_transformers import ( SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) from sentence_transformers.sentence_transformer.evaluation import ( EmbeddingSimilarityEvaluator, ) from sentence_transformers.sentence_transformer.losses import MultipleNegativesRankingLoss from transformers.utils import logging as hf_logging os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["WANDB_DISABLED"] = "true" hf_logging.set_verbosity_error() warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) disable_progress_bars() raw_rows = [ { "query": "reset an expired API token", "positive": "create a replacement access token", "negative": "change the workspace color theme", "score": 0.94, "source": "support-ticket-001", }, { "query": "restore a deleted project", "positive": "recover a project from deleted items", "negative": "download the latest invoice", "score": 0.92, "source": "support-ticket-002", }, { "query": "export audit logs", "positive": "download the audit log CSV file", "negative": "invite a new support agent", "score": 0.89, "source": "support-ticket-003", }, { "query": "enable two factor authentication", "positive": "enroll an authenticator app", "negative": "archive a resolved support case", "score": 0.91, "source": "support-ticket-004", }, { "query": "change the billing contact", "positive": "update the primary billing contact", "negative": "restore a deleted project", "score": 0.86, "source": "support-ticket-005", }, { "query": "invite a new support agent", "positive": "add another person to the support team", "negative": "reset an expired API token", "score": 0.84, "source": "support-ticket-006", }, ] raw_dataset = Dataset.from_list(raw_rows) ranking_train = raw_dataset.select_columns(["query", "positive"]).rename_columns( {"query": "anchor"} ) similarity_eval = raw_dataset.select_columns(["query", "positive", "score"]).rename_columns( {"query": "sentence1", "positive": "sentence2"} ) prepared = DatasetDict( { "ranking_train": ranking_train, "similarity_eval": similarity_eval, } ) prepared.save_to_disk("prepared/support-training") reloaded = DatasetDict.load_from_disk("prepared/support-training") model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") loss = MultipleNegativesRankingLoss(model) args = SentenceTransformerTrainingArguments( output_dir="training-output/dataset-dry-run", max_steps=1, per_device_train_batch_size=2, use_cpu=True, save_strategy="no", eval_strategy="no", logging_strategy="no", disable_tqdm=True, report_to=[], ) trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=reloaded["ranking_train"], loss=loss, ) next(iter(trainer.get_train_dataloader())) evaluator = EmbeddingSimilarityEvaluator( sentences1=reloaded["similarity_eval"]["sentence1"], sentences2=reloaded["similarity_eval"]["sentence2"], scores=reloaded["similarity_eval"]["score"], name="support-similarity", write_csv=False, ) print(f"raw columns: {raw_dataset.column_names}") print(f"ranking_train columns: {reloaded['ranking_train'].column_names}") print(f"similarity_eval columns: {reloaded['similarity_eval'].column_names}") print(f"ranking rows: {len(reloaded['ranking_train'])}") print(f"trainer loss: {loss.__class__.__name__}") print("trainer dataloader: ready") print(f"evaluator: {evaluator.__class__.__name__}") print(f"evaluator pairs: {len(evaluator.sentences1)}") print(f"saved dataset splits: {list(reloaded.keys())}") print("sample anchor:", reloaded["ranking_train"][0]["anchor"]) print("sample positive:", reloaded["ranking_train"][0]["positive"])