from __future__ import annotations import argparse import json import os import random import time from pathlib import Path import torch from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig def main() -> int: parser = argparse.ArgumentParser(description="LoRA fine-tune gpt-oss-20b on a local JSONL text corpus.") parser.add_argument("--model", default="openai/gpt-oss-20b") parser.add_argument("--data", type=Path, default=Path("training_data/relevant/dataset.jsonl")) parser.add_argument("--out", type=Path, default=Path("training_data/lora_adapter")) parser.add_argument("--max-length", type=int, default=256) parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--max-steps", type=int, default=0, help="If >0, stop after this many optimizer steps.") parser.add_argument("--lr", type=float, default=2e-4) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--lora-r", type=int, default=8) parser.add_argument("--lora-alpha", type=int, default=16) parser.add_argument("--lora-dropout", type=float, default=0.05) parser.add_argument("--grad-accum", type=int, default=4) parser.add_argument("--device", default="auto") parser.add_argument("--device-map", choices=["auto", "cuda"], default="auto") parser.add_argument("--cpu-offload", action="store_true") parser.add_argument("--max-gpu-mem", default=None, help="Max GPU memory for device_map=auto, e.g. 10GiB") parser.add_argument("--max-cpu-mem", default=None, help="Max CPU memory for device_map=auto, e.g. 64GiB") parser.add_argument("--quant", choices=["auto", "none", "4bit"], default="auto") parser.add_argument("--log-steps", type=int, default=10) parser.add_argument("--log-seconds", type=int, default=120) parser.add_argument("--local-files-only", action="store_true") args = parser.parse_args() random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) # Reduce noisy parallelism; avoid oversubscribing if user has many cores. if "OMP_NUM_THREADS" not in os.environ: os.environ["OMP_NUM_THREADS"] = str(max(1, (os.cpu_count() or 8) // 2)) args.out.mkdir(parents=True, exist_ok=True) lines = args.data.read_text(encoding="utf-8", errors="ignore").splitlines() records = [json.loads(ln) for ln in lines if ln.strip()] texts = [r.get("text", "") for r in records if isinstance(r, dict) and r.get("text")] if not texts: raise SystemExit(f"No text records found in {args.data}") print(f"Loaded {len(texts)} training samples from {args.data}") if args.device == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) if device.type == "cuda": torch.backends.cuda.matmul.allow_tf32 = True if args.quant in {"auto", "4bit"} and device.type != "cuda": raise SystemExit("Quantized loading requires CUDA. Use --quant none for CPU.") print("Loading tokenizer...") tok = AutoTokenizer.from_pretrained(args.model, local_files_only=args.local_files_only) if tok.pad_token is None: tok.pad_token = tok.eos_token print("Loading model...") config = AutoConfig.from_pretrained(args.model, local_files_only=args.local_files_only) has_quant_attr = hasattr(config, "quantization_config") if args.quant == "auto": model = AutoModelForCausalLM.from_pretrained( args.model, local_files_only=args.local_files_only, device_map="auto", ) elif args.quant == "4bit": compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=compute_dtype, llm_int8_enable_fp32_cpu_offload=args.cpu_offload, ) if has_quant_attr: delattr(config, "quantization_config") device_map = {"": 0} if args.device_map == "cuda" else "auto" load_kwargs = dict( local_files_only=args.local_files_only, config=config, device_map=device_map, torch_dtype=compute_dtype, ) if device_map == "auto" and (args.max_gpu_mem or args.max_cpu_mem): max_memory = {} if args.max_gpu_mem: max_memory[0] = args.max_gpu_mem if args.max_cpu_mem: max_memory["cpu"] = args.max_cpu_mem load_kwargs["max_memory"] = max_memory load_kwargs["quantization_config"] = bnb_config model = AutoModelForCausalLM.from_pretrained(args.model, **load_kwargs) model = prepare_model_for_kbit_training(model) else: model = AutoModelForCausalLM.from_pretrained( args.model, local_files_only=args.local_files_only, torch_dtype=torch.bfloat16, device_map={"": "cpu" if device.type == "cpu" else 0}, low_cpu_mem_usage=True, ) model.to(device) lora_cfg = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ) model = get_peft_model(model, lora_cfg) model.print_trainable_parameters() model.config.use_cache = False if device.type == "cuda": model.gradient_checkpointing_enable() # Needed for PEFT + gradient checkpointing to ensure grads flow to LoRA params. model.enable_input_require_grads() trainable = [p for p in model.parameters() if p.requires_grad] opt = torch.optim.AdamW(trainable, lr=args.lr) # Training loop model.train() total_opt_steps = 0 total_batches = 0 loss_ema: float | None = None last_log = time.time() accum_steps = 0 for epoch in range(1, args.epochs + 1): order = list(range(len(texts))) random.shuffle(order) for i in order: text = texts[i] batch = tok(text, return_tensors="pt", truncation=True, max_length=args.max_length) if batch["input_ids"].numel() < 32: continue batch["labels"] = batch["input_ids"].clone() batch = {k: v.to(device) for k, v in batch.items()} t0 = time.time() out = model(**batch) loss = out.loss / max(1, args.grad_accum) loss.backward() accum_steps += 1 if accum_steps >= max(1, args.grad_accum): torch.nn.utils.clip_grad_norm_(trainable, 1.0) opt.step() opt.zero_grad(set_to_none=True) total_opt_steps += 1 accum_steps = 0 dt = time.time() - t0 total_batches += 1 lv = float(loss.detach().cpu().item()) * max(1, args.grad_accum) loss_ema = lv if loss_ema is None else (0.95 * loss_ema + 0.05 * lv) if (args.log_steps and total_opt_steps % args.log_steps == 0) or ( args.log_seconds and time.time() - last_log >= args.log_seconds ): tok_count = int(batch["input_ids"].numel()) print( f"epoch {epoch}/{args.epochs} step {total_opt_steps} " f"loss {lv:.4f} ema {loss_ema:.4f} " f"{dt:.2f}s {tok_count} tokens" ) last_log = time.time() if args.max_steps and total_opt_steps >= args.max_steps: break if accum_steps: torch.nn.utils.clip_grad_norm_(trainable, 1.0) opt.step() opt.zero_grad(set_to_none=True) total_opt_steps += 1 accum_steps = 0 if args.max_steps and total_opt_steps >= args.max_steps: break if args.max_steps and total_opt_steps >= args.max_steps: break print(f"Saving adapter to {args.out} ...") model.save_pretrained(args.out) tok.save_pretrained(args.out) summary = { "model": args.model, "data": str(args.data), "out": str(args.out), "max_length": args.max_length, "epochs": args.epochs, "max_steps": args.max_steps, "lr": args.lr, "lora_r": args.lora_r, "lora_alpha": args.lora_alpha, "lora_dropout": args.lora_dropout, "optimizer_steps": total_opt_steps, "loss_ema": loss_ema, } (args.out / "training_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8") print("Done.") return 0 if __name__ == "__main__": raise SystemExit(main())