Spaces:
Sleeping
Sleeping
| from torch_lr_finder import LRFinder | |
| from torch.nn import CrossEntropyLoss | |
| import torch.optim as optim | |
| import torch | |
| from transformer import Config, DecoderOnlyTransformer | |
| class DataLoaderLite: | |
| def __init__(self, B, T): | |
| self.B = B | |
| self.T = T | |
| # at init load tokens from disk and store them in memory | |
| with open('input.txt', 'r') as f: | |
| text = f.read() | |
| enc = tiktoken.get_encoding('gpt2') | |
| tokens = enc.encode(text) | |
| self.tokens = torch.tensor(tokens) | |
| print(f'loaded {len(self.tokens)} tokens') | |
| print(f'1 epoch = {len(self.tokens) // (B * T)} batches') | |
| # state | |
| self.current_position = 0 | |
| def next_batch(self): | |
| B, T = self.B, self.T | |
| buf = self.tokens[self.current_position: self.current_position + B * T + 1] | |
| x = (buf[:-1]).view(B, T) # inputs | |
| y = (buf[1:]).view(B, T) # targets | |
| # advance the position in the tensor | |
| self.current_position += B*T | |
| # if loading the next batch would be out of bounds, reset | |
| if self.current_position + (B * T + 1) > len(self.tokens): | |
| self.current_position = 0 | |
| return x, y | |
| batches, no_of_tokens = 16, 128 | |
| train_loader = DataLoaderLite(B=batches, T=no_of_tokens) | |
| steps_per_epoch = len(train_loader.tokens) // (batches * no_of_tokens) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Model configuration | |
| config = Config() | |
| # Tokenizer | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Use GPT-2 tokenizer for compatibility | |
| # Load trained model | |
| model = DecoderOnlyTransformer(config) | |
| model.load_state_dict(torch.load("decoder_only_transformer.pth", map_location=torch.device('cpu'))) | |
| model.eval() | |
| model.to(device) | |
| amp_config = { | |
| 'device_type': 'cuda', | |
| 'dtype': torch.float16, | |
| } | |
| criterion = CrossEntropyLoss() | |
| grad_scaler = torch.cuda.amp.GradScaler() | |
| optimizer = optim.Adam(model.parameters(), lr=1e-3) | |
| # Define a custom batch fetching wrapper | |
| class CustomDataLoader: | |
| def __init__(self, next_batch_func, num_batches): | |
| self.next_batch_func = next_batch_func | |
| self.num_batches = num_batches | |
| self.current_batch = 0 | |
| def __iter__(self): | |
| self.current_batch = 0 | |
| return self | |
| def __next__(self): | |
| if self.current_batch < self.num_batches: | |
| self.current_batch += 1 | |
| return self.next_batch_func() | |
| else: | |
| raise StopIteration | |
| # Create a custom data loader using next_batch | |
| custom_train_loader = CustomDataLoader(train_loader.next_batch(), num_batches=steps_per_epoch) | |
| # Use the custom data loader with LRFinder | |
| lr_finder = LRFinder( | |
| model, optimizer, criterion, device='cuda', | |
| amp_backend='torch', amp_config=amp_config, grad_scaler=grad_scaler | |
| ) | |
| lr_finder.range_test(custom_train_loader, end_lr=5, num_iter=1000, step_mode='exp') | |
| lr_finder.plot() | |
| lr_finder.reset() | |