Torchkeras
了解过深度学习框架的都知道,Tensorflow是早期的主流框架,而后又出现了Keras,keras对Tensorflow进行了封装,使得搭建深度学模型的过程简化到了几个简单的步骤:summary、compile、fit、evaluate、 predict。Pytorch虽然比Tensorflow出现的晚,但是其在框架的实现方式上,更为优雅,可以很好的与Python的原生编程思维结合起来,因此越来越多的人开始转向Pytorch。开发的过程就是不断模块化的过程,Torchkeras便是这一原则下的产物,它将Pytorch进行了封装,使得利用Pytorch搭建深度学模型的过程也可以像Keras那样,并且非常的灵活。
源码
# -*- coding: utf-8 -*-
import os
import datetime
import numpy as np
import pandas as pd
import torch
from collections import OrderedDict
from prettytable import PrettyTable
__version__ = "1.5.3"
#On macOs, run pytorch and matplotlib at the same time in jupyter should set this.
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
# Some modules do the computation themselves using parameters or the parameters of children, treat these as layers
layer_modules = (torch.nn.MultiheadAttention, )
def summary(model, input_shape, input_dtype = torch.FloatTensor, batch_size=-1,
layer_modules = layer_modules,*args, **kwargs):
def register_hook(module):
def hook(module, inputs, outputs):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
module_idx = len(summary)
key = "%s-%i" % (class_name, module_idx + 1)
info = OrderedDict()
info["id"] = id(module)
if isinstance(outputs, (list, tuple)):
try:
info["out"] = [batch_size] + list(outputs[0].size())[1:]
except AttributeError:
# pack_padded_seq and pad_packed_seq store feature into data attribute
info["out"] = [batch_size] + list(outputs[0].data.size())[1:]
else:
info["out"] = [batch_size] + list(outputs.size())[1:]
info["params_nt"], info["params"] = 0, 0
for name, param in module.named_parameters():
info["params"] += param.nelement() * param.requires_grad
info["params_nt"] += param.nelement() * (not param.requires_grad)
# if the current module is already-used, mark as "(recursive)"
# check if this module has params
if list(module.named_parameters()):
for v in summary.values():
if info["id"] == v["id"]:
info["params"] = "(recursive)"
summary[key] = info
# ignore Sequential and ModuleList and other containers
if isinstance(module, layer_modules) or not module._modules:
hooks.append(module.register_forward_hook(hook))
hooks = []
summary = OrderedDict()
model.apply(register_hook)
# multiple inputs to the network
if isinstance(input_shape, tuple):
input_shape = [input_shape]
# batch_size of 2 for batchnorm
x = [torch.rand(2, *size).type(input_dtype) for size in input_shape]
# print(type(x[0]))
try:
with torch.no_grad():
model(*x) if not (kwargs or args) else model(*x, *args, **kwargs)
except Exception:
# This can be usefull for debugging
print("Failed to run torchkeras.summary...")
raise
finally:
for hook in hooks:
hook.remove()
print("----------------------------------------------------------------")
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
print(line_new)
print("================================================================")
total_params = 0
total_output = 0
trainable_params = 0
for layer in summary:
# layer, output_shape, params
line_new = "{:>20} {:>25} {:>15}".format(
layer,
str(summary[layer]["out"]),
"{0:,}".format(summary[layer]["params"]+summary[layer]["params_nt"])
)
total_params += (summary[layer]["params"]+summary[layer]["params_nt"])
total_output += np.prod(summary[layer]["out"])
trainable_params += summary[layer]["params"]
print(line_new)
# assume 4 bytes/number
total_input_size = abs(np.prod(input_shape) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
total_params_size = abs(total_params * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size
print("================================================================")
print("Total params: {0:,}".format(total_params))
print("Trainable params: {0:,}".format(trainable_params))
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
print("----------------------------------------------------------------")
print("Input size (MB): %0.6f" % total_input_size)
print("Forward/backward pass size (MB): %0.6f" % total_output_size)
print("Params size (MB): %0.6f" % total_params_size)
print("Estimated Total Size (MB): %0.6f" % total_size)
print("----------------------------------------------------------------")
class Model(torch.nn.Module):
# print time bar...
@staticmethod
def print_bar():
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("
"+"="*80 + "%s"%nowtime)
def __init__(self,net = None):
super(Model, self).__init__()
self.net = net
def forward(self,x):
if self.net:
return self.net.forward(x)
else:
raise NotImplementedError
def compile(self, loss_func,
optimizer=None, metrics_dict=None,device = None):
self.loss_func = loss_func
self.optimizer = optimizer if optimizer else torch.optim.Adam(self.parameters(),lr = 0.001)
self.metrics_dict = metrics_dict if metrics_dict else {}
self.history = {}
self.device = device if torch.cuda.is_available() else None
if self.device:
self.to(self.device)
def summary(self,input_shape,input_dtype = torch.FloatTensor, batch_size=-1 ):
summary(self,input_shape,input_dtype,batch_size)
def train_step(self, features, labels):
self.train()
self.optimizer.zero_grad()
if self.device:
features = features.to(self.device)
labels = labels.to(self.device)
# forward
predictions = self.forward(features)
loss = self.loss_func(predictions,labels)
# evaluate metrics
train_metrics = {"loss":loss.item()}
for name,metric_func in self.metrics_dict.items():
train_metrics[name] = metric_func(predictions,labels).item()
# backward
loss.backward()
# update parameters
self.optimizer.step()
self.optimizer.zero_grad()
return train_metrics
@torch.no_grad()
def evaluate_step(self, features,labels):
self.eval()
if self.device:
features = features.to(self.device)
labels = labels.to(self.device)
with torch.no_grad():
predictions = self.forward(features)
loss = self.loss_func(predictions,labels)
val_metrics = {"val_loss":loss.item()}
for name,metric_func in self.metrics_dict.items():
val_metrics["val_"+name] = metric_func(predictions,labels).item()
return val_metrics
def fit(self,epochs,dl_train,dl_val = None,log_step_freq = 1):
print("Start Training ...")
Model.print_bar()
dl_val = dl_val if dl_val else []
for epoch in range(1,epochs+1):
# 1,training loop -------------------------------------------------
train_metrics_sum, step = {}, 0
for features,labels in dl_train:
step = step + 1
train_metrics = self.train_step(features,labels)
for name,metric in train_metrics.items():
train_metrics_sum[name] = train_metrics_sum.get(name,0.0)+metric
if step%log_step_freq == 0:
logs = {"step":step}
logs.update({k:round(v/step,3) for k,v in train_metrics_sum.items()})
print(logs)
for name,metric_sum in train_metrics_sum.items():
self.history[name] = self.history.get(name,[])+[metric_sum/step]
# 2,validate loop -------------------------------------------------
val_metrics_sum, step = {}, 0
for features,labels in dl_val:
step = step + 1
val_metrics = self.evaluate_step(features,labels)
for name,metric in val_metrics.items():
val_metrics_sum[name] = val_metrics_sum.get(name,0.0)+metric
for name,metric_sum in val_metrics_sum.items():
self.history[name] = self.history.get(name,[])+[metric_sum/step]
# 3,print logs -------------------------------------------------
infos = {"epoch":epoch}
infos.update({k:round(self.history[k][-1],3) for k in self.history})
tb = PrettyTable()
tb.field_names = infos.keys()
tb.add_row(infos.values())
print("
",tb)
Model.print_bar()
print("Finished Training...")
return pd.DataFrame(self.history)
@torch.no_grad()
def evaluate(self,dl_val):
self.eval()
val_metrics_list = {}
for features,labels in dl_val:
val_metrics = self.evaluate_step(features,labels)
for name,metric in val_metrics.items():
val_metrics_list[name] = val_metrics_list.get(name,[])+[metric]
return {name:np.mean(metric_list) for name,metric_list in val_metrics_list.items()}
@torch.no_grad()
def predict(self,dl):
self.eval()
if self.device:
result = torch.cat([self.forward(t[0].to(self.device)) for t in dl])
else:
result = torch.cat([self.forward(t[0]) for t in dl])
return(result.data)