DWTNet的Direction Network

/DN 目录列表

  • /DN/train_direction.py ; 构建 训练/测试 流程;
  • /DN/direction_model.py ; 构建网络;
  • /DN/ioUtils.py ; 构建batch数据;
  • /DN/lossFunction.py ; 损失函数;

# /DN/train_direction.py def initialize_model(outputChannels, wd=None, modelWeightPaths=None) def forward_model(model, feeder, outputSavePath) def train_model(model, outputChannels, learningRate, trainFeeder, valFeeder, modelSavePath=None, savePrefix=None, initialIteration=1) def modelSaver(sess, modelSavePath, savePrefix, iteration, maxToKeep=5) def checkSaveFlag(modelSavePath) if __name__ == "__main__": # training information... # /DN/direction_model.py class Network: def __init__(self, params, wd=5e-5, modelWeightPaths=None) def build(self, inputData, ss=None, ssMask=None, keepProb=1.0) def _max_pool(self, bottom, name) def _average_pool(self, bottom, name) def _conv_layer(self, bottom, params, keepProb=1.0) def get_bias(self, params) def get_conv_filter(self, params) def _upscore_layer(self, bottom, shape, params) def get_deconv_filter(self, f_shape, params) # /DN/ioUtils.py def read_mat(path) def write_mat(path, m) def read_ids(path) class Batch_Feeder: def __init__(self, dataset, indices, train, batchSize, padWidth=None, padHeight=None, flip=False, keepEmpty=True) def set_paths(self, idList=None, imageDir=None, gtDir=None, ssDir=None) def shuffle(self) def next_batch(self) def total_samples(self) def image_scaling(self, rgb_in) def pad(self, data) # /DN/lossFunction.py def angularErrorTotal(pred, gt, weight, ss, outputChannels=2) def exceedingAngleThreshold(pred, gt, ss, threshold, outputChannels=2) def countCorrect(pred, gt, ss, k, outputChannels) def countTotal(ss) def countTotalWeighted(ss, weight)

batch数据处理(/DN/ioUtils.py)


''' train : self_paths[图片id, 原始图片地址, 图片对应真值(dir_map,depth_map,weight_map;mat格式), 图片对应语义分割的结果(mat格式)] test : self_paths[图片id, 图片对应语义分割的结果(mat格式)] ''' def set_paths(self, idList=None, imageDir=None, gtDir=None, ssDir=None): self._paths = [] if self._train: for id in idList: self._paths.append([id, imageDir + '/' + id + '_leftImg8bit.png', gtDir + '/' + id + '_unified_GT.mat', ssDir + '/' + id + '_unified_ss.mat']) self.shuffle() else: for id in idList: self._paths.append([id, imageDir + '/' + id + '_leftImg8bit.png', ssDir + '/' + id + '_unified_ss.mat']) self._numData = len(self._paths) if self._numData < self._batchSize: self._batchSize = self._numData def next_batch(self): idBatch = [] imageBatch = [] gtBatch = [] ssBatch = [] ssMaskBatch = [] weightBatch = [] # train if self._train: while(len(idBatch) < self._batchSize): ss = (sio.loadmat(self._paths[self._index_in_epoch][3])['mask']).astype(float) ssMask = ss ss = np.sum(ss[:,:,self._indices], 2) background = np.zeros(ssMask.shape[0:2] + (1,)) ssMask = np.concatenate((ssMask[:,:,[1,2,3,4]], background, ssMask[:,:,[0,5,6,7]]), axis=-1) ssMask = np.argmax(ssMask, axis=-1) ssMask = ssMask.astype(float) ssMask = (ssMask - 4) * 32 # centered at 0, with 0 being background, spaced 32 apart for classes if ss.sum() > 0 or self._keepEmpty: idBatch.append(self._paths[self._index_in_epoch][0]) image = (self.image_scaling(skimage.io.imread(self._paths[self._index_in_epoch][1]))).astype(float) image = scipy.misc.imresize(image, 50) gt = (sio.loadmat(self._paths[self._index_in_epoch][2])['dir_map']).astype(float) weight = (sio.loadmat(self._paths[self._index_in_epoch][2])['weight_map']).astype(float) imageBatch.append(self.pad(image)) gtBatch.append(self.pad(gt)) weightBatch.append(self.pad(weight)) ssBatch.append(self.pad(ss)) ssMaskBatch.append(self.pad(ssMask)) else: pass self._index_in_epoch += 1 if self._index_in_epoch == self._numData: self._index_in_epoch = 0 self.shuffle() imageBatch = np.array(imageBatch) gtBatch = np.array(gtBatch) ssBatch = np.array(ssBatch) ssMaskBatch = np.array(ssMaskBatch) weightBatch = np.array(weightBatch) if self._flip and np.random.uniform() > 0.5: for i in range(len(imageBatch)): for j in range(3): imageBatch[i,:,:,j] = np.fliplr(imageBatch[i,:,:,j]) weightBatch[i] = np.fliplr(weightBatch[i]) ssBatch[i] = np.fliplr(ssBatch[i]) ssMaskBatch[i] = np.fliplr(ssMaskBatch[i]) for j in range(2): gtBatch[i,:,:,j] = np.fliplr(gtBatch[i,:,:,j]) gtBatch[i,:,:,0] = -1 * gtBatch[i,:,:,0] return imageBatch, gtBatch, weightBatch, ssBatch, ssMaskBatch, idBatch # test else: for example in self._paths[self._index_in_epoch:min(self._index_in_epoch+self._batchSize, self._numData)]: imageBatch.append(self.pad((self.image_scaling(skimage.io.imread(example[1]))).astype(float))) idBatch.append(example[0]) ss = (sio.loadmat(example[2])['mask']).astype(float) ssMask = ss ss = np.sum(ss[:, :, self._indices], 2) background = np.zeros(ssMask.shape[0:2] + (1,)) ssMask = np.concatenate((ssMask[:,:,[1,2,3,4]], background, ssMask[:,:,[0,5,6,7]]), axis=-1) ssMask = np.argmax(ssMask, axis=-1) ssMask = ssMask.astype(float) ssMask = (ssMask - 4) * 32 # centered at 0, with 0 being background, spaced 32 apart for classes ssBatch.append(self.pad(ss)) ssMaskBatch.append(self.pad(ssMask)) imageBatch = np.array(imageBatch) ssBatch = np.array(ssBatch) ssMaskBatch = np.array(ssMaskBatch) self._index_in_epoch += self._batchSize return imageBatch, ssBatch, ssMaskBatch, idBatch

损失函数(/DN/lossFunction.py)

def angularErrorTotal(pred, gt, weight, ss, outputChannels=2):
    with tf.name_scope("angular_error"):
        pred = tf.reshape(pred, (-1, outputChannels))
        gt = tf.to_float(tf.reshape(gt, (-1, outputChannels)))
        weight = tf.to_float(tf.reshape(weight, (-1, 1)))
        ss = tf.to_float(tf.reshape(ss, (-1, 1)))

        pred = tf.nn.l2_normalize(pred, 1) * 0.999999
        gt = tf.nn.l2_normalize(gt, 1) * 0.999999

        errorAngles = tf.acos(tf.reduce_sum(pred * gt, reduction_indices=[1], keep_dims=True))

        lossAngleTotal = tf.reduce_sum((tf.abs(errorAngles*errorAngles))*ss*weight)

        return lossAngleTotal

def angularErrorLoss(pred, gt, weight, ss, outputChannels=2):
        lossAngleTotal = angularErrorTotal(pred=pred, gt=gt, ss=ss, weight=weight, outputChannels=outputChannels) / (countTotal(ss)+1)

        tf.add_to_collection('losses', lossAngleTotal)

        totalLoss = tf.add_n(tf.get_collection('losses'), name='total_loss')

        return totalLoss