230 lines
7.1 KiB
TypeScript
230 lines
7.1 KiB
TypeScript
import Database from 'better-sqlite3';
|
||
import { ToolDef } from '../../llm/openai-compat.js';
|
||
import type { ToolContext, ToolResult } from './core.js';
|
||
import { resolveAndGuard } from './core.js';
|
||
import { logger } from '../../logger.js';
|
||
|
||
// --- SQL statement analysis ---
|
||
|
||
// Extract the first keyword from a SQL statement (trimmed, uppercased)
|
||
function firstKeyword(sql: string): string {
|
||
return sql.trim().split(/\s+/)[0]?.toUpperCase() ?? '';
|
||
}
|
||
|
||
// DDL keywords that are always forbidden
|
||
const ALWAYS_BLOCKED = new Set(['DROP', 'ALTER', 'ATTACH', 'DETACH', 'REINDEX', 'VACUUM']);
|
||
|
||
// Compound blocked patterns: "CREATE INDEX", "PRAGMA xxx" (except table_info/table_list)
|
||
function isStatementBlocked(sql: string): { blocked: boolean; reason?: string } {
|
||
const trimmed = sql.trim();
|
||
if (trimmed.length === 0) return { blocked: false };
|
||
|
||
const kw = firstKeyword(trimmed);
|
||
|
||
if (ALWAYS_BLOCKED.has(kw)) {
|
||
return { blocked: true, reason: `"${kw}" statements are not allowed` };
|
||
}
|
||
|
||
if (kw === 'CREATE') {
|
||
// Allow CREATE TABLE / CREATE VIEW but block CREATE INDEX / CREATE TRIGGER
|
||
const secondKw = trimmed.trim().split(/\s+/)[1]?.toUpperCase() ?? '';
|
||
const thirdKw = trimmed.trim().split(/\s+/)[2]?.toUpperCase() ?? '';
|
||
// CREATE UNIQUE INDEX is also blocked
|
||
if (secondKw === 'INDEX' || (secondKw === 'UNIQUE' && thirdKw === 'INDEX') || secondKw === 'TRIGGER') {
|
||
return { blocked: true, reason: `"CREATE ${secondKw}" is not allowed` };
|
||
}
|
||
}
|
||
|
||
if (kw === 'PRAGMA') {
|
||
// Allow only PRAGMA table_info and PRAGMA table_list
|
||
const rest = trimmed.slice('PRAGMA'.length).trim().toLowerCase().split(/[\s(]/)[0] ?? '';
|
||
if (rest !== 'table_info' && rest !== 'table_list') {
|
||
return { blocked: true, reason: `PRAGMA "${rest}" is not allowed. Only table_info and table_list are permitted` };
|
||
}
|
||
}
|
||
|
||
return { blocked: false };
|
||
}
|
||
|
||
// Split on semicolons, filtering out empty statements
|
||
function splitStatements(sql: string): string[] {
|
||
return sql
|
||
.split(';')
|
||
.map((s) => s.trim())
|
||
.filter((s) => s.length > 0);
|
||
}
|
||
|
||
// Check if all statements in a multi-statement SQL are SELECT (or allowed non-DML)
|
||
function allAreSelect(statements: string[]): boolean {
|
||
return statements.every((s) => firstKeyword(s) === 'SELECT');
|
||
}
|
||
|
||
// Check for any write operations
|
||
const WRITE_KEYWORDS = new Set(['INSERT', 'UPDATE', 'DELETE', 'REPLACE', 'UPSERT']);
|
||
|
||
function isWriteStatement(sql: string): boolean {
|
||
return WRITE_KEYWORDS.has(firstKeyword(sql));
|
||
}
|
||
|
||
// --- Format result as text table ---
|
||
|
||
function formatTable(rows: Record<string, unknown>[]): string {
|
||
if (rows.length === 0) return '(0 rows)';
|
||
|
||
const columns = Object.keys(rows[0]!);
|
||
const colWidths = columns.map((col) => {
|
||
const maxVal = rows.reduce((max, row) => {
|
||
const val = String(row[col] ?? 'NULL');
|
||
return Math.max(max, val.length);
|
||
}, 0);
|
||
return Math.max(col.length, maxVal);
|
||
});
|
||
|
||
const header = columns.map((col, i) => col.padEnd(colWidths[i]!)).join(' | ');
|
||
const separator = colWidths.map((w) => '-'.repeat(w)).join('-+-');
|
||
const rowLines = rows.map((row) =>
|
||
columns.map((col, i) => String(row[col] ?? 'NULL').padEnd(colWidths[i]!)).join(' | '),
|
||
);
|
||
|
||
const lines = [header, separator, ...rowLines];
|
||
lines.push(`(${rows.length} row${rows.length === 1 ? '' : 's'})`);
|
||
return lines.join('\n');
|
||
}
|
||
|
||
// --- Tool definition ---
|
||
|
||
const SQLITE_DEF: ToolDef = {
|
||
type: 'function',
|
||
function: {
|
||
name: 'SQLite',
|
||
description: 'SQLite DB にクエリを実行する(edit=false 時は SELECT のみ)。詳細は ReadToolDoc({ name: "SQLite" })。',
|
||
parameters: {
|
||
type: 'object',
|
||
properties: {
|
||
query: { type: 'string', description: 'SQL クエリ' },
|
||
db_path: { type: 'string', description: 'DB ファイルパス (省略時は workspace 内の temp.db)' },
|
||
},
|
||
required: ['query'],
|
||
},
|
||
},
|
||
};
|
||
|
||
export const TOOL_DEFS: Record<string, ToolDef> = {
|
||
SQLite: SQLITE_DEF,
|
||
};
|
||
|
||
// --- Tool execution ---
|
||
|
||
function executeSQLite(input: Record<string, unknown>, ctx: ToolContext): ToolResult {
|
||
const query = input['query'] as string;
|
||
const dbPathInput = typeof input['db_path'] === 'string' ? input['db_path'] : 'temp.db';
|
||
|
||
// Resolve DB path
|
||
let resolvedDb: string;
|
||
try {
|
||
resolvedDb = resolveAndGuard(ctx.workspacePath, dbPathInput);
|
||
} catch (e) {
|
||
return { output: (e as Error).message, isError: true };
|
||
}
|
||
|
||
// Split and validate all statements
|
||
const statements = splitStatements(query);
|
||
if (statements.length === 0) {
|
||
return { output: 'Empty query', isError: true };
|
||
}
|
||
|
||
// DDL check (always blocked regardless of editAllowed)
|
||
for (const stmt of statements) {
|
||
const { blocked, reason } = isStatementBlocked(stmt);
|
||
if (blocked) {
|
||
return { output: `Forbidden SQL: ${reason}`, isError: true };
|
||
}
|
||
}
|
||
|
||
// Read-only mode: only SELECT allowed
|
||
if (!ctx.editAllowed) {
|
||
if (!allAreSelect(statements)) {
|
||
return {
|
||
output: 'Only SELECT queries are allowed when edit mode is disabled',
|
||
isError: true,
|
||
};
|
||
}
|
||
}
|
||
|
||
// Write check for non-edit mode (belt-and-suspenders)
|
||
if (!ctx.editAllowed) {
|
||
for (const stmt of statements) {
|
||
if (isWriteStatement(stmt)) {
|
||
return {
|
||
output: 'INSERT/UPDATE/DELETE are not allowed when edit mode is disabled',
|
||
isError: true,
|
||
};
|
||
}
|
||
}
|
||
}
|
||
|
||
logger.debug(`[SQLite] db=${resolvedDb} editAllowed=${ctx.editAllowed} statements=${statements.length}`);
|
||
|
||
// Open database
|
||
let db: Database.Database;
|
||
try {
|
||
db = new Database(resolvedDb, ctx.editAllowed ? {} : { readonly: true });
|
||
} catch (e) {
|
||
return { output: `Failed to open database: ${(e as Error).message}`, isError: true };
|
||
}
|
||
|
||
try {
|
||
// Execute all statements
|
||
// For single SELECT: return formatted table
|
||
// For single write: return changes count
|
||
// For multi-statement: execute all, return combined output
|
||
|
||
const outputs: string[] = [];
|
||
|
||
for (const stmt of statements) {
|
||
const kw = firstKeyword(stmt);
|
||
|
||
if (kw === 'SELECT' || kw === 'PRAGMA') {
|
||
try {
|
||
const rows = db.prepare(stmt).all() as Record<string, unknown>[];
|
||
outputs.push(formatTable(rows));
|
||
} catch (e) {
|
||
return { output: `Query error: ${(e as Error).message}`, isError: true };
|
||
}
|
||
} else if (WRITE_KEYWORDS.has(kw)) {
|
||
try {
|
||
const result = db.prepare(stmt).run();
|
||
outputs.push(`${result.changes} row(s) affected`);
|
||
} catch (e) {
|
||
return { output: `Query error: ${(e as Error).message}`, isError: true };
|
||
}
|
||
} else {
|
||
// CREATE TABLE, CREATE VIEW, etc.
|
||
try {
|
||
db.prepare(stmt).run();
|
||
outputs.push(`OK`);
|
||
} catch (e) {
|
||
return { output: `Query error: ${(e as Error).message}`, isError: true };
|
||
}
|
||
}
|
||
}
|
||
|
||
return { output: outputs.join('\n\n'), isError: false };
|
||
} finally {
|
||
db.close();
|
||
}
|
||
}
|
||
|
||
export async function executeTool(
|
||
name: string,
|
||
input: Record<string, unknown>,
|
||
ctx: ToolContext,
|
||
): Promise<ToolResult | null> {
|
||
switch (name) {
|
||
case 'SQLite':
|
||
return executeSQLite(input, ctx);
|
||
default:
|
||
return null;
|
||
}
|
||
}
|