请解释这段代码def load_dataset(datafiles): def read(data_path): with open(data_path, 'r', encoding='utf-8') as fp: next(fp) for line in fp.readlines(): words, labels = line.strip('\n').split('\t') words = words.split('\002') labels = labels.split('\002') yield words, labels if isinstance(datafiles, str): return MapDataset(list(read(datafiles))) elif isinstance(datafiles, list) or isinstance(datafiles, tuple): return [MapDataset(list(read(datafile))) for datafile in datafiles]
时间: 2023-12-03 14:02:12 浏览: 175
这段代码定义了一个函数叫做`load_dataset`,它接受一个或多个数据文件的路径作为输入。该函数使用嵌套函数`read`来读取每个数据文件,并将其转换为一个`MapDataset`对象。
在`read`函数中,它使用`with`语句打开数据文件并逐行读取文件内容。每一行都包含一个句子及其对应的标签,它们使用`\t`分隔。句子和标签使用`\002`分隔。在读取每一行后,它将句子和标签分别存储在`words`和`labels`变量中,然后将它们转换为列表。最后,它使用`yield`语句将句子和标签作为一对返回。
在`load_dataset`函数中,如果输入是一个字符串,则假定它是单个数据文件的路径,并将其传递给`read`函数。否则,如果输入是一个列表或元组,则假定它包含多个数据文件的路径,并为每个数据文件调用`read`函数。在每种情况下,`MapDataset`被用于将数据转换为可映射的对象,以便于后续操作。最后,该函数返回一个`MapDataset`对象或一个`MapDataset`对象的列表。
相关问题
# main_window.py(主窗口逻辑) import numpy as np from PySide6.QtWidgets import QMainWindow, QFileDialog, QGraphicsScene, QGraphicsView, QMessageBox, QGraphicsPathItem from PySide6.QtGui import QPainterPath, QPen, QBrush, QAction, QTransform, QImage, QPixmap, QColor from PySide6.QtCore import Qt, QRectF, QPointF from osgeo import ogr, gdal from PySide6.QtWidgets import QInputDialog # 新增输入对话框 # 新增自定义图形项类(用于存储属性) class FeatureItem(QGraphicsPathItem): def __init__(self, path, attributes): super().__init__(path) self.attributes = attributes # 存储属性字典 class MainWindow(QMainWindow): def __init__(self): super().__init__() self.setWindowTitle("GIS软件") self.setGeometry(100, 100, 800, 600) ogr.UseExceptions() self.init_ui() self.scene = QGraphicsScene(self) self.graphicsView.setScene(self.scene) # 新增:存储所有几何边界 self.total_bounds = QRectF() def init_ui(self): self.toolBar = self.addToolBar("工具") self.actionOpen_Vector_Data = QAction("打开矢量数据", self) self.toolBar.addAction(self.actionOpen_Vector_Data) # 新增栅格动作 self.actionOpen_Raster_Data = QAction("打开栅格数据", self) self.toolBar.addAction(self.actionOpen_Raster_Data) # 添加到工具栏 # 新增缓冲区分析按钮 self.actionBuffer_Analysis = QAction("缓冲区分析", self) self.toolBar.addAction(self.actionBuffer_Analysis) self.graphicsView = QGraphicsView() self.setCentralWidget(self.graphicsView) # 新增属性查询按钮 self.actionQuery_Attribute = QAction("属性查询", self) self.toolBar.addAction(self.actionQuery_Attribute) self.actionOpen_Vector_Data.triggered.connect(self.open_vector_data) self.actionOpen_Raster_Data.triggered.connect(self.open_raster_data) # 新增连接 self.actionBuffer_Analysis.triggered.connect(self.buffer_analysis) self.actionQuery_Attribute.triggered.connect(self.enable_query_mode) # 新增鼠标点击事件 self.graphicsView.setMouseTracking(True) self.is_query_mode = False # 新增波段组合按钮 self.actionBand_Combination = QAction("波段组合", self) self.toolBar.addAction(self.actionBand_Combination) self.actionBand_Combination.triggered.connect(self.open_band_combination) def open_vector_data(self): file_path, _ = QFileDialog.getOpenFileName( self, "打开矢量文件", "", "Shapefile (*.shp);;GeoJSON (*.geojson);;All Files (*)" ) if file_path: self.load_vector_data(file_path) # 新增:自动缩放视图 self.auto_zoom() def load_vector_data(self, file_path): self.scene.clear() self.total_bounds = QRectF() # 重置边界 try: data_source = ogr.Open(file_path, 0) layer = data_source.GetLayer(0) for feature in layer: geom = feature.GetGeometryRef() path = self.geometry_to_qpainterpath(geom) # 更新总边界 if path.boundingRect().isValid(): self.total_bounds = self.total_bounds.united(path.boundingRect()) pen = QPen(Qt.blue, 1) brush = QBrush(Qt.cyan) self.scene.addPath(path, pen, brush) data_source = None except Exception as e: print(f"加载失败: {str(e)}") self.current_vector_path = file_path # 新增这一行 data_source = None def geometry_to_qpainterpath(self, geom): path = QPainterPath() if geom.GetGeometryType() == ogr.wkbPolygon: for ring in range(geom.GetGeometryCount()): linear_ring = geom.GetGeometryRef(ring) points = linear_ring.GetPoints() if points: path.moveTo(points[0][0], points[0][1]) for p in points[1:]: path.lineTo(p[0], p[1]) path.closeSubpath() elif geom.GetGeometryType() == ogr.wkbLineString: points = geom.GetPoints() if points: path.moveTo(points[0][0], points[0][1]) for p in points[1:]: path.lineTo(p[0], p[1]) elif geom.GetGeometryType() == ogr.wkbPoint: x, y = geom.GetX(), geom.GetY() path.addEllipse(x - 2, y - 2, 4, 4) return path def auto_zoom(self): """自动缩放视图到数据范围并放大2倍""" if not self.total_bounds.isValid(): return # 设置场景边界 self.scene.setSceneRect(self.total_bounds) # 获取视图可视区域 view_rect = self.graphicsView.viewport().rect() # 计算缩放比例(自动适应 + 2倍放大) transform = QTransform() transform.scale(2, 2) # 先放大2倍 # 应用缩放并居中 self.graphicsView.setTransform(transform) self.graphicsView.fitInView(self.total_bounds, Qt.KeepAspectRatio) # 新增缓冲区分析方法 def buffer_analysis(self): """执行缓冲区分析""" if not hasattr(self, 'current_vector_path'): QMessageBox.warning(self, "警告", "请先打开矢量数据文件!") return # 获取缓冲距离 distance, ok = QInputDialog.getDouble( self, "缓冲区分析", "输入缓冲距离(单位与数据坐标系一致):", 0.0, 0 ) if not ok: return try: # 重新打开数据源获取几何 data_source = ogr.Open(self.current_vector_path, 0) layer = data_source.GetLayer(0) # 创建缓冲区路径 buffer_path = QPainterPath() pen = QPen(Qt.red, 2, Qt.DashLine) brush = QBrush(QColor(255, 0, 0, 50)) # 半透明红色填充 for feature in layer: geom = feature.GetGeometryRef() buffer_geom = geom.Buffer(distance) path = self.geometry_to_qpainterpath(buffer_geom) buffer_path.addPath(path) # 添加到场景 self.scene.addPath(buffer_path, pen, brush) # 更新视图边界 if buffer_path.boundingRect().isValid(): self.total_bounds = self.total_bounds.united(buffer_path.boundingRect()) self.auto_zoom() data_source = None except Exception as e: QMessageBox.critical(self, "错误", f"缓冲区分析失败: {str(e)}") def load_vector_data(self, file_path): self.scene.clear() self.total_bounds = QRectF() try: data_source = ogr.Open(file_path, 0) layer = data_source.GetLayer(0) # 获取字段定义 layer_defn = layer.GetLayerDefn() field_names = [layer_defn.GetFieldDefn(i).GetName() for i in range(layer_defn.GetFieldCount())] for feature in layer: geom = feature.GetGeometryRef() path = self.geometry_to_qpainterpath(geom) # 创建属性字典 attributes = { "FID": feature.GetFID(), **{name: feature.GetField(name) for name in field_names} } # 使用自定义图形项 item = FeatureItem(path, attributes) item.setPen(QPen(Qt.blue, 1)) item.setBrush(QBrush(Qt.cyan)) self.scene.addItem(item) if path.boundingRect().isValid(): self.total_bounds = self.total_bounds.united(path.boundingRect()) data_source = None except Exception as e: print(f"加载失败: {str(e)}") self.current_vector_path = file_path data_source = None # 新增属性查询方法 def enable_query_mode(self): """启用属性查询模式""" self.is_query_mode = not self.is_query_mode self.actionQuery_Attribute.setText("退出查询" if self.is_query_mode else "属性查询") self.graphicsView.setCursor(Qt.CrossCursor if self.is_query_mode else Qt.ArrowCursor) # 新增鼠标事件处理 def mousePressEvent(self, event): if self.is_query_mode and event.button() == Qt.LeftButton: scene_pos = self.graphicsView.mapToScene(event.pos()) items = self.scene.items(scene_pos, Qt.IntersectsItemShape, Qt.DescendingOrder) for item in items: if isinstance(item, FeatureItem): # 构建属性信息字符串 info = "\n".join([f"{k}: {v}" for k, v in item.attributes.items()]) QMessageBox.information(self, "要素属性", info) return super().mousePressEvent(event) def open_raster_data(self): """打开栅格数据文件""" file_path, _ = QFileDialog.getOpenFileName( self, "打开栅格文件", "", "GeoTIFF (*.tif);;JPEG (*.jpg *.jpeg);;PNG (*.png);;All Files (*)" ) if file_path: try: self.load_raster_data(file_path) self.auto_zoom() except Exception as e: QMessageBox.critical(self, "错误", f"加载栅格失败: {str(e)}") def load_raster_data(self, file_path): """加载栅格数据到视图""" # 打开栅格文件(需要用户修改路径的部分) dataset = gdal.Open(file_path) # 相对路径示例:"./data/raster.tif" # 读取第一个波段 band = dataset.GetRasterBand(1) width = dataset.RasterXSize height = dataset.RasterYSize # 转换为numpy数组 data = band.ReadAsArray() # 创建QImage(注意数据类型转换) if data.dtype == np.uint8: format = QImage.Format.Format_Grayscale8 else: format = QImage.Format.Format_ARGB32 q_img = QImage(data.tobytes(), width, height, format) # 创建像素图项 pixmap = QPixmap.fromImage(q_img) raster_item = self.scene.addPixmap(pixmap) # 处理地理坐标(如果存在) geotransform = dataset.GetGeoTransform() if geotransform: # 计算四个角的坐标 x_origin = geotransform[0] y_origin = geotransform[3] pixel_width = geotransform[1] pixel_height = geotransform[5] # 更新场景边界 x_min = x_origin x_max = x_origin + pixel_width * width y_min = y_origin + pixel_height * height y_max = y_origin self.total_bounds = QRectF( QPointF(x_min, y_min), QPointF(x_max, y_max) ) dataset = None # 关闭数据集 def open_band_combination(self): if not hasattr(self, 'current_raster_path'): QMessageBox.warning(self, "警告", "请先打开栅格数据文件!") return # 复用open_raster_data的逻辑 self.open_raster_data() def open_raster_data(self): file_path, _ = QFileDialog.getOpenFileName( self, "打开栅格文件", "", "GeoTIFF (*.tif);;JPEG (*.jpg *.jpeg);;PNG (*.png);;All Files (*)" ) if file_path: try: dataset = gdal.Open(file_path) num_bands = dataset.RasterCount # 获取用户输入的波段组合 red_band, ok1 = QInputDialog.getInt( self, "波段选择", f"红通道波段号 (1-{num_bands}):", 1, 1, num_bands ) green_band, ok2 = QInputDialog.getInt( self, "波段选择", f"绿通道波段号 (1-{num_bands}):", min(2, num_bands), 1, num_bands ) blue_band, ok3 = QInputDialog.getInt( self, "波段选择", f"蓝通道波段号 (1-{num_bands}):", min(3, num_bands), 1, num_bands ) if not (ok1 and ok2 and ok3): return self.load_raster_data(file_path, red_band, green_band, blue_band) self.auto_zoom() self.current_raster_path = file_path # 新增存储当前路径 except Exception as e: QMessageBox.critical(self, "错误", f"加载栅格失败: {str(e)}") def load_raster_data(self, file_path, red_band=1, green_band=2, blue_band=3): """加载栅格数据到视图(支持波段组合)""" dataset = gdal.Open(file_path) width = dataset.RasterXSize height = dataset.RasterYSize # 读取三个波段数据 def read_band(band_num): band = dataset.GetRasterBand(band_num) data = band.ReadAsArray() # 自动拉伸到0-255范围 data_min = data.min() data_max = data.max() return np.clip(((data - data_min) / (data_max - data_min) * 255), 0, 255).astype(np.uint8) # 合并波段 rgb_array = np.dstack([ read_band(red_band), read_band(green_band), read_band(blue_band) ]) # 创建QImage q_img = QImage( rgb_array.data, width, height, 3 * width, # 每像素3字节(RGB) QImage.Format.Format_RGB888 ) # 创建像素图项 pixmap = QPixmap.fromImage(q_img) self.scene.addPixmap(pixmap) # 处理地理坐标(保持原有逻辑) geotransform = dataset.GetGeoTransform() if geotransform: x_origin = geotransform[0] y_origin = geotransform[3] pixel_width = geotransform[1] pixel_height = geotransform[5] x_min = x_origin x_max = x_origin + pixel_width * width y_min = y_origin + pixel_height * height # 计算下边界 y_max = y_origin # 上边界 # 确保坐标顺序正确 if x_min > x_max: x_min, x_max = x_max, x_min if y_min > y_max: y_min, y_max = y_max, y_min self.total_bounds = QRectF(QPointF(x_min, y_min), QPointF(x_max, y_max)) dataset = None这是我的完整代码,里面以及实现了打开矢量数据和打开栅格数据的功能,请你在打开栅格数据的基础上增加一个显示裁剪的功能,这个功能是单独的哈,直接在我这个代码的基础上增加代码就行,原有的代码不要改,然后把增加裁剪这一功能的完整代码发给我就行。
<think>好的,我现在需要帮助用户在现有的GIS软件代码中添加一个栅格裁剪功能。用户希望这个功能是独立的,不要修改原有代码,只添加新的部分。首先,我需要理解用户现有的代码结构,找到合适的位置插入新功能。
首先,用户的主窗口类MainWindow中已经有打开栅格数据的功能,相关的函数是open_raster_data和load_raster_data。裁剪功能需要一个新的按钮,可能放在工具栏里,所以需要修改init_ui方法,添加一个QAction,比如“栅格裁剪”,并连接到新的处理函数。
接下来,处理函数需要执行以下步骤:
1. 检查是否已经打开了栅格数据,如果没有,弹出警告。
2. 打开文件对话框让用户选择矢量裁剪区域文件(比如Shapefile)。
3. 使用GDAL读取栅格数据和矢量数据,提取矢量的几何范围。
4. 根据矢量的范围对栅格进行裁剪,可能需要用到GDAL的Warp功能。
5. 将裁剪后的栅格显示在场景中,并调整视图。
需要注意的是,用户现有的代码中已经处理了栅格的加载和显示,所以裁剪后的结果应该复用现有的load_raster_data方法,或者类似的逻辑,确保显示正确。同时,要确保不修改原有代码,只添加新的部分。
在代码结构上,需要添加新的方法,比如clip_raster,处理裁剪逻辑。还需要在init_ui中添加新的按钮和动作。另外,可能需要导入一些新的模块,比如os模块来处理文件路径,但用户现有的代码中可能已经导入了必要的库,需要检查。
另外,错误处理也很重要,比如处理文件打开失败、GDAL操作异常等,需要添加try-except块,并弹出错误提示。
最后,测试新功能是否独立,不会影响原有功能,比如打开栅格和矢量数据、波段组合等。确保裁剪后的栅格正确显示,并且视图自动调整。
总结步骤:
1. 在init_ui中添加“栅格裁剪”按钮。
2. 编写clip_raster方法,处理文件选择、数据读取、裁剪操作。
3. 使用GDAL的Warp功能进行裁剪,可能需要设置目标范围。
4. 调用现有的load_raster_data显示结果,或者新建一个方法来显示裁剪后的栅格。
5. 添加必要的错误处理和用户提示。
现在需要将这些步骤转化为具体的代码,插入到用户现有的代码中,确保不修改原有部分,只添加新代码。
</think>
以下是在您现有代码基础上新增的栅格裁剪功能完整代码(新增部分用注释标出):
```python
# 新增栅格裁剪功能(在init_ui方法中添加)
def init_ui(self):
# ... [原有其他按钮代码] ...
# 新增栅格裁剪按钮(在init_ui方法末尾添加)
self.actionClip_Raster = QAction("栅格裁剪", self)
self.toolBar.addAction(self.actionClip_Raster)
self.actionClip_Raster.triggered.connect(self.clip_raster) # 新增连接
# 新增栅格裁剪方法(在类末尾添加)
def clip_raster(self):
"""执行栅格裁剪功能"""
if not hasattr(self, 'current_raster_path'):
QMessageBox.warning(self, "警告", "请先打开栅格数据文件!")
return
# 选择裁剪矢量文件
vector_path, _ = QFileDialog.getOpenFileName(
self, "选择裁剪区域文件", "",
"Shapefile (*.shp);;GeoJSON (*.geojson);;All Files (*)"
)
if not vector_path:
return
try:
# 获取原始栅格信息
src_ds = gdal.Open(self.current_raster_path)
geotransform = src_ds.GetGeoTransform()
proj = src_ds.GetProjection()
# 获取矢量范围
vector_ds = ogr.Open(vector_path)
layer = vector_ds.GetLayer()
feature = layer.GetNextFeature()
geom = feature.GetGeometryRef()
x_min, x_max, y_min, y_max = geom.GetEnvelope()
# 创建临时裁剪结果文件
import os
output_path = os.path.splitext(self.current_raster_path)[0] + "_clipped.tif"
# 执行裁剪操作
options = gdal.WarpOptions(
format='GTiff',
outputBounds=[x_min, y_min, x_max, y_max],
dstSRS=proj
)
gdal.Warp(output_path, src_ds, options=options)
# 加载裁剪结果
self.load_raster_data(output_path)
self.auto_zoom()
# 清理资源
src_ds = None
vector_ds = None
except Exception as e:
QMessageBox.critical(self, "错误", f"栅格裁剪失败: {str(e)}")
```
需要新增的代码位置:
1. 在`init_ui`方法末尾添加新按钮(约第50行)
2. 在类定义末尾添加`clip_raster`方法(约第300行)
现在我已经完成了这些功能,最后我需要一个保存的功能,这个是我的完整的代码直接在我这个代码的基础上增加代码就行,原有的代码不要改,然后把保存这一功能的完整代码发给我就行。就是要保存之前改了的某些数据。# main_window.py(主窗口逻辑) import numpy as np from PySide6.QtWidgets import QMainWindow, QFileDialog, QGraphicsScene, QGraphicsView, QMessageBox, QGraphicsPathItem from PySide6.QtGui import QPainterPath, QPen, QBrush, QAction, QTransform, QImage, QPixmap, QColor from PySide6.QtCore import Qt, QRectF, QPointF from osgeo import ogr, gdal from PySide6.QtWidgets import QInputDialog # 新增输入对话框 # 新增自定义图形项类(用于存储属性) class FeatureItem(QGraphicsPathItem): def __init__(self, path, attributes): super().__init__(path) self.attributes = attributes # 存储属性字典 class MainWindow(QMainWindow): def __init__(self): super().__init__() self.setWindowTitle("GIS软件") self.setGeometry(100, 100, 800, 600) ogr.UseExceptions() self.init_ui() self.scene = QGraphicsScene(self) self.graphicsView.setScene(self.scene) # 新增:存储所有几何边界 self.total_bounds = QRectF() def init_ui(self): self.toolBar = self.addToolBar("工具") self.actionOpen_Vector_Data = QAction("打开矢量数据", self) self.toolBar.addAction(self.actionOpen_Vector_Data) # 新增栅格动作 self.actionOpen_Raster_Data = QAction("打开栅格数据", self) self.toolBar.addAction(self.actionOpen_Raster_Data) # 添加到工具栏 # 新增缓冲区分析按钮 self.actionBuffer_Analysis = QAction("缓冲区分析", self) self.toolBar.addAction(self.actionBuffer_Analysis) self.graphicsView = QGraphicsView() self.setCentralWidget(self.graphicsView) # 新增属性查询按钮 self.actionQuery_Attribute = QAction("属性查询", self) self.toolBar.addAction(self.actionQuery_Attribute) self.actionOpen_Vector_Data.triggered.connect(self.open_vector_data) self.actionOpen_Raster_Data.triggered.connect(self.open_raster_data) # 新增连接 self.actionBuffer_Analysis.triggered.connect(self.buffer_analysis) self.actionQuery_Attribute.triggered.connect(self.enable_query_mode) # 新增鼠标点击事件 self.graphicsView.setMouseTracking(True) self.is_query_mode = False # 新增波段组合按钮 self.actionBand_Combination = QAction("波段组合", self) self.toolBar.addAction(self.actionBand_Combination) self.actionBand_Combination.triggered.connect(self.open_band_combination) # 新增栅格裁剪按钮(在init_ui方法末尾添加) self.actionClip_Raster = QAction("栅格裁剪", self) self.toolBar.addAction(self.actionClip_Raster) self.actionClip_Raster.triggered.connect(self.clip_raster) # 新增连接 self.actionBand_Calculation = QAction("波段运算", self) self.toolBar.addAction(self.actionBand_Calculation) self.actionBand_Calculation.triggered.connect(self.band_calculation) # 新增质心绘制按钮(放在init_ui方法中) self.actionDraw_Centroids = QAction("绘制质心", self) self.toolBar.addAction(self.actionDraw_Centroids) self.actionDraw_Centroids.triggered.connect(self.draw_centroids) self.centroid_items = [] # 新增:存储质心图形项 # 新增空间查询按钮(放在init_ui方法中) self.actionSpatial_Query = QAction("空间查询", self) self.toolBar.addAction(self.actionSpatial_Query) self.actionSpatial_Query.triggered.connect(self.enable_spatial_query_mode) self.is_spatial_query_mode = False self.spatial_query_results = [] # 存储查询结果 def open_vector_data(self): file_path, _ = QFileDialog.getOpenFileName( self, "打开矢量文件", "", "Shapefile (*.shp);;GeoJSON (*.geojson);;All Files (*)" ) if file_path: self.load_vector_data(file_path) # 新增:自动缩放视图 self.auto_zoom() def load_vector_data(self, file_path): self.scene.clear() self.total_bounds = QRectF() # 重置边界 try: data_source = ogr.Open(file_path, 0) layer = data_source.GetLayer(0) for feature in layer: geom = feature.GetGeometryRef() path = self.geometry_to_qpainterpath(geom) # 更新总边界 if path.boundingRect().isValid(): self.total_bounds = self.total_bounds.united(path.boundingRect()) pen = QPen(Qt.blue, 1) brush = QBrush(Qt.cyan) self.scene.addPath(path, pen, brush) data_source = None except Exception as e: print(f"加载失败: {str(e)}") self.current_vector_path = file_path # 新增这一行 data_source = None def geometry_to_qpainterpath(self, geom): path = QPainterPath() if geom.GetGeometryType() == ogr.wkbPolygon: for ring in range(geom.GetGeometryCount()): linear_ring = geom.GetGeometryRef(ring) points = linear_ring.GetPoints() if points: path.moveTo(points[0][0], points[0][1]) for p in points[1:]: path.lineTo(p[0], p[1]) path.closeSubpath() elif geom.GetGeometryType() == ogr.wkbLineString: points = geom.GetPoints() if points: path.moveTo(points[0][0], points[0][1]) for p in points[1:]: path.lineTo(p[0], p[1]) elif geom.GetGeometryType() == ogr.wkbPoint: x, y = geom.GetX(), geom.GetY() path.addEllipse(x - 2, y - 2, 4, 4) return path def auto_zoom(self): """自动缩放视图到数据范围并放大2倍""" if not self.total_bounds.isValid(): return # 设置场景边界 self.scene.setSceneRect(self.total_bounds) # 获取视图可视区域 view_rect = self.graphicsView.viewport().rect() # 计算缩放比例(自动适应 + 2倍放大) transform = QTransform() transform.scale(2, 2) # 先放大2倍 # 应用缩放并居中 self.graphicsView.setTransform(transform) self.graphicsView.fitInView(self.total_bounds, Qt.KeepAspectRatio) # 新增缓冲区分析方法 def buffer_analysis(self): """执行缓冲区分析""" if not hasattr(self, 'current_vector_path'): QMessageBox.warning(self, "警告", "请先打开矢量数据文件!") return # 获取缓冲距离 distance, ok = QInputDialog.getDouble( self, "缓冲区分析", "输入缓冲距离(单位与数据坐标系一致):", 0.0, 0 ) if not ok: return try: # 重新打开数据源获取几何 data_source = ogr.Open(self.current_vector_path, 0) layer = data_source.GetLayer(0) # 创建缓冲区路径 buffer_path = QPainterPath() pen = QPen(Qt.red, 2, Qt.DashLine) brush = QBrush(QColor(255, 0, 0, 50)) # 半透明红色填充 for feature in layer: geom = feature.GetGeometryRef() buffer_geom = geom.Buffer(distance) path = self.geometry_to_qpainterpath(buffer_geom) buffer_path.addPath(path) # 添加到场景 self.scene.addPath(buffer_path, pen, brush) # 更新视图边界 if buffer_path.boundingRect().isValid(): self.total_bounds = self.total_bounds.united(buffer_path.boundingRect()) self.auto_zoom() data_source = None except Exception as e: QMessageBox.critical(self, "错误", f"缓冲区分析失败: {str(e)}") def load_vector_data(self, file_path): self.scene.clear() self.total_bounds = QRectF() try: data_source = ogr.Open(file_path, 0) layer = data_source.GetLayer(0) # 获取字段定义 layer_defn = layer.GetLayerDefn() field_names = [layer_defn.GetFieldDefn(i).GetName() for i in range(layer_defn.GetFieldCount())] for feature in layer: geom = feature.GetGeometryRef() path = self.geometry_to_qpainterpath(geom) # 创建属性字典 attributes = { "FID": feature.GetFID(), **{name: feature.GetField(name) for name in field_names} } # 使用自定义图形项 item = FeatureItem(path, attributes) item.setPen(QPen(Qt.blue, 1)) item.setBrush(QBrush(Qt.cyan)) self.scene.addItem(item) if path.boundingRect().isValid(): self.total_bounds = self.total_bounds.united(path.boundingRect()) data_source = None except Exception as e: print(f"加载失败: {str(e)}") self.current_vector_path = file_path data_source = None # 新增属性查询方法 def enable_query_mode(self): """启用属性查询模式""" self.is_query_mode = not self.is_query_mode self.actionQuery_Attribute.setText("退出查询" if self.is_query_mode else "属性查询") self.graphicsView.setCursor(Qt.CrossCursor if self.is_query_mode else Qt.ArrowCursor) # 新增鼠标事件处理 def mousePressEvent(self, event): if self.is_query_mode and event.button() == Qt.LeftButton: scene_pos = self.graphicsView.mapToScene(event.pos()) items = self.scene.items(scene_pos, Qt.IntersectsItemShape, Qt.DescendingOrder) for item in items: if isinstance(item, FeatureItem): # 构建属性信息字符串 info = "\n".join([f"{k}: {v}" for k, v in item.attributes.items()]) QMessageBox.information(self, "要素属性", info) return super().mousePressEvent(event) def draw_centroids(self): """独立质心绘制功能""" if not hasattr(self, 'current_vector_path'): QMessageBox.warning(self, "警告", "请先打开矢量数据文件!") return # 清除已有质心 for item in self.centroid_items: self.scene.removeItem(item) self.centroid_items.clear() try: data_source = ogr.Open(self.current_vector_path, 0) layer = data_source.GetLayer(0) for feature in layer: geom = feature.GetGeometryRef() centroid = geom.Centroid() if centroid: # 创建质心图形项 path = QPainterPath() path.addEllipse( QRectF( centroid.GetX() - 0.3, # 修改为0.3像素半径 centroid.GetY() - 0.3, 0.6, 0.6 # 直径0.6像素 ) ) item = self.scene.addPath( path, QPen(Qt.red, 0.1), QBrush(Qt.red) ) self.centroid_items.append(item) data_source = None self.auto_zoom() except Exception as e: QMessageBox.critical(self, "错误", f"质心绘制失败: {str(e)}") # 新增空间查询模式切换方法 def enable_spatial_query_mode(self): """启用空间查询模式""" self.is_spatial_query_mode = not self.is_spatial_query_mode self.actionSpatial_Query.setText("退出空间查询" if self.is_spatial_query_mode else "空间查询") self.graphicsView.setCursor(Qt.CrossCursor if self.is_spatial_query_mode else Qt.ArrowCursor) if not self.is_spatial_query_mode: self.clear_spatial_query_results() # 新增空间查询处理方法 def mousePressEvent(self, event): if self.is_spatial_query_mode and event.button() == Qt.LeftButton: scene_pos = self.graphicsView.mapToScene(event.pos()) items = self.scene.items(scene_pos, Qt.IntersectsItemShape, Qt.DescendingOrder) for item in items: if isinstance(item, FeatureItem): # 获取空间关系选择 relations = ["相交", "包含", "被包含", "接触", "重叠"] relation, ok = QInputDialog.getItem( self, "空间关系选择", "请选择空间关系:", relations, 0, False ) if not ok: return # 执行空间查询 self.perform_spatial_query(item, relation) return super().mousePressEvent(event) # 新增空间查询核心方法 def perform_spatial_query(self, source_item, relation): """执行空间查询并高亮结果""" self.clear_spatial_query_results() try: # 获取源要素几何 source_geom = self.item_to_geometry(source_item) if not source_geom: return # 获取所有要素 all_items = [item for item in self.scene.items() if isinstance(item, FeatureItem)] # 遍历检查空间关系 for target_item in all_items: target_geom = self.item_to_geometry(target_item) if not target_geom: continue # 执行空间关系判断 if relation == "相交" and source_geom.Intersects(target_geom): self.highlight_item(target_item) elif relation == "包含" and source_geom.Contains(target_geom): self.highlight_item(target_item) elif relation == "被包含" and target_geom.Contains(source_geom): self.highlight_item(target_item) elif relation == "接触" and source_geom.Touches(target_geom): self.highlight_item(target_item) elif relation == "重叠" and source_geom.Overlaps(target_geom): self.highlight_item(target_item) except Exception as e: QMessageBox.critical(self, "错误", f"空间查询失败: {str(e)}") # 新增辅助方法 def item_to_geometry(self, item): """将图形项转换为OGR几何对象""" path = item.path() elements = path.toSubpathPolygons(QTransform()) if not elements: return None # 创建多边形几何 geom = ogr.Geometry(ogr.wkbPolygon) ring = ogr.Geometry(ogr.wkbLinearRing) for point in elements[0]: ring.AddPoint(point.x(), point.y()) ring.CloseRings() geom.AddGeometry(ring) return geom def highlight_item(self, item): """高亮显示查询结果""" original_pen = item.pen() highlight_pen = QPen(Qt.yellow, 3) item.setPen(highlight_pen) self.spatial_query_results.append((item, original_pen)) def clear_spatial_query_results(self): """清除查询结果高亮""" for item, original_pen in self.spatial_query_results: item.setPen(original_pen) self.spatial_query_results.clear() def open_raster_data(self): """打开栅格数据文件""" file_path, _ = QFileDialog.getOpenFileName( self, "打开栅格文件", "", "GeoTIFF (*.tif);;JPEG (*.jpg *.jpeg);;PNG (*.png);;All Files (*)" ) if file_path: try: self.load_raster_data(file_path) self.auto_zoom() except Exception as e: QMessageBox.critical(self, "错误", f"加载栅格失败: {str(e)}") def load_raster_data(self, file_path): """加载栅格数据到视图""" # 打开栅格文件(需要用户修改路径的部分) dataset = gdal.Open(file_path) # 相对路径示例:"./data/raster.tif" # 读取第一个波段 band = dataset.GetRasterBand(1) width = dataset.RasterXSize height = dataset.RasterYSize # 转换为numpy数组 data = band.ReadAsArray() # 创建QImage(注意数据类型转换) if data.dtype == np.uint8: format = QImage.Format.Format_Grayscale8 else: format = QImage.Format.Format_ARGB32 q_img = QImage(data.tobytes(), width, height, format) # 创建像素图项 pixmap = QPixmap.fromImage(q_img) raster_item = self.scene.addPixmap(pixmap) # 处理地理坐标(如果存在) geotransform = dataset.GetGeoTransform() if geotransform: # 计算四个角的坐标 x_origin = geotransform[0] y_origin = geotransform[3] pixel_width = geotransform[1] pixel_height = geotransform[5] # 更新场景边界 x_min = x_origin x_max = x_origin + pixel_width * width y_min = y_origin + pixel_height * height y_max = y_origin self.total_bounds = QRectF( QPointF(x_min, y_min), QPointF(x_max, y_max) ) dataset = None # 关闭数据集 def open_band_combination(self): if not hasattr(self, 'current_raster_path'): QMessageBox.warning(self, "警告", "请先打开栅格数据文件!") return # 复用open_raster_data的逻辑 self.open_raster_data() def open_raster_data(self): file_path, _ = QFileDialog.getOpenFileName( self, "打开栅格文件", "", "GeoTIFF (*.tif);;JPEG (*.jpg *.jpeg);;PNG (*.png);;All Files (*)" ) if file_path: try: dataset = gdal.Open(file_path) num_bands = dataset.RasterCount # 获取用户输入的波段组合 red_band, ok1 = QInputDialog.getInt( self, "波段选择", f"红通道波段号 (1-{num_bands}):", 1, 1, num_bands ) green_band, ok2 = QInputDialog.getInt( self, "波段选择", f"绿通道波段号 (1-{num_bands}):", min(2, num_bands), 1, num_bands ) blue_band, ok3 = QInputDialog.getInt( self, "波段选择", f"蓝通道波段号 (1-{num_bands}):", min(3, num_bands), 1, num_bands ) if not (ok1 and ok2 and ok3): return self.load_raster_data(file_path, red_band, green_band, blue_band) self.auto_zoom() self.current_raster_path = file_path # 新增存储当前路径 except Exception as e: QMessageBox.critical(self, "错误", f"加载栅格失败: {str(e)}") def load_raster_data(self, file_path, red_band=1, green_band=2, blue_band=3): """加载栅格数据到视图(支持波段组合)""" dataset = gdal.Open(file_path) width = dataset.RasterXSize height = dataset.RasterYSize # 读取三个波段数据 def read_band(band_num): band = dataset.GetRasterBand(band_num) data = band.ReadAsArray() # 自动拉伸到0-255范围 data_min = data.min() data_max = data.max() return np.clip(((data - data_min) / (data_max - data_min) * 255), 0, 255).astype(np.uint8) # 合并波段 rgb_array = np.dstack([ read_band(red_band), read_band(green_band), read_band(blue_band) ]) # 创建QImage q_img = QImage( rgb_array.data, width, height, 3 * width, # 每像素3字节(RGB) QImage.Format.Format_RGB888 ) # 创建像素图项 pixmap = QPixmap.fromImage(q_img) self.scene.addPixmap(pixmap) # 处理地理坐标(保持原有逻辑) geotransform = dataset.GetGeoTransform() if geotransform: x_origin = geotransform[0] y_origin = geotransform[3] pixel_width = geotransform[1] pixel_height = geotransform[5] x_min = x_origin x_max = x_origin + pixel_width * width y_min = y_origin + pixel_height * height # 计算下边界 y_max = y_origin # 上边界 # 确保坐标顺序正确 if x_min > x_max: x_min, x_max = x_max, x_min if y_min > y_max: y_min, y_max = y_max, y_min self.total_bounds = QRectF(QPointF(x_min, y_min), QPointF(x_max, y_max)) dataset = None # 新增栅格裁剪方法(必须缩进在类内部) def clip_raster(self): """执行栅格裁剪功能""" if not hasattr(self, 'current_raster_path'): QMessageBox.warning(self, "警告", "请先打开栅格数据文件!") return # 选择裁剪矢量文件 vector_path, _ = QFileDialog.getOpenFileName( self, "选择裁剪区域文件", "", "Shapefile (*.shp);;GeoJSON (*.geojson);;All Files (*)" ) if not vector_path: return try: # 获取原始栅格信息 src_ds = gdal.Open(self.current_raster_path) geotransform = src_ds.GetGeoTransform() proj = src_ds.GetProjection() # 获取矢量范围 vector_ds = ogr.Open(vector_path) layer = vector_ds.GetLayer() feature = layer.GetNextFeature() geom = feature.GetGeometryRef() x_min, x_max, y_min, y_max = geom.GetEnvelope() # 创建临时裁剪结果文件 import os # 确保导入os模块 output_path = os.path.splitext(self.current_raster_path)[0] + "_clipped.tif" # 执行裁剪操作 options = gdal.WarpOptions( format='GTiff', outputBounds=[x_min, y_min, x_max, y_max], dstSRS=proj ) gdal.Warp(output_path, src_ds, options=options) # 加载裁剪结果 self.load_raster_data(output_path) self.auto_zoom() # 清理资源 src_ds = None vector_ds = None except Exception as e: QMessageBox.critical(self, "错误", f"栅格裁剪失败: {str(e)}") # 新增波段运算方法 def band_calculation(self): """执行波段运算(示例为NDVI计算)""" if not hasattr(self, 'current_raster_path'): QMessageBox.warning(self, "警告", "请先打开栅格数据文件!") return try: # 获取用户输入参数 red_band, ok1 = QInputDialog.getInt( self, "波段选择", "输入红波段编号 (1-based):", 1, 1, 100 ) nir_band, ok2 = QInputDialog.getInt( self, "波段选择", "输入近红外波段编号 (1-based):", 4, 1, 100 ) if not (ok1 and ok2): return # 读取栅格数据 dataset = gdal.Open(self.current_raster_path) red = dataset.GetRasterBand(red_band).ReadAsArray() nir = dataset.GetRasterBand(nir_band).ReadAsArray() # 执行NDVI计算 ndvi = np.where( (nir + red) == 0, 0, (nir - red) / (nir + red) ).astype(np.float32) # 创建输出文件 output_path, _ = QFileDialog.getSaveFileName( self, "保存结果", "", "GeoTIFF (*.tif)" ) if not output_path: return # 写入结果 driver = gdal.GetDriverByName('GTiff') out_ds = driver.Create( output_path, dataset.RasterXSize, dataset.RasterYSize, 1, gdal.GDT_Float32 ) out_ds.SetGeoTransform(dataset.GetGeoTransform()) out_ds.SetProjection(dataset.GetProjection()) out_ds.GetRasterBand(1).WriteArray(ndvi) out_ds.FlushCache() # 清理资源 dataset = None out_ds = None QMessageBox.information(self, "成功", f"NDVI计算结果已保存至:\n{output_path}") except Exception as e: QMessageBox.critical(self, "错误", f"波段运算失败: {str(e)}")
<think>好的,用户需要为现有的GIS软件添加保存功能,特别是保存之前修改过的数据。首先,我需要回顾用户提供的代码,了解当前的数据结构和功能。用户已经实现了矢量数据的加载、缓冲区分析、属性查询、空间查询等功能,现在需要保存这些修改后的数据。
首先,保存功能需要考虑保存哪些数据。用户可能修改了矢量数据(如缓冲区分析生成的缓冲区、质心、空间查询结果等),所以需要将这些修改保存到新的矢量文件中。栅格数据的处理可能涉及裁剪和波段运算,但用户当前的问题主要集中在矢量数据的保存上。
接下来,我需要确定如何将QGraphicsScene中的图形项转换回OGR的几何对象,并保存到文件中。现有的FeatureItem类存储了原始属性和路径,但修改后的图形项(如缓冲区、质心)可能没有这些属性,需要单独处理。
保存功能需要以下几个步骤:
1. 添加保存按钮到工具栏,并连接相应的槽函数。
2. 收集场景中所有需要保存的图形项,包括原始要素、缓冲区、质心等。
3. 将这些图形项转换为OGR的几何对象,并保留属性信息。
4. 创建新的矢量文件,定义图层和字段,将几何和属性写入文件。
需要注意的是,用户可能修改了原始数据,如添加了新的几何图形(缓冲区),这些图形可能没有原始属性,需要生成新的属性或保留原有属性。此外,空间查询结果的高亮显示可能只是视觉上的变化,不需要保存,除非用户明确要保存查询结果。
在代码实现中,需要遍历场景中的所有图形项,区分FeatureItem和其他图形项(如缓冲区、质心)。对于FeatureItem,可以直接使用其存储的属性和几何;对于其他图形项,可能需要生成新的属性或忽略。
还需要处理文件保存的对话框,让用户选择保存路径和文件名,以及选择保存的格式(如Shapefile、GeoJSON等)。
最后,测试保存功能,确保生成的矢量文件正确包含所有修改后的要素,并且属性信息完整。
可能遇到的问题包括几何转换的准确性、属性字段的定义、不同格式的支持等。需要确保使用OGR库正确创建数据源和图层,处理坐标系信息,以及正确写入几何和属性数据。
总结,保存功能的核心是将场景中的图形项转换回矢量数据格式,并保存到文件中。需要处理不同类型的图形项,保留必要的属性,并确保文件格式的正确性。
</think>
以下是添加保存功能的完整代码(在原有代码基础上新增部分):
```python
# 在MainWindow类的init_ui方法中添加(保持原有代码不变)
def init_ui(self):
# ...原有代码...
# 新增保存按钮(放在init_ui方法中)
self.actionSave_Vector = QAction("保存矢量数据", self)
self.toolBar.addAction(self.actionSave_Vector)
self.actionSave_Vector.triggered.connect(self.save_vector_data)
# 新增保存方法
def save_vector_data(self):
"""保存当前场景中的矢量数据"""
if not hasattr(self, 'current_vector_path'):
QMessageBox.warning(self, "警告", "没有可保存的矢量数据!")
return
# 获取保存路径
file_path, _ = QFileDialog.getSaveFileName(
self, "保存矢量文件", "",
"Shapefile (*.shp);;GeoJSON (*.geojson);;All Files (*)"
)
if not file_path:
return
try:
# 获取原始数据源信息
src_ds = ogr.Open(self.current_vector_path)
src_layer = src_ds.GetLayer(0)
src_defn = src_layer.GetLayerDefn()
# 创建目标数据源
driver = ogr.GetDriverByName("ESRI Shapefile" if file_path.endswith(".shp") else "GeoJSON")
if os.path.exists(file_path):
driver.DeleteDataSource(file_path)
dst_ds = driver.CreateDataSource(file_path)
# 创建图层(保持与原始数据相同的坐标系)
dst_layer = dst_ds.CreateLayer(
"saved_features",
srs=src_layer.GetSpatialRef(),
geom_type=ogr.wkbPolygon
)
# 复制字段定义
for i in range(src_defn.GetFieldCount()):
field_defn = src_defn.GetFieldDefn(i)
dst_layer.CreateField(field_defn)
# 遍历场景中的所有要素项
for item in self.scene.items():
if isinstance(item, FeatureItem):
# 创建新要素
feature = ogr.Feature(dst_layer.GetLayerDefn())
# 复制属性
for key, value in item.attributes.items():
if key == "FID":
continue # FID通常自动生成
if feature.GetFieldIndex(key) != -1:
feature.SetField(key, str(value))
# 转换几何
geom = self.item_to_geometry(item)
if geom:
feature.SetGeometry(geom)
dst_layer.CreateFeature(feature)
feature = None
# 添加缓冲区要素(如果存在)
self.save_additional_features(dst_layer, "buffer")
# 添加质心要素(如果存在)
self.save_additional_features(dst_layer, "centroid")
dst_ds = None
src_ds = None
QMessageBox.information(self, "成功", f"数据已保存至:\n{file_path}")
except Exception as e:
QMessageBox.critical(self, "错误", f"保存失败: {str(e)}")
# 新增辅助保存方法
def save_additional_features(self, layer, feature_type):
"""保存附加要素(缓冲区/质心)"""
items = []
if feature_type == "buffer":
items = [item for item in self.scene.items()
if item.pen().style() == Qt.DashLine and item.pen().color() == Qt.red]
elif feature_type == "centroid":
items = self.centroid_items
for item in items:
geom = self.item_to_geometry(item)
if geom:
feature = ogr.Feature(layer.GetLayerDefn())
feature.SetGeometry(geom)
# 添加类型标识字段
feature.SetField("FEATURE_TYPE", feature_type.upper())
layer.CreateFeature(feature)
feature = None
# 在item_to_geometry方法中添加点要素支持
def item_to_geometry(self, item):
"""增强版几何转换(支持点要素)"""
path = item.path()
elements = path.toSubpathPolygons(QTransform())
if not elements:
# 处理点要素
if isinstance(item, QGraphicsPathItem):
path = item.path()
if path.elementCount() == 1 and path.elementAt(0).isMoveTo():
pt = path.elementAt(0)
geom = ogr.Geometry(ogr.wkbPoint)
geom.AddPoint(pt.x, pt.y)
return geom
return None
# 原有多边形处理逻辑
geom = ogr.Geometry(ogr.wkbPolygon)
ring = ogr.Geometry(ogr.wkbLinearRing)
for point in elements[0]:
ring.AddPoint(point.x(), point.y())
ring.CloseRings()
geom.AddGeometry(ring)
return geom
```
阅读全文
相关推荐











