pythongdal+skimage实现基于遥感影像的传统图像分割及合并外加⽮量化
根据我前述博客中对图像传分割算法及图像块合并⽅法的实验探究,在此将这些⽅法⽤于遥感影像并尝试⽮量化。
这个过程中我⾃⼰遇到了⼀个棘⼿的问题,在最后的结果那⾥有描述,希望知道的朋友帮忙解答⼀下,谢谢!
直接上代码:
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import os
import cv2
from osgeo import ogr, osr, gdal
import numpy as np
from PIL import Image
from skimage import morphology, color, measure
from skimage.segmentation import felzenszwalb, slic, quickshift
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float
from skimage.future import graph
from skimage import data,filters
import matplotlib.pyplot as plt
phology import disk
def read_img(filename):
dataset = gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0, im_width, im_height)
del dataset
return im_width, im_height, im_proj, im_geotrans, im_data
def write_img(filename, im_proj, im_geotrans, im_data):
if'int8'in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif'int16'in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape)==3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands,(im_height, im_width)=1, im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands ==1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i +1).WriteArray(im_data[i])
del dataset
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
out = np.zeros_like(bands).astype(np.float32)
a = img_min
b = img_max
c = np.percentile(bands[:,:], lower_percent)
d = np.percentile(bands[:,:], higher_percent)
t = a +(bands[:,:]- c)*(b - a)/(d - c)
t[t < a]= a
t[t > b]= b
out[:,:]= t
return out
def DoesDriverHandleExtension(drv, ext):
exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
return exts is not None and exts.lower().find(ext.lower())>=0
def GetExtension(filename):
ext = os.path.splitext(filename)[1]
if ext.startswith('.'):
ext = ext[1:]
return ext
def GetOutputDriversFor(filename):
drv_list =[]
ext = GetExtension(filename)
for i in range(gdal.GetDriverCount()):
drv = gdal.GetDriver(i)
if(drv.GetMetadataItem(gdal.DCAP_CREATE)is not None or
drv.GetMetadataItem(gdal.DCAP_CREATECOPY)is not None)and \
drv.GetMetadataItem(gdal.DCAP_VECTOR)is not None:
if ext and DoesDriverHandleExtension(drv, ext):
drv_list.append(drv.ShortName)
else:
prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX)
if prefix is not None and filename.lower().startswith(prefix.lower()):
drv_list.append(drv.ShortName)
return drv_list
def GetOutputDriverFor(filename):
drv_list = GetOutputDriversFor(filename)
ext = GetExtension(filename)
if not drv_list:
if not ext:
return'ESRI Shapefile'
else:
raise Exception("Cannot guess driver for %s"% filename)
elif len(drv_list)>1:
print("Several drivers matching %s extension. Using %s"%(ext if ext else'', drv_list[0])) return drv_list[0]
def_weight_mean_color(graph, src, dst, n):
"""Callback to handle merging nodes by recomputing mean color.
The method expects that the mean color of `dst` is already computed.
Parameters
----------
graph : RAG
The graph under consideration.
src, dst : int
The vertices in `graph` to be merged.
n : int
A neighbor of `src` or `dst` or both.
Returns
-------
data : dict
data : dict
A dictionary with the `"weight"` attribute set as the absolute
difference of the mean color between node `dst` and `n`.
"""
diff = des[dst]['mean color']- des[n]['mean color']
diff = (diff)
return{'weight': diff}
def merge_mean_color(graph, src, dst):
"""Callback called before merging two nodes of a mean color distance graph.
This method computes the mean color of `dst`.
Parameters
----------
getsavefilenamegraph : RAG
The graph under consideration.
src, dst : int
The vertices in `graph` to be merged.
"""
def BetterMedianFilter(src_arr, k =3, padding =None):
# imarray = np.array(Image.open(src))
height, width = src_arr.shape
if not padding:
edge =int((k-1)/2)
if height -1- edge <= edge or width -1- edge <= edge:
print("The parameter k is to large.")
return None
new_arr = np.zeros((height, width), dtype ="uint16")
for i in range(height):
for j in range(width):
if i <= edge -1or i >= height -1- edge or j <= edge -1or j >= height - edge -1: new_arr[i, j]= src_arr[i, j]
else:
nm = src_arr[i - edge:i + edge +1, j - edge:j + edge +1]
max= np.max(nm)
min= np.min(nm)
if src_arr[i, j]==max or src_arr[i, j]==min:
new_arr[i, j]= np.median(nm)
else:
new_arr[i, j]= src_arr[i, j]
return new_arr
if __name__ =='__main__':
img_path ="./temp/test2.tif"
temp_path ="./temp/"
im_width, im_height, im_proj, im_geotrans, im_data = read_img(img_path)
im_data = im_data[0:3]
temp = anspose((2,1,0))
segments_quick = quickshift(temp, kernel_size=3, max_dist=6, ratio=0.5)
mark0 = mark_boundaries(temp, segments_quick)
save_path = temp_path +"qs_seg0.tif"
re0 = anspose((2,1,0))
write_img(save_path, im_proj, im_geotrans, re0)
grid_path = temp_path +"qs_grid0.tif"
grid0 = np.uint8(re0[0,...])
write_img(grid_path, im_proj, im_geotrans, grid0)
skeleton = morphology.skeletonize(grid0)
border0 = np.multiply(grid0, skeleton)
ret, border0 = cv2.threshold(border0,0,1, cv2.THRESH_BINARY + cv2.THRESH_OTSU) border_path = temp_path +"qs_border0.tif"
write_img(border_path, im_proj, im_geotrans, border0)
g = graph.rag_mean_color(temp, segments_quick)
labels2 = _hierarchical(segments_quick, g, thresh=5,
rag_copy=False,
in_place_merge=True,
merge_func=merge_mean_color,
weight_func=_weight_mean_color)
label_rgb2 = color.label2rgb(labels2, temp, kind='avg')
rgb_path = temp_path +"qs_label.tif"
lb = anspose((1,0))
write_img(rgb_path, im_proj, im_geotrans, lb)
label_smooth = temp_path +"qs_label_smooth.tif"
# lb = dian(lb, disk(5))
lb = BetterMedianFilter(lb)
write_img(label_smooth, im_proj, im_geotrans, lb)
mark = mark_boundaries(label_rgb2, labels2)
save_path = temp_path +"qs_seg.tif"
re = anspose((2,1,0))
write_img(save_path, im_proj, im_geotrans, re)
grid_path = temp_path +"qs_grid.tif"
grid = np.uint8(re[0,...])
write_img(grid_path, im_proj, im_geotrans, grid)
skeleton = morphology.skeletonize(grid)
border = np.multiply(grid, skeleton)
ret, border = cv2.threshold(border,0,1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
border_path = temp_path +"qs_border.tif"
write_img(border_path, im_proj, im_geotrans, border)
# out_shp = temp_path + "temp.shp"
# RasterToLineshp(border_path, out_shp, 1)
border_driver = gdal.Open(rgb_path)
border_band = border_driver.GetRasterBand(1)
border_mask = border_band.GetMaskBand()
dst_filename = temp_path +'temp.shp'
frmt = GetOutputDriverFor(dst_filename)
drv = ogr.GetDriverByName(frmt)
dst_ds = drv.CreateDataSource(dst_filename)
dst_layername ='out'
srs = osr.SpatialReference(wkt=border_driver.GetProjection())
dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbPolygon, srs=srs)
# dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbLineString, srs=srs)
dst_fieldname ='DN'
fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
dst_layer.CreateField(fd)
dst_field =0
options =[""]
options.append('DATASET_FOR_GEOREF='+ rgb_path)
prog_func = gdal.TermProgress_nocb
gdal.Polygonize(border_band, border_mask, dst_layer, dst_field, options,
callback=prog_func)
srcband =None
src_ds =None
dst_ds =None
mask_ds =None
# enum WKBGeometryType {
# wkbPoint = 1,
# wkbLineString = 2,
# wkbPolygon = 3,
# wkbTriangle = 17
# wkbMultiPoint = 4,
# wkbMultiLineString = 5,
# wkbMultiPolygon = 6,
# wkbGeometryCollection = 7,
# wkbPolyhedralSurface = 15,
# wkbTIN = 16
# wkbPointZ = 1001,
# wkbLineStringZ = 1002,
# wkbPolygonZ = 1003,
# wkbTrianglez = 1017
# wkbMultiPointZ = 1004,
# wkbMultiLineStringZ = 1005,
# wkbMultiPolygonZ = 1006,
# wkbGeometryCollectionZ = 1007, # wkbPolyhedralSurfaceZ = 1015,
# wkbTINZ = 1016
# wkbPointM = 2001,
# wkbLineStringM = 2002,
# wkbPolygonM = 2003,
# wkbTriangleM = 2017
# wkbMultiPointM = 2004,
# wkbMultiLineStringM = 2005,
# wkbMultiPolygonM = 2006,
# wkbGeometryCollectionM = 2007, # wkbPolyhedralSurfaceM = 2015, # wkbTINM = 2016
# wkbPointZM = 3001,
# wkbLineStringZM = 3002,
# wkbPolygonZM = 3003,
# wkbTriangleZM = 3017
# wkbMultiPointZM = 3004,
# wkbMultiLineStringZM = 3005,
# wkbMultiPolygonZM = 3006,
# wkbGeometryCollectionZM = 3007, # wkbPolyhedralSurfaceZM = 3015, # wkbTinZM = 3016,
# }
对应的结果图如下:
原图:
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论