diff --git a/nuplan/database/nuplan_db/nuplan_scenario_queries.py b/nuplan/database/nuplan_db/nuplan_scenario_queries.py index b717e9fa..96b51dc4 100644 --- a/nuplan/database/nuplan_db/nuplan_scenario_queries.py +++ b/nuplan/database/nuplan_db/nuplan_scenario_queries.py @@ -24,6 +24,7 @@ from nuplan.database.nuplan_db.query_session import execute_many, execute_one from nuplan.database.nuplan_db.sensor_data_table_row import SensorDataTableRow from nuplan.database.utils.label.utils import local2agent_type, raw_mapping +from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling def _parse_tracked_object_row(row: sqlite3.Row) -> TrackedObject: @@ -447,6 +448,66 @@ def get_sampled_lidarpcs_from_db( yield LidarPc.from_db_row(row) +def get_sampled_lidarpcs_from_db_batch( + log_file: str, + initial_token: str, + sensor_source: SensorDataSource, + sample_indexes: List[int], + future: bool +) -> List[LidarPc]: + if not sample_indexes: + return [] + + sensor_token = get_sensor_token(log_file, sensor_source.sensor_table, sensor_source.channel) + + order_direction = "ASC" if future else "DESC" + order_cmp = ">=" if future else "<=" + + query = f""" + WITH initial_lidarpc AS + ( + SELECT token, timestamp + FROM lidar_pc + WHERE token = ? + ), + ordered AS + ( + SELECT lp.token, + lp.next_token, + lp.prev_token, + lp.ego_pose_token, + lp.lidar_token, + lp.scene_token, + lp.filename, + lp.timestamp, + ROW_NUMBER() OVER (ORDER BY lp.timestamp {order_direction}) AS row_num + FROM lidar_pc AS lp + CROSS JOIN initial_lidarpc AS il + WHERE lp.timestamp {order_cmp} il.timestamp + AND lp.lidar_token = ? + ) + SELECT token, + next_token, + prev_token, + ego_pose_token, + lidar_token, + scene_token, + filename, + timestamp + FROM ordered + + -- ROW_NUMBER() starts at 1, where consumers will expect sample_indexes to be 0-indexed + WHERE (row_num - 1) IN ({('?,'*len(sample_indexes))[:-1]}) + + ORDER BY timestamp ASC; + """ + + args = [bytearray.fromhex(initial_token), bytearray.fromhex(sensor_token)] + sample_indexes # type: ignore + rows = execute_many(query, args, log_file) + return [LidarPc.from_db_row(row) for row in rows] + + + def get_sampled_ego_states_from_db( log_file: str, initial_token: str, @@ -778,6 +839,56 @@ def get_future_waypoints_for_agents_from_db( yield (row["track_token"].hex(), Waypoint(TimePoint(row["timestamp"]), oriented_box, velocity)) +def get_future_waypoints_for_agents_from_db_optimized( + log_file: str, track_tokens: List[str], start_timestamp: int, future_trajectory_sampling: TrajectorySampling +) -> Generator[Tuple[str, Waypoint], None, None]: + """ + Obtain the future waypoints for the selected agents from the DB in the provided time window, + taking into account the sampling interval for future waypoints. + + :param log_file: The log file to query. + :param track_tokens: The track_tokens for which to query. + :param start_timestamp: The starting timestamp for which to query. + :param future_trajectory_sampling: The trajectory sampling strategy. + :return: A generator of tuples of (track_token, Waypoint), sorted by track_token, then by timestamp in ascending order. + """ + interval_microseconds = int(1e6 * future_trajectory_sampling.interval_length) + end_timestamp = start_timestamp + int(1e6 * future_trajectory_sampling.time_horizon) + + # Adjust the query to return waypoints based on the specified interval. + # The following SQL is an example and might need adjustments based on the actual schema. + query = f""" + WITH RECURSIVE sampled_timestamps(ts) AS ( + SELECT ? UNION ALL + SELECT ts + ? FROM sampled_timestamps + WHERE ts + ? <= ? + ) + SELECT + lb.x, lb.y, lb.z, lb.yaw, lb.width, lb.length, lb.height, lb.vx, lb.vy, lb.track_token, lp.timestamp + FROM + lidar_box AS lb + INNER JOIN + lidar_pc AS lp ON lp.token = lb.lidar_pc_token + INNER JOIN + sampled_timestamps st ON lp.timestamp >= st.ts AND lp.timestamp < st.ts + ? + WHERE + lp.timestamp >= ? AND lp.timestamp <= ? AND lb.track_token IN ({('?,' * len(track_tokens))[:-1]}) + ORDER BY + lb.track_token ASC, lp.timestamp ASC; + """ + args = [start_timestamp, interval_microseconds, interval_microseconds, end_timestamp, interval_microseconds, start_timestamp, end_timestamp] + [bytearray.fromhex(t) for t in track_tokens] + + for row in execute_many(query, args, log_file): + # 直接在这里解析行数据,创建Waypoint对象 + pose = StateSE2(row["x"], row["y"], row["yaw"]) + oriented_box = OrientedBox(pose, width=row["width"], length=row["length"], height=row["height"]) + velocity = StateVector2D(row["vx"], row["vy"]) + waypoint = Waypoint(TimePoint(row["timestamp"]), oriented_box, velocity) + + # 产生(track_token, Waypoint)对 + yield (row["track_token"].hex(), waypoint) + + def get_scenarios_from_db( log_file: str, filter_tokens: Optional[List[str]], diff --git a/nuplan/database/nuplan_db/query_session.py b/nuplan/database/nuplan_db/query_session.py index 2ef42d1f..6d51d89c 100644 --- a/nuplan/database/nuplan_db/query_session.py +++ b/nuplan/database/nuplan_db/query_session.py @@ -1,57 +1,82 @@ import sqlite3 from typing import Any, Generator, Optional +from collections import OrderedDict +memory_dbs = OrderedDict() +MAX_CACHE_SIZE = 5 # 允许的最大缓存连接数 -def execute_many(query_text: str, query_parameters: Any, db_file: str) -> Generator[sqlite3.Row, None, None]: +def get_or_copy_db_to_memory(db_file: str) -> sqlite3.Connection: """ - Runs a query with the provided arguments on a specified Sqlite DB file. - This query can return any number of rows. + Get an existing in-memory database connection or copy the SQLite DB file to an in-memory database if not exists. + Manages cache size to not exceed MAX_CACHE_SIZE by removing the least recently used (LRU) connection. + :param db_file: The DB file to check or copy to memory. + :return: A connection to the in-memory database. + """ + # 如果已缓存,则将其移动到字典的末尾以标记为最近使用 + print("memory dbs: {}, current,{}".format(memory_dbs,db_file) ) + if db_file in memory_dbs: + memory_dbs.move_to_end(db_file) + return memory_dbs[db_file] + + # 如果达到最大缓存大小,则删除最早的项 + if len(memory_dbs) >= MAX_CACHE_SIZE: + oldest_db_file, oldest_conn = memory_dbs.popitem(last=False) # 删除第一个添加的项 + oldest_conn.close() + print(f"Closed and removed the oldest DB from cache: {oldest_db_file}") + + # 创建新的内存数据库连接 + disk_connection = sqlite3.connect(db_file) + mem_connection = sqlite3.connect(':memory:') + disk_connection.backup(mem_connection) # Requires Python 3.7+ + disk_connection.close() + + # 添加到缓存并返回 + memory_dbs[db_file] = mem_connection + return mem_connection + + +def execute_many(query_text: str, query_parameters: Any, db_file: str, use_mem = True) -> Generator[sqlite3.Row, None, None]: + """ + Runs a query on a specified Sqlite DB file, preferably using an in-memory copy for improved speed. :param query_text: The query to run. :param query_parameters: The parameters to provide to the query. - :param db_file: The DB file on which to run the query. + :param db_file: The DB file to use, copying to memory if not already done. :return: A generator of rows emitted from the query. """ - # Caching a connection saves around 600 uS for local databases. - # By making it stateless, we get isolation, which is a huge plus. - connection = sqlite3.connect(db_file) + if use_mem: + connection = get_or_copy_db_to_memory(db_file) + else: + connection = sqlite3.connect(db_file) + connection.row_factory = sqlite3.Row cursor = connection.cursor() try: cursor.execute(query_text, query_parameters) - for row in cursor: yield row finally: cursor.close() - connection.close() - + # Do not close the connection here to reuse it def execute_one(query_text: str, query_parameters: Any, db_file: str) -> Optional[sqlite3.Row]: """ - Runs a query with the provided arguments on a specified Sqlite DB file. - Validates that the query returns at most one row. + Runs a query on a specified Sqlite DB file, preferably using an in-memory copy for improved speed. :param query_text: The query to run. :param query_parameters: The parameters to provide to the query. - :param db_file: The DB file on which to run the query. + :param db_file: The DB file to use, copying to memory if not already done. :return: The returned row, if it exists. None otherwise. """ - # Caching a connection saves around 600 uS for local databases. - # By making it stateless, we get isolation, which is a huge plus. - connection = sqlite3.connect(db_file) + connection = get_or_copy_db_to_memory(db_file) connection.row_factory = sqlite3.Row cursor = connection.cursor() try: cursor.execute(query_text, query_parameters) - result: Optional[sqlite3.Row] = cursor.fetchone() - - # Check for more rows. If more exist, throw an error. if result is not None and cursor.fetchone() is not None: raise RuntimeError("execute_one query returned multiple rows.") - return result finally: cursor.close() - connection.close() + # Do not close the connection here to reuse it diff --git a/nuplan/planning/scenario_builder/nuplan_db/nuplan_scenario.py b/nuplan/planning/scenario_builder/nuplan_db/nuplan_scenario.py index 0fbfb6fd..8a37c2d4 100644 --- a/nuplan/planning/scenario_builder/nuplan_db/nuplan_scenario.py +++ b/nuplan/planning/scenario_builder/nuplan_db/nuplan_scenario.py @@ -3,6 +3,7 @@ import os from functools import cached_property from pathlib import Path +import time from typing import Any, Generator, List, Optional, Set, Tuple, Type, cast from nuplan.common.actor_state.ego_state import EgoState @@ -24,6 +25,7 @@ get_roadblock_ids_for_lidarpc_token_from_db, get_sampled_ego_states_from_db, get_sampled_lidarpcs_from_db, + get_sampled_lidarpcs_from_db_batch, get_sensor_data_from_sensor_data_tokens_from_db, get_sensor_data_token_timestamp_from_db, get_sensor_transform_matrix_for_sensor_data_token_from_db, @@ -361,11 +363,22 @@ def get_future_tracked_objects( time_horizon: float, num_samples: Optional[int] = None, future_trajectory_sampling: Optional[TrajectorySampling] = None, - ) -> Generator[DetectionsTracks, None, None]: + ) -> List[DetectionsTracks]: + start_time = time.time() """Inherited, see superclass.""" - # TODO: This can be made even more efficient with a batch query - for lidar_pc in self._find_matching_lidar_pcs(iteration, num_samples, time_horizon, True): - yield DetectionsTracks(extract_tracked_objects(lidar_pc.token, self._log_file, future_trajectory_sampling)) + lidar_pcs = self._find_matching_lidar_pcs_batch(iteration, num_samples, time_horizon, True) + mid_time = time.time() + print(f'执行 _find_matching_lidar_pcs_batch 用时: {(mid_time - start_time) * 1000} 毫秒') + detections_tracks = [] + detections_tracks = [ + DetectionsTracks(extract_tracked_objects(lidar_pc.token, self._log_file, future_trajectory_sampling)) + for lidar_pc in lidar_pcs + ] + end_time = time.time() + print(f'生成所有 DetectionsTracks 对象用时: {(end_time - mid_time) * 1000} 毫秒') + print(f'总函数执行用时: {(end_time - start_time) * 1000} 毫秒') + return detections_tracks + def get_past_sensors( self, @@ -446,6 +459,19 @@ def _find_matching_lidar_pcs( self._log_file, self._lidarpc_tokens[iteration], get_lidarpc_sensor_data(), indices, look_into_future ), ) + + def _find_matching_lidar_pcs_batch( + self, iteration: int, num_samples: Optional[int], time_horizon: float, look_into_future: bool + ) -> List[LidarPc]: + num_samples = num_samples if num_samples else int(time_horizon / self.database_interval) + indices = sample_indices_with_time_horizon(num_samples, time_horizon, self._database_row_interval) + + # 将生成器转换为批量查询 + lidarpcs = get_sampled_lidarpcs_from_db_batch( + self._log_file, self._lidarpc_tokens[iteration], get_lidarpc_sensor_data(), indices, look_into_future + ) + return list(lidarpcs) # 确保返回一个列表 + def _extract_expert_trajectory(self, max_future_seconds: int = 60) -> Generator[EgoState, None, None]: """ diff --git a/nuplan/planning/scenario_builder/nuplan_db/nuplan_scenario_utils.py b/nuplan/planning/scenario_builder/nuplan_db/nuplan_scenario_utils.py index b60a8afd..fb5e82e1 100644 --- a/nuplan/planning/scenario_builder/nuplan_db/nuplan_scenario_utils.py +++ b/nuplan/planning/scenario_builder/nuplan_db/nuplan_scenario_utils.py @@ -19,6 +19,7 @@ from nuplan.database.nuplan_db.nuplan_db_utils import SensorDataSource, get_lidarpc_sensor_data from nuplan.database.nuplan_db.nuplan_scenario_queries import ( get_future_waypoints_for_agents_from_db, + get_future_waypoints_for_agents_from_db_optimized, get_sampled_sensor_tokens_in_time_window_from_db, get_sensor_data_token_timestamp_from_db, get_tracked_objects_for_lidarpc_token_from_db, @@ -336,50 +337,54 @@ def extract_tracked_objects( future_trajectory_sampling: Optional[TrajectorySampling] = None, ) -> TrackedObjects: """ - Extracts all boxes from a lidarpc. - :param lidar_pc: Input lidarpc. - :param future_trajectory_sampling: If provided, the future trajectory sampling to use for future waypoints. - :return: Tracked objects contained in the lidarpc. + Extracts all boxes from a lidarpc, considering future trajectory sampling if provided. """ tracked_objects: List[TrackedObject] = [] agent_indexes: Dict[str, int] = {} - agent_future_trajectories: Dict[str, List[Waypoint]] = {} + # 获取当前lidar点云对应的所有追踪对象 for idx, tracked_object in enumerate(get_tracked_objects_for_lidarpc_token_from_db(log_file, token)): if future_trajectory_sampling and isinstance(tracked_object, Agent): agent_indexes[tracked_object.metadata.track_token] = idx - agent_future_trajectories[tracked_object.metadata.track_token] = [] tracked_objects.append(tracked_object) - if future_trajectory_sampling and len(tracked_objects) > 0: + if future_trajectory_sampling: timestamp_time = get_sensor_data_token_timestamp_from_db(log_file, get_lidarpc_sensor_data(), token) + if timestamp_time is None: + return TrackedObjects(tracked_objects=tracked_objects) + end_time = timestamp_time + int( 1e6 * (future_trajectory_sampling.time_horizon + future_trajectory_sampling.interval_length) ) - # TODO: This is somewhat inefficient because the resampling should happen in SQL layer - for track_token, waypoint in get_future_waypoints_for_agents_from_db( - log_file, list(agent_indexes.keys()), timestamp_time, end_time - ): + # 使用优化后的方式获取未来轨迹点 + future_waypoints = get_future_waypoints_for_agents_from_db_optimized( + log_file, list(agent_indexes.keys()), timestamp_time, future_trajectory_sampling + ) + + # 重新组织未来轨迹点数据,按照追踪对象的token组织 + agent_future_trajectories = {track_token: [] for track_token in agent_indexes} + for track_token, waypoint in future_waypoints: agent_future_trajectories[track_token].append(waypoint) - for key in agent_future_trajectories: - # We can only interpolate waypoints if there is more than one in the future. - if len(agent_future_trajectories[key]) == 1: - tracked_objects[agent_indexes[key]]._predictions = [ - PredictedTrajectory(1.0, agent_future_trajectories[key]) - ] - elif len(agent_future_trajectories[key]) > 1: - tracked_objects[agent_indexes[key]]._predictions = [ + # 根据获取到的未来轨迹点更新追踪对象的预测轨迹 + for track_token, waypoints in agent_future_trajectories.items(): + idx = agent_indexes[track_token] + if len(waypoints) > 1: # 只有当存在多个未来轨迹点时才进行插值 + tracked_objects[idx]._predictions = [ PredictedTrajectory( - 1.0, + 1.0, # 假设置信度为1.0 interpolate_future_waypoints( - agent_future_trajectories[key], + waypoints, future_trajectory_sampling.time_horizon, future_trajectory_sampling.interval_length, ), ) ] + elif len(waypoints) == 1: + tracked_objects[idx]._predictions = [ + PredictedTrajectory(1.0, waypoints) + ] return TrackedObjects(tracked_objects=tracked_objects)