Source code for code_index.index.persist.persist_sqlite

from enum import StrEnum
from pathlib import Path
from pprint import pprint
from typing import Type, TypeVar

from sqlalchemy import (
    Column,
    ForeignKey,
    Integer,
    String,
    Table,
    UniqueConstraint,
    create_engine,
    select,
)
from sqlalchemy.orm import (
    DeclarativeBase,
    Mapped,
    Session,
    mapped_column,
    relationship,
    sessionmaker,
)

from ...models import (
    CodeLocation,
    Definition,
    Function,
    FunctionLikeInfo,
    IndexDataEntry,
    Method,
    PureDefinition,
    PureReference,
    Reference,
    Symbol,
    SymbolDefinition,
    SymbolReference,
)
from ...utils.logger import logger
from ..base import IndexData, PersistStrategy


# --- 1. Database model base class ---
class Base(DeclarativeBase):
    """Base class for all SQLAlchemy ORM models."""

    pass


# --- 2. Many-to-many association table ---
# Association table for definition-reference relationships.
# Since this table only contains foreign keys without additional data,
# we define it as a Table object rather than a full ORM model class.
definition_references_table = Table(
    "definition_references",
    Base.metadata,
    Column("definition_id", ForeignKey("definitions.id"), primary_key=True),
    Column("reference_id", ForeignKey("references.id"), primary_key=True),
)

# --- 3. Core entity models ---


class SymbolType(StrEnum):
    """Enumeration for symbol types in the database."""

    FUNCTION = "FUNCTION"
    METHOD = "METHOD"


class OrmSymbol(Base):
    """Database model for function and method symbols.

    Represents functions and methods with their identifying information.
    Each symbol can have multiple definitions and references.
    """

    __tablename__ = "symbols"

    id: Mapped[int] = mapped_column(primary_key=True)
    name: Mapped[str] = mapped_column(String)
    class_name: Mapped[str | None] = mapped_column(String)
    symbol_type: Mapped[SymbolType] = mapped_column(String)

    # 关系:一个 OrmSymbol 可���有多个 OrmDefinition 和 OrmReference
    definitions: Mapped[list["OrmDefinition"]] = relationship(back_populates="symbol")
    references: Mapped[list["OrmReference"]] = relationship(back_populates="symbol")

    __table_args__ = (UniqueConstraint("name", "class_name", name="uq_symbol_name_class"),)

    def __repr__(self) -> str:
        return f"OrmSymbol(id={self.id!r}, name={self.name!r}, class_name={self.class_name!r})"


class OrmCodeLocation(Base):
    """Database model for code locations.

    Represents the exact location of code elements in source files,
    including line numbers, columns, and byte positions.
    """

    __tablename__ = "code_locations"

    id: Mapped[int] = mapped_column(primary_key=True)
    file_path: Mapped[str] = mapped_column(String)
    start_lineno: Mapped[int] = mapped_column(Integer)
    start_col: Mapped[int] = mapped_column(Integer)
    end_lineno: Mapped[int] = mapped_column(Integer)
    end_col: Mapped[int] = mapped_column(Integer)
    start_byte: Mapped[int] = mapped_column(Integer)
    end_byte: Mapped[int] = mapped_column(Integer)

    # Relationships: code locations can be referenced by multiple definitions and references
    definitions: Mapped[list["OrmDefinition"]] = relationship(back_populates="location")
    references: Mapped[list["OrmReference"]] = relationship(back_populates="location")

    def __repr__(self) -> str:
        return (
            f"OrmCodeLocation(id={self.id!r}, path={self.file_path!r}, line={self.start_lineno!r})"
        )


class OrmDefinition(Base):
    """Database model for symbol definitions.

    Links symbols to their definition locations in the codebase.
    """

    __tablename__ = "definitions"

    id: Mapped[int] = mapped_column(primary_key=True)
    symbol_id: Mapped[int] = mapped_column(ForeignKey("symbols.id"))
    location_id: Mapped[int] = mapped_column(ForeignKey("code_locations.id"))

    # Relationships: definitions belong to symbols and locations
    symbol: Mapped["OrmSymbol"] = relationship(back_populates="definitions")
    location: Mapped["OrmCodeLocation"] = relationship(back_populates="definitions")

    # Relationships: definitions can contain multiple references (many-to-many)
    internal_references: Mapped[list["OrmReference"]] = relationship(
        secondary=definition_references_table, back_populates="callers"
    )

    def __repr__(self) -> str:
        return f"OrmDefinition(id={self.id!r}, symbol_id={self.symbol_id!r})"


class OrmReference(Base):
    """Database model for symbol references.

    Links symbols to their reference/usage locations in the codebase.
    """

    __tablename__ = "references"

    id: Mapped[int] = mapped_column(primary_key=True)
    symbol_id: Mapped[int] = mapped_column(ForeignKey("symbols.id"))
    location_id: Mapped[int] = mapped_column(ForeignKey("code_locations.id"))

    # Relationships: references point to symbols and occur at locations
    symbol: Mapped["OrmSymbol"] = relationship(back_populates="references")
    location: Mapped["OrmCodeLocation"] = relationship(back_populates="references")

    # Relationships: references can be contained in multiple definitions
    callers: Mapped[list["OrmDefinition"]] = relationship(
        secondary=definition_references_table, back_populates="internal_references"
    )

    def __repr__(self) -> str:
        return f"OrmReference(id={self.id!r}, symbol_id={self.symbol_id!r})"


class OrmMetadata(Base):
    """Database model for storing index metadata."""

    __tablename__ = "metadata"

    index_type: Mapped[str] = mapped_column(primary_key=True)


T = TypeVar("T", bound=Base)


# Helper function
def get_or_create(session: Session, model_cls: Type[T], **kwargs) -> tuple[T, bool]:
    """Gets an existing ORM instance or creates a new one.

    Args:
        session: SQLAlchemy session object.
        model_cls: ORM model class.
        **kwargs: Field parameters for querying or creating.

    Returns:
        A tuple of (instance, created) where created is True if the
        instance was newly created, False if it already existed.
    """
    instance = session.execute(select(model_cls).filter_by(**kwargs)).scalar_one_or_none()
    if instance is not None:
        return instance, False
    else:
        instance = model_cls(**kwargs)
        session.add(instance)
        return instance, True


[docs] class SqlitePersistStrategy(PersistStrategy): """SQLite database persistence strategy for index data. Stores index data in a SQLite database with proper relational structure. Supports both file-based and in-memory databases. Note: Currently does not support the LLM Note feature. Check `models.py` for details. """
[docs] def __init__(self): """Initializes the SQLite persistence strategy.""" super().__init__() logger.debug("Initialized SqlitePersistStrategy")
[docs] def __repr__(self): """Returns a string representation of the persistence strategy.""" return f"{self.__class__.__name__}()"
[docs] def get_engine(self, path: Path | None = None, make_empty_db: bool = False): """Gets the SQLite database engine. Args: path: Database file path. If None, uses in-memory database. make_empty_db: If True, removes existing database file if it exists. Returns: SQLAlchemy engine for the database. """ if path is None: return create_engine("sqlite:///:memory:") if path.exists() and path.is_dir(): path = path / "index.sqlite" # 创建父目录而不是文件路径本身 path.parent.mkdir(parents=True, exist_ok=True) # if the path is an existing file, remove it (rename it to avoid conflicts) if make_empty_db and path.exists() and path.is_file(): # add a .bak suffix. e.g. index.sqlite -> index.sqlite.bak backup_path = path.with_name(f"{path.name}.bak") path.rename(backup_path) logger.warning(f"Existing file {path} renamed to {backup_path}") return create_engine(f"sqlite:///{str(path.resolve())}")
@staticmethod def _func_like_as_criteria(func_like: Symbol) -> dict: match func_like: case Function(name=name): return {"name": name, "class_name": None, "symbol_type": SymbolType.FUNCTION} case Method(name=name, class_name=class_name): return {"name": name, "class_name": class_name, "symbol_type": SymbolType.METHOD} raise ValueError( f"Unsupported Symbol type: {type(func_like)}. Expected Function or Method." )
[docs] @staticmethod def _location_as_criteria(location: CodeLocation) -> dict: """ 将 CodeLocation 转换为查询条件字典。 """ return { "file_path": str(location.file_path), "start_lineno": location.start_lineno, "start_col": location.start_col, "end_lineno": location.end_lineno, "end_col": location.end_col, "start_byte": location.start_byte, "end_byte": location.end_byte, }
def _handle_definition_for_symbol( self, session: Session, symbol_db: OrmSymbol, definition_dc: Definition ): # make location loc_db, _ = get_or_create( session, OrmCodeLocation, **self._location_as_criteria(definition_dc.location), ) # make definition definition_db, _ = get_or_create( session, OrmDefinition, symbol=symbol_db, # this should add this definition to the symbol's definitions location=loc_db, ) # handle what this definition calls for func_ref in definition_dc.calls: called_symbol_dc: Symbol = func_ref.symbol called_reference_dc: PureReference = func_ref.reference # make called symbol called_symbol_db, _ = get_or_create( session, OrmSymbol, **self._func_like_as_criteria(called_symbol_dc), ) # make called location called_location_db, _ = get_or_create( session, OrmCodeLocation, **self._location_as_criteria(called_reference_dc.location), ) # make called reference called_reference_db, _ = get_or_create( session, OrmReference, symbol=called_symbol_db, location=called_location_db, ) # add the reference to the definition-reference relationship if called_reference_db not in definition_db.internal_references: definition_db.internal_references.append(called_reference_db) def _handle_reference_for_symbol( self, session: Session, symbol_db: OrmSymbol, reference_dc: Reference ): # make location loc_db, _ = get_or_create( session, OrmCodeLocation, **self._location_as_criteria(reference_dc.location), ) # make reference reference_db, _ = get_or_create( session, OrmReference, symbol=symbol_db, # this should add this reference to the symbol's references location=loc_db, ) for func_def in reference_dc.called_by: caller_symbol_dc: Symbol = func_def.symbol caller_definition_dc: PureDefinition = func_def.definition # make caller symbol caller_symbol_db, _ = get_or_create( session, OrmSymbol, **self._func_like_as_criteria(caller_symbol_dc), ) # make caller location caller_location_db, _ = get_or_create( session, OrmCodeLocation, **self._location_as_criteria(caller_definition_dc.location), ) # make caller definition caller_definition_db, _ = get_or_create( session, OrmDefinition, symbol=caller_symbol_db, location=caller_location_db, ) # add the reference to the reference-definition relationship if caller_definition_db not in reference_db.callers: reference_db.callers.append(caller_definition_db) def _handle_entry(self, session: Session, entry: IndexDataEntry): symbol_dc: Symbol = entry.symbol info_dc: FunctionLikeInfo = entry.info # make symbol symbol_criteria = self._func_like_as_criteria(symbol_dc) symbol_db, _ = get_or_create(session, OrmSymbol, **symbol_criteria) # handle info of this symbol for definition_dc in info_dc.definitions: self._handle_definition_for_symbol(session, symbol_db, definition_dc) # handle references of this symbol for reference_dc in info_dc.references: self._handle_reference_for_symbol(session, symbol_db, reference_dc) def _save(self, data: IndexData, session: Session): index_type = data.type index_data = data.data # save metadata metadata = OrmMetadata(index_type=index_type) session.add(metadata) for entry in index_data: self._handle_entry(session, entry)
[docs] def save(self, data: IndexData, path: Path): """ 将索引数据保存到 SQLite 数据库。 :param data: 要保存的索引数据字典 :param path: 保存数据库文件的路径 """ engine = self.get_engine(path, make_empty_db=True) logger.debug("Created engine at {}", path) Base.metadata.create_all(engine) session_maker = sessionmaker(bind=engine) session = session_maker() try: self._save(data, session) session.commit() except Exception as e: session.rollback() raise RuntimeError(f"保存索引数据到 SQLite 数据库时出错:{e}") finally: session.close()
[docs] @staticmethod def _make_function_like(symbol_db: OrmSymbol) -> Symbol: """ 根据 OrmSymbol 创建 Symbol 对象。 """ if symbol_db.symbol_type == SymbolType.FUNCTION: return Function(name=symbol_db.name) elif symbol_db.symbol_type == SymbolType.METHOD: return Method(name=symbol_db.name, class_name=symbol_db.class_name) else: raise ValueError(f"Unsupported symbol type: {symbol_db.symbol_type}")
def _handle_load_pure_reference( self, _session: Session, ref_db: OrmReference, ) -> PureReference: loc_db = ref_db.location location = CodeLocation( file_path=Path(loc_db.file_path), start_lineno=loc_db.start_lineno, start_col=loc_db.start_col, end_lineno=loc_db.end_lineno, end_col=loc_db.end_col, start_byte=loc_db.start_byte, end_byte=loc_db.end_byte, ) return PureReference(location=location) def _handle_load_pure_definition( self, session: Session, def_db: OrmDefinition, ) -> PureDefinition: # get the location for this definition loc_db = def_db.location # create the definition object return PureDefinition( location=CodeLocation( file_path=Path(loc_db.file_path), start_lineno=loc_db.start_lineno, start_col=loc_db.start_col, end_lineno=loc_db.end_lineno, end_col=loc_db.end_col, start_byte=loc_db.start_byte, end_byte=loc_db.end_byte, ) ) def _handle_load_reference(self, _session: Session, ref_db: OrmReference) -> Reference: loc_db = ref_db.location location = CodeLocation( file_path=Path(loc_db.file_path), start_lineno=loc_db.start_lineno, start_col=loc_db.start_col, end_lineno=loc_db.end_lineno, end_col=loc_db.end_col, start_byte=loc_db.start_byte, end_byte=loc_db.end_byte, ) called_by: list[SymbolDefinition] = [] for def_db in ref_db.callers: called_by.append( SymbolDefinition( symbol=self._make_function_like(def_db.symbol), definition=self._handle_load_pure_definition(_session, def_db), ) ) return Reference(location=location, called_by=called_by) def _handle_load_definition(self, session: Session, def_db: OrmDefinition) -> Definition: # get the location for this definition loc_db = def_db.location location = CodeLocation( file_path=Path(loc_db.file_path), start_lineno=loc_db.start_lineno, start_col=loc_db.start_col, end_lineno=loc_db.end_lineno, end_col=loc_db.end_col, start_byte=loc_db.start_byte, end_byte=loc_db.end_byte, ) # handle what this definition calls calls: list[SymbolReference] = [] for ref_db in def_db.internal_references: calls.append( SymbolReference( symbol=self._make_function_like(ref_db.symbol), reference=self._handle_load_pure_reference(session, ref_db), ) ) # create the definition object return Definition(location=location, calls=calls) def _handle_load_info_for_symbol( self, session: Session, symbol_db: OrmSymbol ) -> FunctionLikeInfo: # get the definitions for this symbol definitions: list[Definition] = [] for def_db in symbol_db.definitions: definition = self._handle_load_definition(session, def_db) definitions.append(definition) # get the references for this symbol references: list[Reference] = [] for ref_db in symbol_db.references: reference = self._handle_load_reference(session, ref_db) references.append(reference) # create the FunctionLikeInfo object return FunctionLikeInfo( definitions=definitions, references=references, )
[docs] def _load(self, session: Session) -> IndexData: """ 从 SQLAlchemy 会话中加载索引数据。 :param session: SQLAlchemy 会话对象 :return: IndexData 对象 """ metadata = session.query(OrmMetadata).one_or_none() if metadata is None: raise ValueError("数据库中没有找到索引元数据。") index_type: str = metadata.index_type # type: ignore symbols = session.scalars(select(OrmSymbol)).all() entries = [] for symbol in symbols: entries.append( IndexDataEntry( symbol=self._make_function_like(symbol), info=self._handle_load_info_for_symbol(session, symbol), ) ) return IndexData(type=index_type, data=entries)
[docs] def load(self, path: Path) -> IndexData: """ 从 SQLite 数据库加载索引数据。 :param path: SQLite 数据库文件的路径 :return: IndexData 对象 """ engine = self.get_engine(path) logger.debug("Created engine at {}", path) session_maker = sessionmaker(bind=engine) session = session_maker() try: return self._load(session) except Exception as e: raise RuntimeError(f"从 SQLite 数据库加载索引数据时出错:{e}") finally: session.close()
def demo_orm(): # 创建一个内存中的 SQLite 数据库引擎用于演示 engine = create_engine("sqlite:///:memory:") # 根据我们定义的模型,在数据库中创建所有表 Base.metadata.create_all(engine) # 创建一个 Session 类,用于与数据库交互 session_maker = sessionmaker(bind=engine) session = session_maker() # --- 创建一些示例数据 --- session.add(OrmMetadata(index_type="demo_index")) # 1. 创建符号 main_func_symbol = OrmSymbol(name="main", symbol_type=SymbolType.FUNCTION) helper_func_symbol = OrmSymbol(name="helper_func", symbol_type=SymbolType.FUNCTION) # 2. 创建位置 main_loc = OrmCodeLocation( file_path="main.c", start_lineno=10, start_col=1, end_lineno=15, end_col=1, start_byte=11, end_byte=45, ) helper_loc = OrmCodeLocation( file_path="main.c", start_lineno=1, start_col=1, end_lineno=5, end_col=1, start_byte=2, end_byte=30, ) call_loc = OrmCodeLocation( file_path="main.c", start_lineno=12, start_col=5, end_lineno=12, end_col=18, start_byte=50, end_byte=70, ) # 3. 创建定义 main_def = OrmDefinition(symbol=main_func_symbol, location=main_loc) helper_def = OrmDefinition(symbol=helper_func_symbol, location=helper_loc) # 4. 创建引用 helper_ref = OrmReference(symbol=helper_func_symbol, location=call_loc) # 5. 建立调用关系:main 函数内部调用了 helper_func main_def.internal_references.append(helper_ref) # 将所有对象添加到 session 中 session.add_all([main_func_symbol, helper_func_symbol, main_def, helper_def, helper_ref]) # 提交事务,将数据写入数据库 session.commit() # --- 查询数据 --- print("--- 查询数据库 ---") # 查找 main 函数的定义 retrieved_main_def = ( session.query(OrmDefinition).filter(OrmDefinition.symbol.has(name="main")).one() ) # 打印 main 函数内部调用的函数名 print(f"函数 '{retrieved_main_def.symbol.name}' 内部调用了:") for ref in retrieved_main_def.internal_references: print(f" - '{ref.symbol.name}' (位于行 {ref.location.start_lineno})") session.close() # 读取数据 persist_strategy = SqlitePersistStrategy() load_session = session_maker() try: index_data = persist_strategy._load(load_session) print("--- 读取索引数据 ---") pprint(index_data) except Exception as e: print(f"读取索引数据时出错: {e}") finally: load_session.close() # --- 4. 示例:如何使用这些模型 --- if __name__ == "__main__": demo_orm()