本文使用预训练的Resnet50网络对皮肤病图片进行二分类,基于portch框架。
数据集说明
数据集存放目录为: used_dataset , 共200张图片,标签为:benign(良性)、malignant(患病)。
数据集划分如下:
代码目录介绍
- args.py 存放训练和测试所用的各种参数。 --mode字段表示运行模式:train or test. --model_path字段是训练模型的保存路径。 其余字段都有默认值。
- create_dataset.py 该脚本是用来读json中的数据的,可以忽略。
- data_gen.py 该脚本实现划分数据集以及数据增强和数据加载。
- main.py 包含训练、评估和测试。
- transform.py 实现图片增强。
- utils.py 存放一些工具函数。
- models/Res.py 是重写的ResNet各种类型的网络。
- checkpoints 保存模型
main.py 脚本介绍
main()函数 实现模型的训练和评估
step1: 加载数据
step2: 构建模型
step3: 模型的训练和评估
train()函数 每个epoch下的模型训练过程
主要实现每个批次下梯度的反向传播,计算accuarcy 和 loss, 并更新,最后返回其均值。
val()函数 每个epoch下的模型评估过程
主要代码与train()函数一致,但没有梯度的计算,还有将model.train()改成model.eval()。
test()函数 模型的测试
实验结果
版权声明:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若内容造成侵权、违法违规、事实不符,请将相关资料发送至xkadmin@xkablog.com进行投诉反馈,一经查实,立即处理!
转载请注明出处,原文链接:https://www.xkablog.com/rfx/49827.html