import os import shutil import warnings from pathlib import Path import torch from datasets import Dataset from sentence_transformers import CrossEncoder from sentence_transformers.cross_encoder import CrossEncoderTrainer, CrossEncoderTrainingArguments from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss from transformers.utils import logging as hf_logging os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["WANDB_DISABLED"] = "true" hf_logging.set_verbosity_error() warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) model_dir = Path("models/support-reranker") final_dir = model_dir / "final" if model_dir.exists(): shutil.rmtree(model_dir) train_dataset = Dataset.from_dict( { "query": [ "reset a user password", "reset a user password", "enable two-factor authentication", "enable two-factor authentication", "export audit logs", "export audit logs", "restore a deleted project", "restore a deleted project", ], "document": [ "Open the user profile, choose Reset password, and send a recovery email.", "Open billing settings and update the saved credit card.", "Open account security, scan the authenticator QR code, and save backup codes.", "Open the release dashboard and deploy the web container.", "Open compliance reports, choose Audit logs, and export the CSV file.", "Open team settings and invite a new member by email.", "Open deleted projects, select the project, and click Restore.", "Open API tokens, revoke the old token, and create a replacement token.", ], "label": [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], } ) model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2", num_labels=1) loss = BinaryCrossEntropyLoss(model=model) args = CrossEncoderTrainingArguments( output_dir=str(model_dir / "checkpoints"), max_steps=4, per_device_train_batch_size=2, learning_rate=2e-5, warmup_steps=0, use_cpu=True, eval_strategy="no", save_strategy="no", logging_strategy="no", report_to="none", disable_tqdm=True, seed=7, ) trainer = CrossEncoderTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, ) print(f"Training rows: {len(train_dataset)}") train_result = trainer.train() model.save_pretrained(final_dir) print(f"Training steps: {train_result.global_step}") print(f"Saved model: {final_dir}") reranker = CrossEncoder(str(final_dir), activation_fn=torch.nn.Sigmoid()) query = "How do I reset an account password?" pairs = [ (query, "Open the user profile, choose Reset password, and send a recovery email."), (query, "Open the release dashboard and deploy the web container."), ] scores = reranker.predict(pairs, convert_to_numpy=True) print(f"Relevant score: {scores[0]:.3f}") print(f"Unrelated score: {scores[1]:.3f}") print("Top document:", pairs[int(scores.argmax())][1])