417 lines
10 KiB
Python
417 lines
10 KiB
Python
# Lint as: python2, python3
|
|
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Visualizes the segmentation results via specified color map.
|
|
|
|
Visualizes the semantic segmentation results by the color map
|
|
defined by the different datasets. Supported colormaps are:
|
|
|
|
* ADE20K (http://groups.csail.mit.edu/vision/datasets/ADE20K/).
|
|
|
|
* Cityscapes dataset (https://www.cityscapes-dataset.com).
|
|
|
|
* Mapillary Vistas (https://research.mapillary.com).
|
|
|
|
* PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/).
|
|
"""
|
|
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import numpy as np
|
|
|
|
# from six.moves import range
|
|
|
|
# Dataset names.
|
|
_ADE20K = 'ade20k'
|
|
_CITYSCAPES = 'cityscapes'
|
|
_MAPILLARY_VISTAS = 'mapillary_vistas'
|
|
_PASCAL = 'pascal'
|
|
|
|
# Max number of entries in the colormap for each dataset.
|
|
_DATASET_MAX_ENTRIES = {
|
|
_ADE20K: 151,
|
|
_CITYSCAPES: 256,
|
|
_MAPILLARY_VISTAS: 66,
|
|
_PASCAL: 512,
|
|
}
|
|
|
|
|
|
def create_ade20k_label_colormap():
|
|
"""Creates a label colormap used in ADE20K segmentation benchmark.
|
|
|
|
Returns:
|
|
A colormap for visualizing segmentation results.
|
|
"""
|
|
return np.asarray([
|
|
[0, 0, 0],
|
|
[120, 120, 120],
|
|
[180, 120, 120],
|
|
[6, 230, 230],
|
|
[80, 50, 50],
|
|
[4, 200, 3],
|
|
[120, 120, 80],
|
|
[140, 140, 140],
|
|
[204, 5, 255],
|
|
[230, 230, 230],
|
|
[4, 250, 7],
|
|
[224, 5, 255],
|
|
[235, 255, 7],
|
|
[150, 5, 61],
|
|
[120, 120, 70],
|
|
[8, 255, 51],
|
|
[255, 6, 82],
|
|
[143, 255, 140],
|
|
[204, 255, 4],
|
|
[255, 51, 7],
|
|
[204, 70, 3],
|
|
[0, 102, 200],
|
|
[61, 230, 250],
|
|
[255, 6, 51],
|
|
[11, 102, 255],
|
|
[255, 7, 71],
|
|
[255, 9, 224],
|
|
[9, 7, 230],
|
|
[220, 220, 220],
|
|
[255, 9, 92],
|
|
[112, 9, 255],
|
|
[8, 255, 214],
|
|
[7, 255, 224],
|
|
[255, 184, 6],
|
|
[10, 255, 71],
|
|
[255, 41, 10],
|
|
[7, 255, 255],
|
|
[224, 255, 8],
|
|
[102, 8, 255],
|
|
[255, 61, 6],
|
|
[255, 194, 7],
|
|
[255, 122, 8],
|
|
[0, 255, 20],
|
|
[255, 8, 41],
|
|
[255, 5, 153],
|
|
[6, 51, 255],
|
|
[235, 12, 255],
|
|
[160, 150, 20],
|
|
[0, 163, 255],
|
|
[140, 140, 140],
|
|
[250, 10, 15],
|
|
[20, 255, 0],
|
|
[31, 255, 0],
|
|
[255, 31, 0],
|
|
[255, 224, 0],
|
|
[153, 255, 0],
|
|
[0, 0, 255],
|
|
[255, 71, 0],
|
|
[0, 235, 255],
|
|
[0, 173, 255],
|
|
[31, 0, 255],
|
|
[11, 200, 200],
|
|
[255, 82, 0],
|
|
[0, 255, 245],
|
|
[0, 61, 255],
|
|
[0, 255, 112],
|
|
[0, 255, 133],
|
|
[255, 0, 0],
|
|
[255, 163, 0],
|
|
[255, 102, 0],
|
|
[194, 255, 0],
|
|
[0, 143, 255],
|
|
[51, 255, 0],
|
|
[0, 82, 255],
|
|
[0, 255, 41],
|
|
[0, 255, 173],
|
|
[10, 0, 255],
|
|
[173, 255, 0],
|
|
[0, 255, 153],
|
|
[255, 92, 0],
|
|
[255, 0, 255],
|
|
[255, 0, 245],
|
|
[255, 0, 102],
|
|
[255, 173, 0],
|
|
[255, 0, 20],
|
|
[255, 184, 184],
|
|
[0, 31, 255],
|
|
[0, 255, 61],
|
|
[0, 71, 255],
|
|
[255, 0, 204],
|
|
[0, 255, 194],
|
|
[0, 255, 82],
|
|
[0, 10, 255],
|
|
[0, 112, 255],
|
|
[51, 0, 255],
|
|
[0, 194, 255],
|
|
[0, 122, 255],
|
|
[0, 255, 163],
|
|
[255, 153, 0],
|
|
[0, 255, 10],
|
|
[255, 112, 0],
|
|
[143, 255, 0],
|
|
[82, 0, 255],
|
|
[163, 255, 0],
|
|
[255, 235, 0],
|
|
[8, 184, 170],
|
|
[133, 0, 255],
|
|
[0, 255, 92],
|
|
[184, 0, 255],
|
|
[255, 0, 31],
|
|
[0, 184, 255],
|
|
[0, 214, 255],
|
|
[255, 0, 112],
|
|
[92, 255, 0],
|
|
[0, 224, 255],
|
|
[112, 224, 255],
|
|
[70, 184, 160],
|
|
[163, 0, 255],
|
|
[153, 0, 255],
|
|
[71, 255, 0],
|
|
[255, 0, 163],
|
|
[255, 204, 0],
|
|
[255, 0, 143],
|
|
[0, 255, 235],
|
|
[133, 255, 0],
|
|
[255, 0, 235],
|
|
[245, 0, 255],
|
|
[255, 0, 122],
|
|
[255, 245, 0],
|
|
[10, 190, 212],
|
|
[214, 255, 0],
|
|
[0, 204, 255],
|
|
[20, 0, 255],
|
|
[255, 255, 0],
|
|
[0, 153, 255],
|
|
[0, 41, 255],
|
|
[0, 255, 204],
|
|
[41, 0, 255],
|
|
[41, 255, 0],
|
|
[173, 0, 255],
|
|
[0, 245, 255],
|
|
[71, 0, 255],
|
|
[122, 0, 255],
|
|
[0, 255, 184],
|
|
[0, 92, 255],
|
|
[184, 255, 0],
|
|
[0, 133, 255],
|
|
[255, 214, 0],
|
|
[25, 194, 194],
|
|
[102, 255, 0],
|
|
[92, 0, 255],
|
|
])
|
|
|
|
|
|
def create_cityscapes_label_colormap():
|
|
"""Creates a label colormap used in CITYSCAPES segmentation benchmark.
|
|
|
|
Returns:
|
|
A colormap for visualizing segmentation results.
|
|
"""
|
|
colormap = np.zeros((256, 3), dtype=np.uint8)
|
|
colormap[0] = [128, 64, 128]
|
|
colormap[1] = [244, 35, 232]
|
|
colormap[2] = [70, 70, 70]
|
|
colormap[3] = [102, 102, 156]
|
|
colormap[4] = [190, 153, 153]
|
|
colormap[5] = [153, 153, 153]
|
|
colormap[6] = [250, 170, 30]
|
|
colormap[7] = [220, 220, 0]
|
|
colormap[8] = [107, 142, 35]
|
|
colormap[9] = [152, 251, 152]
|
|
colormap[10] = [70, 130, 180]
|
|
colormap[11] = [220, 20, 60]
|
|
colormap[12] = [255, 0, 0]
|
|
colormap[13] = [0, 0, 142]
|
|
colormap[14] = [0, 0, 70]
|
|
colormap[15] = [0, 60, 100]
|
|
colormap[16] = [0, 80, 100]
|
|
colormap[17] = [0, 0, 230]
|
|
colormap[18] = [119, 11, 32]
|
|
return colormap
|
|
|
|
|
|
def create_mapillary_vistas_label_colormap():
|
|
"""Creates a label colormap used in Mapillary Vistas segmentation benchmark.
|
|
|
|
Returns:
|
|
A colormap for visualizing segmentation results.
|
|
"""
|
|
return np.asarray([
|
|
[165, 42, 42],
|
|
[0, 192, 0],
|
|
[196, 196, 196],
|
|
[190, 153, 153],
|
|
[180, 165, 180],
|
|
[102, 102, 156],
|
|
[102, 102, 156],
|
|
[128, 64, 255],
|
|
[140, 140, 200],
|
|
[170, 170, 170],
|
|
[250, 170, 160],
|
|
[96, 96, 96],
|
|
[230, 150, 140],
|
|
[128, 64, 128],
|
|
[110, 110, 110],
|
|
[244, 35, 232],
|
|
[150, 100, 100],
|
|
[70, 70, 70],
|
|
[150, 120, 90],
|
|
[220, 20, 60],
|
|
[255, 0, 0],
|
|
[255, 0, 0],
|
|
[255, 0, 0],
|
|
[200, 128, 128],
|
|
[255, 255, 255],
|
|
[64, 170, 64],
|
|
[128, 64, 64],
|
|
[70, 130, 180],
|
|
[255, 255, 255],
|
|
[152, 251, 152],
|
|
[107, 142, 35],
|
|
[0, 170, 30],
|
|
[255, 255, 128],
|
|
[250, 0, 30],
|
|
[0, 0, 0],
|
|
[220, 220, 220],
|
|
[170, 170, 170],
|
|
[222, 40, 40],
|
|
[100, 170, 30],
|
|
[40, 40, 40],
|
|
[33, 33, 33],
|
|
[170, 170, 170],
|
|
[0, 0, 142],
|
|
[170, 170, 170],
|
|
[210, 170, 100],
|
|
[153, 153, 153],
|
|
[128, 128, 128],
|
|
[0, 0, 142],
|
|
[250, 170, 30],
|
|
[192, 192, 192],
|
|
[220, 220, 0],
|
|
[180, 165, 180],
|
|
[119, 11, 32],
|
|
[0, 0, 142],
|
|
[0, 60, 100],
|
|
[0, 0, 142],
|
|
[0, 0, 90],
|
|
[0, 0, 230],
|
|
[0, 80, 100],
|
|
[128, 64, 64],
|
|
[0, 0, 110],
|
|
[0, 0, 70],
|
|
[0, 0, 192],
|
|
[32, 32, 32],
|
|
[0, 0, 0],
|
|
[0, 0, 0],
|
|
])
|
|
|
|
|
|
def create_pascal_label_colormap():
|
|
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
|
|
|
|
Returns:
|
|
A colormap for visualizing segmentation results.
|
|
"""
|
|
colormap = np.zeros((_DATASET_MAX_ENTRIES[_PASCAL], 3), dtype=int)
|
|
ind = np.arange(_DATASET_MAX_ENTRIES[_PASCAL], dtype=int)
|
|
|
|
for shift in reversed(list(range(8))):
|
|
for channel in range(3):
|
|
colormap[:, channel] |= bit_get(ind, channel) << shift
|
|
ind >>= 3
|
|
|
|
return colormap
|
|
|
|
|
|
def get_ade20k_name():
|
|
return _ADE20K
|
|
|
|
|
|
def get_cityscapes_name():
|
|
return _CITYSCAPES
|
|
|
|
|
|
def get_mapillary_vistas_name():
|
|
return _MAPILLARY_VISTAS
|
|
|
|
|
|
def get_pascal_name():
|
|
return _PASCAL
|
|
|
|
|
|
def bit_get(val, idx):
|
|
"""Gets the bit value.
|
|
|
|
Args:
|
|
val: Input value, int or numpy int array.
|
|
idx: Which bit of the input val.
|
|
|
|
Returns:
|
|
The "idx"-th bit of input val.
|
|
"""
|
|
return (val >> idx) & 1
|
|
|
|
|
|
def create_label_colormap(dataset=_PASCAL):
|
|
"""Creates a label colormap for the specified dataset.
|
|
|
|
Args:
|
|
dataset: The colormap used in the dataset.
|
|
|
|
Returns:
|
|
A numpy array of the dataset colormap.
|
|
|
|
Raises:
|
|
ValueError: If the dataset is not supported.
|
|
"""
|
|
if dataset == _ADE20K:
|
|
return create_ade20k_label_colormap()
|
|
elif dataset == _CITYSCAPES:
|
|
return create_cityscapes_label_colormap()
|
|
elif dataset == _MAPILLARY_VISTAS:
|
|
return create_mapillary_vistas_label_colormap()
|
|
elif dataset == _PASCAL:
|
|
return create_pascal_label_colormap()
|
|
else:
|
|
raise ValueError('Unsupported dataset.')
|
|
|
|
|
|
def label_to_color_image(label, dataset=_PASCAL):
|
|
"""Adds color defined by the dataset colormap to the label.
|
|
|
|
Args:
|
|
label: A 2D array with integer type, storing the segmentation label.
|
|
dataset: The colormap used in the dataset.
|
|
|
|
Returns:
|
|
result: A 2D array with floating type. The element of the array
|
|
is the color indexed by the corresponding element in the input label
|
|
to the dataset color map.
|
|
|
|
Raises:
|
|
ValueError: If label is not of rank 2 or its value is larger than color
|
|
map maximum entry.
|
|
"""
|
|
if label.ndim != 2:
|
|
raise ValueError('Expect 2-D input label. Got {}'.format(label.shape))
|
|
|
|
if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]:
|
|
raise ValueError(
|
|
'label value too large: {} >= {}.'.format(
|
|
np.max(label), _DATASET_MAX_ENTRIES[dataset]))
|
|
|
|
colormap = create_label_colormap(dataset)
|
|
return colormap[label]
|
|
|
|
|
|
def get_dataset_colormap_max_entries(dataset):
|
|
return _DATASET_MAX_ENTRIES[dataset]
|