One-Shot Learning on Humpback Whales with Siamese Nerual Networks

In this notebook I will explore setting up a Siamese Neural Network (SNN), using the fastai/pytorch framework, to try and identify whales by their flukes (tail fins). The dataset comes from the kaggle humpback whale identification challege. The inspiriation for this technique originated from Martin Piotte's kaggle kernel which implemented a SNN in keras.

I'll be focusing on training an SNN as they are specifically tailored for one-shot learning tasks, which consists of classification under the restriction that we may only observe a single example of each possible class before making a prediction about a test instance. This is extremely useful given that in my previous post I found that the majority of whales in the dataset only have 1-4 examples in the training set.


%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai import *
from import *
from fastai.callbacks import *
import fastai
path = Path('../input')
train = pd.read_csv(path/'train.csv'); 
## Filtering classes with at least 60 images for quick experiments
cnter = Counter(train.Id.values)
train['cnt']=train['Id'].apply(lambda x: cnter[x])
#Not consider new_whale images and hence considering 500 as upper cutoff
train = train[(train['cnt']>60) & (train['cnt']<500)] 

Duplicating the entire dataframe so that we can generate a positive pair and a negative pair for each image; maintaing 50% positive pairs and 50% negative pairs for the saimese network. Why this is done can be understood in below sections

train['target'] = 1
train1 = train.copy()
train1['target'] = 0
train = train.append(train1)
target_col = 3

Custom Itembase and ItemList for Siamese Networks

  • Much of this follows the Custom Item List tutorial(
  • Though blindly following the steps in the tutorial might work, i was not very comfortable doing that and hence spent a few hours digging into the library and understanding how various pieces inside it work
  • I would suggest starting with tabular as tabular is relatively easier to follow and gives an overall idea of the different pieces involved in getting the databunch ready
  • At a very high level, creation of databunch for Image application follows this path.
    • When factory methods like from_df are called, it instantiates the corresponding ImageItemList. Important things includes creation of items, xtra df with necessary info etc. (vision/ - ImageItemList class)
    • Most imp thing to note here is the get() func responsible for retreving data given the index. This func links the ImageItemList to the Image(Itembase) by returning the retrived item as the class of Itembase. I missed this as it was very subtle in the code and couldnt connect the dots initially.
    • .split_by_.. - Much of this happens in, ItemLists class and easier to follow
    • Label_from_df - This is where a lot of heavy lifting seems to happen. As ItemLists doesnt have label_from_df attri, it falls back to ItemList class(using '_get_attr') and does the job. Also, finally the class is changed to LabelLists which gives it the functionality to create databunch etc. Refer to and LabelList and LabelLists classes
    • Also Deduction of target variable class, and assignment of processor, loss_func etc happen here
    • .databunch does the rest
mean, std = torch.tensor(imagenet_stats)
# The primary difference from the tutorial is with how normalization is being done here
class SiamImage(ItemBase):
    def __init__(self, img1, img2): ## These should of Image type
        self.img1, self.img2 = img1, img2
        self.obj, = (img1, img2), [([...,None,None])/std[...,None,None], ([...,None,None])/std[...,None,None]]
    def apply_tfms(self, tfms,*args, **kwargs):
        self.img1 = self.img1.apply_tfms(tfms, *args, **kwargs)
        self.img2 = self.img2.apply_tfms(tfms, *args, **kwargs) = [([...,None,None])/std[...,None,None], ([...,None,None])/std[...,None,None]]
        return self
    def __repr__(self): return f'{self.__class__.__name__} {self.img1.shape, self.img2.shape}'
    def to_one(self):
        return Image(mean[...,None,None],2)*std[...,None,None])

We define get() func such that we get similar pair of images and different pair of images based on idx value. Similar for first half index values and different for next half. Hence the duplicity in dataframe initially. Currently did things dumbly just to demostrate. Can do generation of pairs, dealing with underbalanced classes and many other things in a much better way

class SiamImageItemList(ImageItemList):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
#         self._label_cls=FloatList

    def __len__(self)->int: return len(self.items) or 1 

    def get(self, i):
        if i>=len(self.items)//2:#"First set of iteration will generate similar pairs, next will generate different pairs"
            match = 0
        fn = self.items[i]
        img1 = super().get(i) # Returns Image class object

        imgs = self.xtra.Image.values
        ids = self.xtra.Id.values
        wcls = ids[i]
        simgs = imgs[ids == wcls]
        dimgs = imgs[ids != wcls]
        if len(simgs)==1 and match==1:
            while True:
                if simgs[0] != fn:
                    fn2 = [simgs[0] if match==1 else dimgs[0]][0]
            fn2 = self.items[np.where(imgs==fn2)[0][0]]
        img2 = super().open(fn2) # Returns Image class object
        return SiamImage(img1, img2)

    def reconstruct(self, t): return SiamImage(mean[...,None,None]+t[0]*std[...,None,None], mean[...,None,None]+t[1]*std[...,None,None])

    def show_xys(self, xs, ys, figsize:Tuple[int,int]=(9,10), **kwargs):
        rows = int(math.sqrt(len(xs)))
        fig, axs = plt.subplots(rows,rows,figsize=figsize)
        for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
            xs[i].to_one().show(ax=ax, y=ys[i], **kwargs)
whl_tfms = get_transforms()
bs = 8
data = (SiamImageItemList.from_df(df=train, path=path/'train', cols=0)
         .random_split_by_pct(valid_pct=0.2, seed=34)
         .label_from_df(cols=target_col, label_cls=FloatList)
         .transform(whl_tfms, size=224)


# Checking if the normalization done above is correct
x = next(iter(data.train_dl))
to = mean[...,None,None] + t* std[...,None,None]
import torchvision
ti = torchvision.transforms.ToPILImage(to)
<matplotlib.image.AxesImage at 0x7faa85c81748>


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

<matplotlib.image.AxesImage at 0x7faa85c649b0>


Siamese Network

from import learner

class SiameseNet(nn.Module):
    def __init__(self, arch=models.resnet18, lin_ftrs=[256, 128], emb_sz=128,ps=0.5, bn_final=False):
        super(SiameseNet, self).__init__()
        self.arch, self.emb_sz = arch, emb_sz
        self.lin_ftrs,, self.bn_final = lin_ftrs, ps, bn_final
        self.body = learner.create_body(self.arch, True, learner.cnn_config(self.arch)['cut'])
        self.head = learner.create_head(num_features_model(self.body) * 2, self.emb_sz, self.lin_ftrs,,self.bn_final)
        self.cnn = nn.Sequential(self.body, self.head)

    def forward(self, x1, x2):
        output1 = self.cnn(x1)
        output2 = self.cnn(x2)
        return output1, output2

    def get_embedding(self, x):
        return self.cnn(x)

Loss function

After comparing several loss functions and I've found that contrastive loss works the best in the current setup. Distance based logistic loss gives similar performance when model is trained with singe precision, but worse results for training with half precision.

class ContrastiveLoss(nn.Module):
    """Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise
    def __init__(self, margin=5.):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
    def forward(self, ops, target, size_average=True):
        op1, op2 = ops[0], ops[1]
        dist = F.pairwise_distance(op1, op2)
        pdist = dist*target
        ndist = dist*(1-target)
        loss = 0.5* ((pdist**2) + (F.relu(self.margin-ndist)**2))
        return loss.mean() if size_average else losses.sum()
model = SiameseNet().cuda()
apply_init(model.head, nn.init.kaiming_normal_)
siam_learner = Learner(data, model, loss_func=loss_func, model_dir=Path(os.getcwd()))
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.


siam_learner.fit_one_cycle(5, slice(3*1e-2))

Total time: 01:01

epoch train_loss valid_loss
1 47.814510 171.353699
2 20.523115 10.899207
3 13.116446 10.829012
4 10.895652 11.075461
5 9.889898 11.046269