From 0229aedfa1de06fab2429502a0ee8b84320899d1 Mon Sep 17 00:00:00 2001 From: Nick Wall <46641379+walln@users.noreply.github.com> Date: Mon, 16 Jun 2025 11:35:20 -0400 Subject: [PATCH] fix: handle hf dataset tensor warnings --- .../causal_langauge_modeling_dataset.py | 24 ++++----- .../datasets/image_classification_dataset.py | 52 ++++++++++--------- .../datasets/question_answering_dataset.py | 8 +-- .../sequence_classification_dataset.py | 38 ++++++-------- .../datasets/token_classification_dataset.py | 6 +-- 5 files changed, 62 insertions(+), 66 deletions(-) diff --git a/src/scratch/datasets/causal_langauge_modeling_dataset.py b/src/scratch/datasets/causal_langauge_modeling_dataset.py index 6f36a7c..8b51d50 100644 --- a/src/scratch/datasets/causal_langauge_modeling_dataset.py +++ b/src/scratch/datasets/causal_langauge_modeling_dataset.py @@ -90,18 +90,18 @@ def load_hf_dataset( ) if shuffle: - data = data.shuffle().with_format("torch") + data = data.shuffle() if validate: - data = data.filter(validate).with_format("torch") + data = data.filter(validate) def tokenize_function(examples): return tokenizer(examples["text"], padding="max_length", truncation=True) - data = data.map(tokenize_function, batched=True).with_format("torch") + data = data.map(tokenize_function, batched=True) if prepare: - data = data.map(prepare).with_format("torch") + data = data.map(prepare) return data.with_format("torch") @@ -162,9 +162,9 @@ def transform(batch: CausalLanguageModelingBatch): batch["attention_mask"], batch["labels"], ) - input_ids = torch.tensor(input_ids, dtype=torch.int64) - attention_mask = torch.tensor(attention_mask, dtype=torch.int64) - labels = torch.tensor(labels, dtype=torch.int64) + input_ids = torch.as_tensor(input_ids, dtype=torch.int64) + attention_mask = torch.as_tensor(attention_mask, dtype=torch.int64) + labels = torch.as_tensor(labels, dtype=torch.int64) return CausalLanguageModelingBatch( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) @@ -196,12 +196,12 @@ def wikitext2_dataset( tokenizer = load_tokenizer(tokenizer_name, max_length=max_length) def prepare(sample): - input_ids = sample["input_ids"] - input_ids = torch.tensor(input_ids, dtype=torch.int64) - labels = input_ids.clone() + input_ids = np.array(sample["input_ids"], dtype=np.int64) + labels = input_ids.copy() # Make a lower triangular attention mask - attention_mask = np.tril(np.ones((len(input_ids), len(input_ids)))) - attention_mask = torch.tensor(attention_mask, dtype=torch.int64) + attention_mask = np.tril( + np.ones((len(input_ids), len(input_ids)), dtype=np.int64) + ) sample["input_ids"], sample["attention_mask"], sample["labels"] = ( input_ids, attention_mask, diff --git a/src/scratch/datasets/image_classification_dataset.py b/src/scratch/datasets/image_classification_dataset.py index 38afb5e..ca43b23 100644 --- a/src/scratch/datasets/image_classification_dataset.py +++ b/src/scratch/datasets/image_classification_dataset.py @@ -94,17 +94,20 @@ def load_hf_dataset( The IterableDataset object """ data = load_dataset( - dataset_name, split=dataset_split, trust_remote_code=True, streaming=True - ).with_format("torch") + dataset_name, + split=dataset_split, + trust_remote_code=True, + streaming=True, + ) if shuffle: - data = data.shuffle().with_format("torch") + data = data.shuffle() if validate: - data = data.filter(validate).with_format("torch") + data = data.filter(validate) if prepare: - data = data.map(prepare).with_format("torch") + data = data.map(prepare) return data.with_format("torch") @@ -172,15 +175,13 @@ def mnist_dataset(batch_size=32, shuffle=True): def prepare(sample): images, labels = sample["image"], sample["label"] - # Ensure the images are float tensors - images = images.to(torch.float32) - # Normalize the images - images = images / 255.0 - # Convert labels to one-hot encoding - labels = labels.to(torch.int64) # Ensure labels are int32 tensors - labels = F.one_hot(labels, num_classes=10).to(torch.int32) - - sample["image"], sample["label"] = images, labels + images = transforms.ToTensor()(images).to(torch.float32) + labels = F.one_hot( + torch.as_tensor(labels, dtype=torch.int64), + num_classes=10, + ).to(torch.int32) + + sample["image"], sample["label"] = images.numpy(), labels.numpy() return sample train_data, test_data = ( @@ -219,20 +220,21 @@ def tiny_imagenet_dataset(batch_size=32, shuffle=True): def prepare(sample): images, labels = sample["image"], sample["label"] - # Ensure the images are float tensors - images = images.clone().detach().to(torch.float32) - # Normalize the images - images = images / 255.0 - # Convert labels to one-hot encoding - labels = labels.clone().detach().to(torch.int64) # Ensure labels are int32 - labels = F.one_hot(labels, num_classes=200).to(torch.int32) - - sample["image"], sample["label"] = images, labels + images = transforms.ToTensor()(images).to(torch.float32) + labels = F.one_hot( + torch.as_tensor(labels, dtype=torch.int64), + num_classes=200, + ).to(torch.int32) + + sample["image"], sample["label"] = images.numpy(), labels.numpy() return sample def validate(sample): - transform = transforms.ToTensor() - img = transform(sample["image"]) + img = ( + sample["image"] + if isinstance(sample["image"], torch.Tensor) + else transforms.ToTensor()(sample["image"]) + ) return ( img.shape == (3, 64, 64) and torch.isnan(img).sum() == 0 diff --git a/src/scratch/datasets/question_answering_dataset.py b/src/scratch/datasets/question_answering_dataset.py index 9f4e58e..f83e039 100644 --- a/src/scratch/datasets/question_answering_dataset.py +++ b/src/scratch/datasets/question_answering_dataset.py @@ -88,10 +88,10 @@ def transform(batch): batch["start_positions"], batch["end_positions"], ) - input_ids = torch.tensor(input_ids, dtype=torch.long) - attention_mask = torch.tensor(attention_mask, dtype=torch.long) - start_positions = torch.tensor(start_positions, dtype=torch.long) - end_positions = torch.tensor(end_positions, dtype=torch.long) + input_ids = torch.as_tensor(input_ids, dtype=torch.long) + attention_mask = torch.as_tensor(attention_mask, dtype=torch.long) + start_positions = torch.as_tensor(start_positions, dtype=torch.long) + end_positions = torch.as_tensor(end_positions, dtype=torch.long) return QuestionAnsweringBatch( input_ids=input_ids, attention_mask=attention_mask, diff --git a/src/scratch/datasets/sequence_classification_dataset.py b/src/scratch/datasets/sequence_classification_dataset.py index 7f39053..658676b 100644 --- a/src/scratch/datasets/sequence_classification_dataset.py +++ b/src/scratch/datasets/sequence_classification_dataset.py @@ -81,22 +81,25 @@ def load_hf_dataset( The IterableDataset object """ data = load_dataset( - dataset_name, split=dataset_split, trust_remote_code=True, streaming=True + dataset_name, + split=dataset_split, + trust_remote_code=True, + streaming=True, ) if shuffle: - data = data.shuffle().with_format("torch") + data = data.shuffle() if validate: - data = data.filter(validate).with_format("torch") + data = data.filter(validate) def tokenize_function(examples): return tokenizer(examples["text"], padding="max_length", truncation=True) - data = data.map(tokenize_function, batched=True).with_format("torch") + data = data.map(tokenize_function, batched=True) if prepare: - data = data.map(prepare).with_format("torch") + data = data.map(prepare) return data.with_format("torch") @@ -157,8 +160,8 @@ def transform(batch: SequenceClassificationBatch): batch["input_ids"], batch["label"], ) - input_ids = torch.tensor(input_ids, dtype=torch.int64) - label = torch.tensor(label, dtype=torch.int64) + input_ids = torch.as_tensor(input_ids, dtype=torch.int64) + label = torch.as_tensor(label, dtype=torch.int64) label = F.one_hot(label, num_classes=num_classes).to(torch.int32) return SequenceClassificationBatch( input_ids=input_ids, @@ -192,21 +195,12 @@ def imdb_dataset( tokenizer = load_tokenizer(tokenizer_name, max_length=max_length) def prepare(sample): - input_ids, labels = ( - sample["input_ids"], - sample["label"], - ) - input_ids = torch.tensor(input_ids, dtype=torch.int64) - labels = torch.tensor(labels, dtype=torch.int64) - labels = F.one_hot(labels, num_classes=2).to(torch.int32) - - ( - sample["input_ids"], - sample["label"], - ) = ( - input_ids, - labels, - ) + input_ids, labels = sample["input_ids"], sample["label"] + input_ids = np.array(input_ids, dtype=np.int64) + labels_tensor = torch.as_tensor(labels, dtype=torch.int64) + labels = F.one_hot(labels_tensor, num_classes=2).to(torch.int32).numpy() + + sample["input_ids"], sample["label"] = input_ids, labels return sample train_data, test_data = ( diff --git a/src/scratch/datasets/token_classification_dataset.py b/src/scratch/datasets/token_classification_dataset.py index e4ff398..55a5d9c 100644 --- a/src/scratch/datasets/token_classification_dataset.py +++ b/src/scratch/datasets/token_classification_dataset.py @@ -91,9 +91,9 @@ def transform(batch): batch["attention_mask"], batch["labels"], ) - input_ids = torch.tensor(input_ids, dtype=torch.long) - attention_mask = torch.tensor(attention_mask, dtype=torch.long) - labels = torch.tensor(labels, dtype=torch.long) + input_ids = torch.as_tensor(input_ids, dtype=torch.long) + attention_mask = torch.as_tensor(attention_mask, dtype=torch.long) + labels = torch.as_tensor(labels, dtype=torch.long) return TokenClassificationBatch( input_ids=input_ids, attention_mask=attention_mask, labels=labels )