Skip to content

Commit 3b78b1a

Browse files
committed
feat(ssd): 测试代码
1 parent 90d1d8c commit 3b78b1a

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

py/test.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import argparse
2+
import logging
3+
import os
4+
5+
import torch
6+
import torch.utils.data
7+
8+
from ssd.config import cfg
9+
from ssd.engine.inference import do_evaluation
10+
from ssd.models.detector import build_detection_model
11+
from ssd.utils import dist_util
12+
from ssd.utils.checkpoint import CheckPointer
13+
from ssd.utils.dist_util import synchronize
14+
from ssd.utils.logger import setup_logger
15+
16+
17+
def evaluation(cfg, ckpt, distributed):
18+
logger = logging.getLogger("SSD.inference")
19+
20+
model = build_detection_model(cfg)
21+
checkpointer = CheckPointer(model, save_dir=cfg.OUTPUT_DIR, logger=logger)
22+
device = torch.device(cfg.MODEL.DEVICE)
23+
model.to(device)
24+
checkpointer.load(ckpt, use_latest=ckpt is None)
25+
do_evaluation(cfg, model, distributed)
26+
27+
28+
def main():
29+
parser = argparse.ArgumentParser(description='SSD Evaluation on VOC and COCO dataset.')
30+
parser.add_argument(
31+
"--config-file",
32+
default="",
33+
metavar="FILE",
34+
help="path to config file",
35+
type=str,
36+
)
37+
parser.add_argument("--local_rank", type=int, default=0)
38+
parser.add_argument(
39+
"--ckpt",
40+
help="The path to the checkpoint for test, default is the latest checkpoint.",
41+
default=None,
42+
type=str,
43+
)
44+
45+
parser.add_argument("--output_dir", default="eval_results", type=str,
46+
help="The directory to store evaluation results.")
47+
48+
parser.add_argument(
49+
"opts",
50+
help="Modify config options using the command-line",
51+
default=None,
52+
nargs=argparse.REMAINDER,
53+
)
54+
args = parser.parse_args()
55+
56+
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
57+
distributed = num_gpus > 1
58+
59+
if torch.cuda.is_available():
60+
# This flag allows you to enable the inbuilt cudnn auto-tuner to
61+
# find the best algorithm to use for your hardware.
62+
torch.backends.cudnn.benchmark = True
63+
if distributed:
64+
torch.cuda.set_device(args.local_rank)
65+
torch.distributed.init_process_group(backend="nccl", init_method="env://")
66+
synchronize()
67+
68+
cfg.merge_from_file(args.config_file)
69+
cfg.merge_from_list(args.opts)
70+
cfg.freeze()
71+
72+
logger = setup_logger("SSD", dist_util.get_rank(), cfg.OUTPUT_DIR)
73+
logger.info("Using {} GPUs".format(num_gpus))
74+
logger.info(args)
75+
76+
logger.info("Loaded configuration file {}".format(args.config_file))
77+
with open(args.config_file, "r") as cf:
78+
config_str = "\n" + cf.read()
79+
logger.info(config_str)
80+
logger.info("Running with config:\n{}".format(cfg))
81+
evaluation(cfg, ckpt=args.ckpt, distributed=distributed)
82+
83+
84+
if __name__ == '__main__':
85+
main()

0 commit comments

Comments
 (0)