教程 1: 学习接口文件¶
在”fgvclib/api”这个文件夹下,我们为fgvclib设置了各类api接口。这里有四种类型的api接口:build.py, evluate_model.py, save_model.py, 和 update_model.py。
“fgvclib/apis/build.py”:提供了各种用于快速构建训练系统或评估系统的api;
“fgvc/apis/evluate_model.py”:提供了用于评估FGVC算法的api;
“fgvclib/apis/save_model.py”:提供了各种用于保存模型的api;
“fgvclib/apis/update_model”:提供了各种用于更新模型和记录损失的api。
模型构建¶
build_model: 根据全局配置构建一个FGVC模型。
参数:
model_cfg (CfgNode): 根配置的模型配置节点返回值:
nn.Module: FGVC模型
build_logger: 根据配置构建日志对象。
参数:
cfg (CfgNode): 根配置节点返回值:
Logger: 日志对象 build_transforms: 根据配置为训练或测试数据集构建转换参数:
transforms_cfg (CfgNode): 根配置节点返回值:
transforms.Compose: Pytorch中的transforms.Compose对象
build_dataset: 为训练过程或评估过程构建数据加载器
参数:
root (str): 数据集的目录cfg (CfgNode): 根配置节点返回值:
DataLoader: Pytorch数据加载器
build_optimizer: 为训练过程构建优化器
参数:
optim_cfg (CfgNode): 根配置节点的优化配置节点返回值:
Optimizer: Pytorch优化器
build_criterion : 为训练过程构建损失函数
参数:
criterion_cfg(CfgNode): 根配置节点的标准配置节点返回值:
nn.Module: 损失函数
build_interpreter: 为训练过程构建一个解释器
参数:
cfg (CfgNode): 根配置节点返回值:
Interpreter: 一个解释器
build_metrics: 为评估过程构建度量标准
参数:
metrics_cfg (CfgNode): 根配置节点的度量标准配置节点返回值:
t.List[NamedMetric]: NamedMetric列表
模型评估¶
evaluate_model:对FGVC模型进行评估
参数:
model (nn.Module): FGVC模型p_bar (iterable): 提供测试数据的迭代器metrics (List[NamedMetric]): 指标的列表use_cuda (boolean, optional): 是否使用gpu返回值:
dict: 结果的字典
模型保存¶
save_model: 保存被训练的FGVC模型
参数:
cfg (CfgNode): 根配置节点model (nn.Module): FGVC模型logger (Logger): 日志对象
模型更新¶
update_model: 更新FGVC模型并且记录损失
参数:
model (nn.Module): FGVC模型optimizer (Optimizer): 日志对象pbar (Iterable): 提供训练数据的可迭代对象strategy (string): 更新的策略use_cuda (boolean): 是否使用GPU训练模型logger (Logger): 日志对象
API的应用¶
当你进行算法设计时,你需要使用from fgvclib.apis import * 导入上述这些api去调用这些接口。你可以直接使用以下的函数:build_logger, build_criterion, build_model, build_metrics, build_transforms, build_dataset, build_optimizer, update_model, evaluate_model, save_model, build_interpreter
应用举例:建立模型
import os
import torch
from fgvclib.apis import *
from fgvclib.configs import FGVCConfig
model = build_model(cfg.MODEL)
weight_path = os.path.join(cfg.WEIGHT.SAVE_DIR, cfg.WEIGHT.NAME)
assert os.path.exists(weight_path), f"The resume weight {cfg.RESUME_WEIGHT} dosn't exists."
state_dict = torch.load(weight_path, map_location="cpu")
model.load_state_dict(state_dict=state_dict)
if cfg.USE_CUDA:
assert torch.cuda.is_available(), f"Cuda is not available."
model = torch.nn.DataParallel(model)
transforms = build_transforms(cfg.TRANSFORMS.TEST)
loader = build_dataset(root=os.path.join(cfg.DATASETS.ROOT, 'test'), cfg=cfg.DATASETS.TEST, transforms=transforms)
interpreter = build_interpreter(model, cfg)
voxel = VOXEL(dataset=loader.dataset, name=cfg.FIFTYONE.NAME, interpreter=interpreter)
voxel.predict(model, transforms, 10, cfg.MODEL.NAME)
voxel.launch()