Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
fastai
GitHub Repository: fastai/course22
Path: blob/master/clean/03-which-image-models-are-best.ipynb
807 views
Kernel: python3

timm

PyTorch Image Models (timm) is a wonderful library by Ross Wightman which provides state-of-the-art pre-trained computer vision models. It's like Huggingface Transformers, but for computer vision instead of NLP (and it's not restricted to transformers-based models)!

Ross has been kind enough to help me understand how to best take advantage of this library by identifying the top models. I'm going to share here so of what I've learned from him, plus some additional ideas.

The data

Ross regularly benchmarks new models as they are added to timm, and puts the results in a CSV in the project's GitHub repo. To analyse the data, we'll first clone the repo:

! git clone --depth 1 https://github.com/rwightman/pytorch-image-models.git %cd pytorch-image-models/results
import pandas as pd df_results = pd.read_csv('results-imagenet.csv')
def get_data(part, col): df = pd.read_csv(f'benchmark-{part}-amp-nhwc-pt111-cu113-rtx3090.csv').merge(df_results, on='model') df['secs'] = 1. / df[col] df['family'] = df.model.str.extract('^([a-z]+?(?:v2)?)(?:\d|_|$)') df = df[~df.model.str.endswith('gn')] df.loc[df.model.str.contains('in22'),'family'] = df.loc[df.model.str.contains('in22'),'family'] + '_in22' df.loc[df.model.str.contains('resnet.*d'),'family'] = df.loc[df.model.str.contains('resnet.*d'),'family'] + 'd' return df[df.family.str.contains('^re[sg]netd?|beit|convnext|levit|efficient|vit|vgg')]
df = get_data('infer', 'infer_samples_per_sec')

Inference results

import plotly.express as px w,h = 1000,800 def show_all(df, title, size): return px.scatter(df, width=w, height=h, size=df[size]**2, title=title, x='secs', y='top1', log_x=True, color='family', hover_name='model', hover_data=[size])
show_all(df, 'Inference', 'infer_img_size')
subs = 'levit|resnetd?|regnetx|vgg|convnext.*|efficientnetv2|beit'
def show_subs(df, title, size): df_subs = df[df.family.str.fullmatch(subs)] return px.scatter(df_subs, width=w, height=h, size=df_subs[size]**2, title=title, trendline="ols", trendline_options={'log_x':True}, x='secs', y='top1', log_x=True, color='family', hover_name='model', hover_data=[size])
show_subs(df, 'Inference', 'infer_img_size')
px.scatter(df, width=w, height=h, x='param_count_x', y='secs', log_x=True, log_y=True, color='infer_img_size', hover_name='model', hover_data=['infer_samples_per_sec', 'family'] )

Training results

tdf = get_data('train', 'train_samples_per_sec')
show_all(tdf, 'Training', 'train_img_size')
show_subs(tdf, 'Training', 'train_img_size')