IOGS - ARDF Projects - Chaos segmentation

This file is a helper for loading the Chaos challenge data

Imports

In [1]:
import numpy as np
import os
import pickle
import random
import torch
import torch.utils.data

Define the dataset

The dataset is a Pytorch class. The datasets objects, used along a dataloader provide de data in a pytorch format. The following class, takes numpy arrays for the input data and the targets.

Note: This class does not operate a data normalization, normalization must be either done before creating the dataset or modify the definition of the class.

In [2]:
class ImageDataset(torch.utils.data.Dataset):
    """Main Class for Image Folder loader."""

    def __init__(self, data, targets, ids, imsize=None):
        """Init function."""

        self.imsize = imsize
        self.ids = ids
        self.data = []
        self.targets = []
        for key, val in data.items():
            if key in ids:
                self.data.append(val)
        for key, val in targets.items():
            if key in ids:
                self.targets.append(val)

        self.coord = []
        for d_id, d in enumerate(self.data):
            for i in range(d.shape[0]):
                self.coord.append([d_id, i])

    def __getitem__(self, index):
        """Get item."""

        i, rid = self.coord[index]
        data, target = self.data[i], self.targets[i]
        rid = random.randint(0,data.shape[0]-1)
        data, target = data[rid], target[rid]


        # convert to float32
        data = data.astype(np.float32)
        target = target.astype(np.float32)

        # convert to torch tensors
        if len(data.shape)==2:
            data = np.expand_dims(data, 0)
        else:
            # in troch channels are first
            data = data.transpose(2,0,1)

        if self.imsize is not None:
            _, w, h = data.shape
            x1 = random.randint(0, w - self.imsize)
            y1 = random.randint(0, h - self.imsize)
            data = data[:, x1:x1+self.imsize, y1:y1+self.imsize]
            target = target[x1:x1+self.imsize, y1:y1+self.imsize]

        # convert to torch tensors
        data = torch.from_numpy(data)
        target = torch.from_numpy(target)

        return data, target

    def __len__(self):
        """Length."""
        return len(self.coord)

Download data

We provide the data in the form of an archive, containing 6 pickle files. You can download them here.

Supposing they are stored on you Google Drive in the data/chaos folder, you can mount the folder using the following code. Set USE_COLAB to true.

In [3]:
USE_COLAB = False
if USE_COLAB:
    # mount the goole drive
    from google.colab import drive
    drive.mount('/content/drive')
    # download cifar on GoogleDrive
    data_dir = "/content/drive/My Drive/data/chaos"
else:
    data_dir = "data/chaos"

Load data

In [4]:
data = pickle.load(open(os.path.join(data_dir, "CT_train_data.pkl"),"rb"))
gt = pickle.load(open(os.path.join(data_dir, "CT_train_gt.pkl"),"rb"))

# data = pickle.load(open(open(os.path.join(data_dir, "MR_T1DUAL_train_data.pkl"),"rb"))
# gt = pickle.load(open(open(os.path.join(data_dir, "MR_T1DUAL_train_gt.pkl"),"rb"))

# data = pickle.load(open(open(os.path.join(data_dir, "MR_T2SPIR_train_data.pkl"),"rb"))
# gt = pickle.load(open(open(os.path.join(data_dir, "MR_T2SPIR_train_gt.pkl"),"rb"))

Create the data loader and iterate

In [6]:
dataset = ImageDataset(data, gt, ids=[1,2,3], imsize=256)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

for inputs, targets in dataloader:
    print(inputs.size(), targets.size())
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([2, 1, 256, 256]) torch.Size([2, 256, 256])
torch.Size([1, 1, 256, 256]) torch.Size([1, 256, 256])