np.concatenate()&py-MDNet/tracking/run 短命女 2022-11-12 13:45 123阅读 0赞 以下代码取自MDNet def train(model, criterion, optimizer, pos_feats, neg_feats, maxiter, in_layer='fc4'): model.train()#启用Batch Normalization和Dropout batch_pos = opts['batch_pos']#32 batch_neg = opts['batch_neg']#96 batch_test = opts['batch_test']#256 batch_neg_cand = max(opts['batch_neg_cand'], batch_neg)#1024 pos_idx = np.random.permutation(pos_feats.size(0))#permutaition重新排列,将正样本和负样本特征的索引重新排列。 neg_idx = np.random.permutation(neg_feats.size(0)) while(len(pos_idx) < batch_pos * maxiter): pos_idx = np.concatenate([pos_idx, np.random.permutation(pos_feats.size(0))])#将pos_idx从500扩展到96*50 while(len(neg_idx) < batch_neg_cand * maxiter): neg_idx = np.concatenate([neg_idx, np.random.permutation(neg_feats.size(0))])#将neg_idx扩展到1024*50 主要是分析np.concatenate()的作用 示例代码: a = np.random.permutation(10) print(a) a = np.concatenate([a,np.random.permutation(10)]) print(a) ''' [2 8 7 6 1 3 5 4 9 0] [2 8 7 6 1 3 5 4 9 0 3 7 2 0 5 8 1 6 9 4] ''' 所以,在源代码中,就是不断扩展,pos\_idx以及neg\_idx
还没有评论,来说两句吧...