Torch dataset пример: от теории к практике

Когда речь заходит о машинном обучении с использованием PyTorch, одной из первых и фундаментальных концепций является подготовка информации. Эффективная работа с выборками — ключ к созданию производительных и точных моделей. Именно для этой цели в библиотеке существует мощная абстракция — класс `torch.utils.data.Dataset`. Этот материал представляет собой подробный torch dataset пример, который поможет разобраться в его устройстве и научиться создавать собственные структуры для загрузки изображений, текстов или любых других форматов.

По своей сути, `Dataset` — это абстрактный класс, который представляет собой источник ваших сэмплов. Он оборачивает вашу информацию и предоставляет стандартизированный интерфейс для доступа к ней. Вместо того чтобы вручную писать циклы для чтения файлов с диска и их обработки, вы инкапсулируете эту логику внутри одного объекта. Это делает код чище, модульнее и проще для повторного применения.

Основы класса `Dataset`

Чтобы создать свой собственный набор, необходимо унаследовать его от `torch.utils.data.Dataset` и реализовать три ключевых метода. Представьте, что вы создаете каталог для фотоальбома. Вам нужно знать три вещи:

  • `__init__(self, ...)`: Конструктор. Здесь происходит вся первоначальная подготовка. Это как собрать все фотографии в одном месте. Вы можете прочитать файл с аннотациями, составить список путей к файлам или загрузить небольшие метаданные в память.
  • `__len__(self)`: Должен возвращать общее количество образцов в наборе. В аналогии с альбомом, это общее число фотографий, которые у вас есть.
  • `__getitem__(self, idx)`: Самый важный элемент. Он отвечает за получение одного образца по его индексу `idx`. Для нашего альбома это означало бы: «дай мне фотографию номер `idx`». Метод загружает сам элемент (например, изображение с диска), выполняет необходимые преобразования и возвращает его, как правило, в виде кортежа `(данные, метка)`.
Понимание этих трех методов — основа для работы с любыми типами информации в PyTorch. Они формируют контракт, который позволяет другим компонентам фреймворка, таким как `DataLoader`, эффективно взаимодействовать с вашим источником сэмплов.

Создание простого набора данных: работа с числами

Начнем с очень простого примера для иллюстрации концепции. Создадим набор, который по индексу `i` возвращает само число `i` и его квадрат `i*i`. Это синтетический случай, но он отлично демонстрирует механику работы.


import torch
from torch.utils.data import Dataset

class NumbersDataset(Dataset):
    def __init__(self, num_samples=100):
        # В конструкторе определяем общее количество элементов
        self.num_samples = num_samples

    def __len__(self):
        # Возвращаем общее число образцов
        return self.num_samples

    def __getitem__(self, idx):
        # По индексу генерируем образец и его метку
        if idx >= self.num_samples:
            raise IndexError("Index out of range")
        
        feature = torch.tensor([float(idx)], dtype=torch.float32)
        label = torch.tensor([float(idx**2)], dtype=torch.float32)
        return feature, label

# Используем наш кастомный класс
my_numbers_dataset = NumbersDataset(num_samples=10)

# Получим третий элемент (индекс 2)
feature, label = my_numbers_dataset[2]
print(f"Feature: {feature.item()}, Label: {label.item()}")
# Вывод: Feature: 2.0, Label: 4.0

print(f"Total samples: {len(my_numbers_dataset)}")
# Вывод: Total samples: 10

Этот код показывает, как три обязательных метода работают вместе. `__init__` задает размер, `__len__` его сообщает, а `__getitem__` генерирует пару (признак, метка) на лету для любого запрошенного индекса.

Практический torch dataset пример: набор данных изображений

Теперь перейдем к более реалистичной и распространенной задаче — созданию набора для классификации картинок. Предположим, у нас есть папка с изображениями, рассортированными по подпапкам, где имя каждой подпапки является названием класса.

Структура проекта и подготовка данных

Для нашего примера ожидается следующая файловая структура. Это стандартный подход, который используется во многих библиотеках.


data/
├── cat/
│   ├── 1.jpg
│   ├── 2.png
│   └── ...
└── dog/
    ├── 1.jpg
    ├── 2.png
    └── ...

Задача нашего `Dataset` — просканировать эту структуру, составить список всех путей к картинкам и сопоставить каждому пути соответствующую метку класса («cat» или «dog»).

Реализация кастомного `ImageDataset`

Для работы с картинками нам понадобятся библиотеки `os` для навигации по файловой системе и `PIL` (Pillow) для открытия файлов. Также мы будем использовать `torchvision.transforms` для предварительной обработки изображений.


import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.classes = sorted(entry.name for entry in os.scandir(root_dir) if entry.is_dir())
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

        # Собираем пути к файлам и их метки
        for target_class in self.classes:
            class_index = self.class_to_idx[target_class]
            target_dir = os.path.join(root_dir, target_class)
            for fname in os.listdir(target_dir):
                path = os.path.join(target_dir, fname)
                item = (path, class_index)
                self.samples.append(item)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, target = self.samples[idx]
        with Image.open(path).convert('RGB') as img:
            # Применяем трансформации, если они есть
            if self.transform:
                img = self.transform(img)
        
        return img, target

В `__init__` мы сканируем директорию, создаем словарь для сопоставления имен классов с числовыми индексами (`class_to_idx`) и формируем список `self.samples`, содержащий кортежи `(путь_к_картинке, индекс_класса)`. Метод `__getitem__` открывает изображение по пути, конвертирует его в RGB и применяет переданные трансформации.

Зачем нужны трансформации (`transforms`)?

Нейронные сети ожидают на входе тензоры фиксированного размера. Изображения же могут быть разного разрешения и формата. Трансформации (`transforms`) — это функции, которые выполняют предварительную обработку. Основные из них:

  1. `transforms.Resize((w, h))`: Изменяет размер каждого изображения до `w` x `h` пикселей.
  2. `transforms.ToTensor()`: Преобразует изображение из формата PIL или NumPy в тензор PyTorch. Также нормализует значения пикселей из диапазона [0, 255] в [0.0, 1.0].
  3. `transforms.Normalize(mean, std)`: Нормализует тензор, вычитая среднее `mean` и деля на стандартное отклонение `std` для каждого канала. Это улучшает сходимость модели.
  4. Аугментация: Случайные преобразования, такие как `RandomHorizontalFlip` (случайное отражение) или `RandomRotation` (случайный поворот), помогают модели лучше обобщать и предотвращают переобучение.

Загрузка данных с помощью `DataLoader`

Сам по себе `Dataset` только предоставляет доступ к элементам по одному. Для обучения модели нам нужно подавать их пачками (батчами), перемешивать на каждой эпохе и, желательно, делать это в несколько потоков для ускорения. Эту задачу решает `torch.utils.data.DataLoader`.

`DataLoader` — это итератор, который оборачивает `Dataset` и предоставляет всю необходимую функциональность.

Эффективная работа с батчами и перемешиванием

Создать `DataLoader` очень просто. Ему нужно передать экземпляр нашего `Dataset` и указать несколько параметров.


from torch.utils.data import DataLoader

# Определяем последовательность трансформаций
data_transforms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Создаем экземпляр нашего набора
image_dataset = CustomImageDataset(root_dir='data/', transform=data_transforms)

# Оборачиваем его в DataLoader
data_loader = DataLoader(image_dataset, batch_size=32, shuffle=True, num_workers=4)

# Теперь можно итерироваться по data_loader в цикле обучения
for epoch in range(num_epochs):
    for images_batch, labels_batch in data_loader:
        # images_batch имеет размерность (32, 3, 128, 128)
        # labels_batch имеет размерность (32)
        # Здесь происходит обучение модели...
        pass

Ключевые параметры `DataLoader`:

  • `batch_size`: Количество образцов в одной пачке.
  • `shuffle=True`: Перемешивать образцы перед каждой эпохой. Это критически важно для качественного обучения.
  • `num_workers`: Число параллельных процессов для загрузки. Значение больше 0 значительно ускоряет подготовку, так как она происходит параллельно с работой GPU.

Использование связки `Dataset` и `DataLoader` является стандартом в экосистеме PyTorch. Это не только делает код чистым и читаемым, но и скрывает сложную логику многопоточной загрузки, позволяя вам сосредоточиться на архитектуре нейронной сети и процессе ее тренировки. Освоив этот подход, вы сможете эффективно работать с любыми источниками информации, будь то гигантские наборы картинок или потоковые текстовые потоки.