feat: add PythonExtractor for tree-sitter Python structural analysis

Implements the LanguageExtractor interface for Python, extracting functions
(with type annotations, defaults, *args/**kwargs), classes (methods +
annotated properties), imports (plain, from, aliased, wildcard), exports
(top-level defs), and caller-callee call graphs. Includes 31 tests using
the real tree-sitter parser.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Lum1104
2026-04-15 18:28:38 +08:00
Unverified
parent 2ad0563878
commit 398f5b15b0
2 changed files with 1012 additions and 0 deletions
@@ -0,0 +1,659 @@
import { describe, it, expect, beforeAll } from "vitest";
import { createRequire } from "node:module";
import { PythonExtractor } from "../python-extractor.js";
const require = createRequire(import.meta.url);
// Load tree-sitter + Python grammar once
let Parser: any;
let Language: any;
let pythonLang: any;
beforeAll(async () => {
const mod = await import("web-tree-sitter");
Parser = mod.Parser;
Language = mod.Language;
await Parser.init();
const wasmPath = require.resolve(
"tree-sitter-python/tree-sitter-python.wasm",
);
pythonLang = await Language.load(wasmPath);
});
function parse(code: string) {
const parser = new Parser();
parser.setLanguage(pythonLang);
const tree = parser.parse(code);
const root = tree.rootNode;
return { tree, parser, root };
}
describe("PythonExtractor", () => {
const extractor = new PythonExtractor();
it("has correct languageIds", () => {
expect(extractor.languageIds).toEqual(["python"]);
});
// ---- Functions ----
describe("extractStructure - functions", () => {
it("extracts simple functions with type annotations", () => {
const { tree, parser, root } = parse(`
def hello(name: str) -> str:
return f"Hello {name}"
def add(a: int, b: int) -> int:
return a + b
`);
const result = extractor.extractStructure(root);
expect(result.functions).toHaveLength(2);
expect(result.functions[0].name).toBe("hello");
expect(result.functions[0].params).toEqual(["name"]);
expect(result.functions[0].returnType).toBe("str");
expect(result.functions[0].lineRange[0]).toBeGreaterThan(0);
expect(result.functions[1].name).toBe("add");
expect(result.functions[1].params).toEqual(["a", "b"]);
expect(result.functions[1].returnType).toBe("int");
tree.delete();
parser.delete();
});
it("extracts functions without type annotations", () => {
const { tree, parser, root } = parse(`
def greet(name):
print(name)
def noop():
pass
`);
const result = extractor.extractStructure(root);
expect(result.functions).toHaveLength(2);
expect(result.functions[0].name).toBe("greet");
expect(result.functions[0].params).toEqual(["name"]);
expect(result.functions[0].returnType).toBeUndefined();
expect(result.functions[1].name).toBe("noop");
expect(result.functions[1].params).toEqual([]);
tree.delete();
parser.delete();
});
it("extracts functions with default parameters", () => {
const { tree, parser, root } = parse(`
def connect(host: str, port: int = 8080, timeout: float = 30.0):
pass
`);
const result = extractor.extractStructure(root);
expect(result.functions).toHaveLength(1);
expect(result.functions[0].name).toBe("connect");
expect(result.functions[0].params).toEqual(["host", "port", "timeout"]);
tree.delete();
parser.delete();
});
it("extracts functions with *args and **kwargs", () => {
const { tree, parser, root } = parse(`
def flexible(*args, **kwargs):
pass
`);
const result = extractor.extractStructure(root);
expect(result.functions).toHaveLength(1);
expect(result.functions[0].params).toEqual(["*args", "**kwargs"]);
tree.delete();
parser.delete();
});
it("extracts decorated functions", () => {
const { tree, parser, root } = parse(`
@decorator
def decorated_func():
pass
@app.route("/api")
def api_handler():
pass
`);
const result = extractor.extractStructure(root);
expect(result.functions).toHaveLength(2);
expect(result.functions[0].name).toBe("decorated_func");
expect(result.functions[1].name).toBe("api_handler");
tree.delete();
parser.delete();
});
it("reports correct line ranges", () => {
const { tree, parser, root } = parse(`
def multiline(
a: int,
b: int,
) -> int:
result = a + b
return result
`);
const result = extractor.extractStructure(root);
expect(result.functions).toHaveLength(1);
expect(result.functions[0].lineRange[0]).toBe(2);
expect(result.functions[0].lineRange[1]).toBe(7);
tree.delete();
parser.delete();
});
});
// ---- Classes ----
describe("extractStructure - classes", () => {
it("extracts classes with methods and properties", () => {
const { tree, parser, root } = parse(`
class DataProcessor:
name: str
def __init__(self, name: str):
self.name = name
def process(self, data: list) -> dict:
return transform(data)
`);
const result = extractor.extractStructure(root);
expect(result.classes).toHaveLength(1);
expect(result.classes[0].name).toBe("DataProcessor");
expect(result.classes[0].methods).toContain("__init__");
expect(result.classes[0].methods).toContain("process");
expect(result.classes[0].properties).toContain("name");
tree.delete();
parser.delete();
});
it("extracts dataclass-style annotated properties", () => {
const { tree, parser, root } = parse(`
class Config:
name: str
value: int
debug: bool
`);
const result = extractor.extractStructure(root);
expect(result.classes).toHaveLength(1);
expect(result.classes[0].properties).toEqual(["name", "value", "debug"]);
expect(result.classes[0].methods).toEqual([]);
tree.delete();
parser.delete();
});
it("extracts decorated classes", () => {
const { tree, parser, root } = parse(`
@dataclass
class Config:
name: str
value: int = 0
`);
const result = extractor.extractStructure(root);
expect(result.classes).toHaveLength(1);
expect(result.classes[0].name).toBe("Config");
expect(result.classes[0].properties).toContain("name");
expect(result.classes[0].properties).toContain("value");
tree.delete();
parser.delete();
});
it("extracts decorated methods within a class", () => {
const { tree, parser, root } = parse(`
class MyClass:
@staticmethod
def static_method():
pass
@classmethod
def class_method(cls):
pass
@property
def prop(self):
return self._prop
`);
const result = extractor.extractStructure(root);
expect(result.classes).toHaveLength(1);
expect(result.classes[0].methods).toContain("static_method");
expect(result.classes[0].methods).toContain("class_method");
expect(result.classes[0].methods).toContain("prop");
tree.delete();
parser.delete();
});
it("filters self and cls from method params", () => {
const { tree, parser, root } = parse(`
class Foo:
def instance_method(self, x: int):
pass
@classmethod
def class_method(cls, y: str):
pass
`);
const result = extractor.extractStructure(root);
// Methods are on the class, but top-level functions should not include them
expect(result.functions).toHaveLength(0);
expect(result.classes[0].methods).toEqual(["instance_method", "class_method"]);
tree.delete();
parser.delete();
});
it("reports correct class line ranges", () => {
const { tree, parser, root } = parse(`
class MyClass:
def method_a(self):
pass
def method_b(self):
pass
`);
const result = extractor.extractStructure(root);
expect(result.classes).toHaveLength(1);
expect(result.classes[0].lineRange[0]).toBe(2);
expect(result.classes[0].lineRange[1]).toBe(7);
tree.delete();
parser.delete();
});
});
// ---- Imports ----
describe("extractStructure - imports", () => {
it("extracts simple import statements", () => {
const { tree, parser, root } = parse(`
import os
import sys
`);
const result = extractor.extractStructure(root);
expect(result.imports).toHaveLength(2);
expect(result.imports[0].source).toBe("os");
expect(result.imports[0].specifiers).toEqual(["os"]);
expect(result.imports[1].source).toBe("sys");
expect(result.imports[1].specifiers).toEqual(["sys"]);
tree.delete();
parser.delete();
});
it("extracts from-import statements", () => {
const { tree, parser, root } = parse(`
from pathlib import Path
from typing import Optional, List
`);
const result = extractor.extractStructure(root);
expect(result.imports).toHaveLength(2);
expect(result.imports[0].source).toBe("pathlib");
expect(result.imports[0].specifiers).toEqual(["Path"]);
expect(result.imports[1].source).toBe("typing");
expect(result.imports[1].specifiers).toEqual(["Optional", "List"]);
tree.delete();
parser.delete();
});
it("extracts aliased imports", () => {
const { tree, parser, root } = parse(`
from foo import bar as baz
`);
const result = extractor.extractStructure(root);
expect(result.imports).toHaveLength(1);
expect(result.imports[0].source).toBe("foo");
expect(result.imports[0].specifiers).toEqual(["baz"]);
tree.delete();
parser.delete();
});
it("extracts dotted module imports", () => {
const { tree, parser, root } = parse(`
import os.path
from os.path import join, exists
`);
const result = extractor.extractStructure(root);
expect(result.imports).toHaveLength(2);
expect(result.imports[0].source).toBe("os.path");
expect(result.imports[0].specifiers).toEqual(["os.path"]);
expect(result.imports[1].source).toBe("os.path");
expect(result.imports[1].specifiers).toEqual(["join", "exists"]);
tree.delete();
parser.delete();
});
it("extracts wildcard imports", () => {
const { tree, parser, root } = parse(`
from os.path import *
`);
const result = extractor.extractStructure(root);
expect(result.imports).toHaveLength(1);
expect(result.imports[0].source).toBe("os.path");
expect(result.imports[0].specifiers).toEqual(["*"]);
tree.delete();
parser.delete();
});
it("handles all import types together", () => {
const { tree, parser, root } = parse(`
import os
from pathlib import Path
from typing import Optional, List
`);
const result = extractor.extractStructure(root);
expect(result.imports.length).toBeGreaterThanOrEqual(3);
tree.delete();
parser.delete();
});
it("reports correct import line numbers", () => {
const { tree, parser, root } = parse(`
import os
from pathlib import Path
`);
const result = extractor.extractStructure(root);
expect(result.imports[0].lineNumber).toBe(2);
expect(result.imports[1].lineNumber).toBe(3);
tree.delete();
parser.delete();
});
});
// ---- Exports ----
describe("extractStructure - exports", () => {
it("treats top-level functions as exports", () => {
const { tree, parser, root } = parse(`
def public_func():
pass
def another_func(x: int) -> str:
return str(x)
`);
const result = extractor.extractStructure(root);
const exportNames = result.exports.map((e) => e.name);
expect(exportNames).toContain("public_func");
expect(exportNames).toContain("another_func");
expect(result.exports).toHaveLength(2);
tree.delete();
parser.delete();
});
it("treats top-level classes as exports", () => {
const { tree, parser, root } = parse(`
class MyService:
pass
class MyModel:
pass
`);
const result = extractor.extractStructure(root);
const exportNames = result.exports.map((e) => e.name);
expect(exportNames).toContain("MyService");
expect(exportNames).toContain("MyModel");
expect(result.exports).toHaveLength(2);
tree.delete();
parser.delete();
});
it("treats decorated top-level definitions as exports", () => {
const { tree, parser, root } = parse(`
@dataclass
class Config:
name: str
@app.route("/")
def index():
pass
`);
const result = extractor.extractStructure(root);
const exportNames = result.exports.map((e) => e.name);
expect(exportNames).toContain("Config");
expect(exportNames).toContain("index");
tree.delete();
parser.delete();
});
it("does not treat imports as exports", () => {
const { tree, parser, root } = parse(`
import os
from pathlib import Path
def my_func():
pass
`);
const result = extractor.extractStructure(root);
expect(result.exports).toHaveLength(1);
expect(result.exports[0].name).toBe("my_func");
tree.delete();
parser.delete();
});
});
// ---- Call Graph ----
describe("extractCallGraph", () => {
it("extracts simple function calls", () => {
const { tree, parser, root } = parse(`
def process(data):
result = transform(data)
return format_output(result)
def main():
process([1, 2, 3])
`);
const result = extractor.extractCallGraph(root);
expect(result.length).toBeGreaterThanOrEqual(2);
const processCallers = result.filter((e) => e.caller === "process");
expect(processCallers.some((e) => e.callee === "transform")).toBe(true);
expect(processCallers.some((e) => e.callee === "format_output")).toBe(true);
const mainCallers = result.filter((e) => e.caller === "main");
expect(mainCallers.some((e) => e.callee === "process")).toBe(true);
tree.delete();
parser.delete();
});
it("extracts attribute-based calls (method calls)", () => {
const { tree, parser, root } = parse(`
def process():
self.method()
os.path.join("a", "b")
result.save()
`);
const result = extractor.extractCallGraph(root);
const callees = result.map((e) => e.callee);
expect(callees).toContain("self.method");
expect(callees).toContain("os.path.join");
expect(callees).toContain("result.save");
tree.delete();
parser.delete();
});
it("tracks correct caller context for nested calls", () => {
const { tree, parser, root } = parse(`
def outer():
helper()
def inner():
deep_call()
another()
`);
const result = extractor.extractCallGraph(root);
const outerCalls = result.filter((e) => e.caller === "outer");
expect(outerCalls.some((e) => e.callee === "helper")).toBe(true);
expect(outerCalls.some((e) => e.callee === "another")).toBe(true);
const innerCalls = result.filter((e) => e.caller === "inner");
expect(innerCalls.some((e) => e.callee === "deep_call")).toBe(true);
tree.delete();
parser.delete();
});
it("reports correct line numbers for calls", () => {
const { tree, parser, root } = parse(`
def main():
foo()
bar()
`);
const result = extractor.extractCallGraph(root);
expect(result).toHaveLength(2);
expect(result[0].lineNumber).toBe(3);
expect(result[1].lineNumber).toBe(4);
tree.delete();
parser.delete();
});
it("ignores top-level calls (no caller)", () => {
const { tree, parser, root } = parse(`
print("hello")
main()
`);
const result = extractor.extractCallGraph(root);
// Top-level calls have no enclosing function, so they are skipped
expect(result).toHaveLength(0);
tree.delete();
parser.delete();
});
it("handles calls inside class methods", () => {
const { tree, parser, root } = parse(`
class Service:
def start(self):
self.setup()
run_server()
`);
const result = extractor.extractCallGraph(root);
const startCalls = result.filter((e) => e.caller === "start");
expect(startCalls.some((e) => e.callee === "self.setup")).toBe(true);
expect(startCalls.some((e) => e.callee === "run_server")).toBe(true);
tree.delete();
parser.delete();
});
});
// ---- Comprehensive ----
describe("comprehensive Python file", () => {
it("handles a realistic Python module", () => {
const { tree, parser, root } = parse(`
import os
from pathlib import Path
from typing import Optional, List
class FileProcessor:
name: str
verbose: bool
def __init__(self, name: str, verbose: bool = False):
self.name = name
self.verbose = verbose
def process(self, paths: List[str]) -> dict:
results = {}
for p in paths:
results[p] = self._read_file(p)
return results
def _read_file(self, path: str) -> Optional[str]:
full = Path(path)
if full.exists():
return full.read_text()
return None
def create_processor(name: str) -> FileProcessor:
return FileProcessor(name)
@staticmethod
def utility_func(*args, **kwargs) -> None:
print(args, kwargs)
`);
const result = extractor.extractStructure(root);
// Imports
expect(result.imports.length).toBeGreaterThanOrEqual(3);
// Class
expect(result.classes).toHaveLength(1);
expect(result.classes[0].name).toBe("FileProcessor");
expect(result.classes[0].methods).toContain("__init__");
expect(result.classes[0].methods).toContain("process");
expect(result.classes[0].methods).toContain("_read_file");
expect(result.classes[0].properties).toContain("name");
expect(result.classes[0].properties).toContain("verbose");
// Top-level functions
expect(result.functions.some((f) => f.name === "create_processor")).toBe(
true,
);
expect(result.functions.some((f) => f.name === "utility_func")).toBe(
true,
);
// Exports (top-level defs)
const exportNames = result.exports.map((e) => e.name);
expect(exportNames).toContain("FileProcessor");
expect(exportNames).toContain("create_processor");
expect(exportNames).toContain("utility_func");
// Call graph
const calls = extractor.extractCallGraph(root);
expect(calls.length).toBeGreaterThan(0);
tree.delete();
parser.delete();
});
});
});
@@ -0,0 +1,353 @@
import type { StructuralAnalysis, CallGraphEntry } from "../../types.js";
import type { LanguageExtractor, TreeSitterNode } from "./types.js";
import { findChild, findChildren } from "./base-extractor.js";
/**
* Extract parameter names from a Python `parameters` node.
*
* Handles: identifier (plain), typed_parameter, default_parameter,
* typed_default_parameter, list_splat_pattern (*args),
* dictionary_splat_pattern (**kwargs).
*/
function extractParams(paramsNode: TreeSitterNode | null): string[] {
if (!paramsNode) return [];
const params: string[] = [];
for (let i = 0; i < paramsNode.childCount; i++) {
const child = paramsNode.child(i);
if (!child) continue;
switch (child.type) {
case "identifier":
// Skip `self` and `cls` — they are implicit, not real parameters
if (child.text !== "self" && child.text !== "cls") {
params.push(child.text);
}
break;
case "typed_parameter": {
const ident = findChild(child, "identifier");
if (ident && ident.text !== "self" && ident.text !== "cls") {
params.push(ident.text);
}
break;
}
case "default_parameter": {
const ident = findChild(child, "identifier");
if (ident && ident.text !== "self" && ident.text !== "cls") {
params.push(ident.text);
}
break;
}
case "typed_default_parameter": {
const ident = findChild(child, "identifier");
if (ident && ident.text !== "self" && ident.text !== "cls") {
params.push(ident.text);
}
break;
}
case "list_splat_pattern": {
const ident = findChild(child, "identifier");
if (ident) params.push("*" + ident.text);
break;
}
case "dictionary_splat_pattern": {
const ident = findChild(child, "identifier");
if (ident) params.push("**" + ident.text);
break;
}
}
}
return params;
}
/**
* Extract the return type annotation from a function_definition node.
* Python AST has a `return_type` field (the `type` node after `->`) on function_definition.
*/
function extractReturnType(node: TreeSitterNode): string | undefined {
const returnType = node.childForFieldName("return_type");
if (returnType) {
return returnType.text;
}
return undefined;
}
/**
* Unwrap a `decorated_definition` to get the inner definition.
* If the node is not a decorated_definition, returns the node itself.
*/
function unwrapDecorated(node: TreeSitterNode): TreeSitterNode {
if (node.type === "decorated_definition") {
const inner =
findChild(node, "function_definition") ??
findChild(node, "class_definition");
if (inner) return inner;
}
return node;
}
/**
* Python extractor for tree-sitter structural analysis and call graph extraction.
*
* Handles functions, classes, imports, exports, and call graphs for Python code.
* Python has no formal export syntax, so all top-level function and class
* definitions are treated as exports.
*/
export class PythonExtractor implements LanguageExtractor {
readonly languageIds = ["python"];
extractStructure(rootNode: TreeSitterNode): StructuralAnalysis {
const functions: StructuralAnalysis["functions"] = [];
const classes: StructuralAnalysis["classes"] = [];
const imports: StructuralAnalysis["imports"] = [];
const exports: StructuralAnalysis["exports"] = [];
for (let i = 0; i < rootNode.childCount; i++) {
const node = rootNode.child(i);
if (!node) continue;
// Unwrap decorated definitions to get the inner node
const inner = unwrapDecorated(node);
switch (inner.type) {
case "function_definition":
this.extractFunction(inner, functions);
// Top-level functions are exports in Python
this.addExport(inner, node, exports);
break;
case "class_definition":
this.extractClass(inner, classes);
// Top-level classes are exports in Python
this.addExport(inner, node, exports);
break;
case "import_statement":
this.extractImport(inner, imports);
break;
case "import_from_statement":
this.extractFromImport(inner, imports);
break;
}
}
return { functions, classes, imports, exports };
}
extractCallGraph(rootNode: TreeSitterNode): CallGraphEntry[] {
const entries: CallGraphEntry[] = [];
const functionStack: string[] = [];
const walkForCalls = (node: TreeSitterNode) => {
let pushedName = false;
// Track entering function/method definitions
if (node.type === "function_definition") {
const nameNode = node.childForFieldName("name");
if (nameNode) {
functionStack.push(nameNode.text);
pushedName = true;
}
}
// Extract call expressions
if (node.type === "call") {
const calleeNode = node.children.find(
(c) =>
c.type === "identifier" ||
c.type === "attribute",
);
if (calleeNode && functionStack.length > 0) {
entries.push({
caller: functionStack[functionStack.length - 1],
callee: calleeNode.text,
lineNumber: node.startPosition.row + 1,
});
}
}
for (let i = 0; i < node.childCount; i++) {
const child = node.child(i);
if (child) walkForCalls(child);
}
if (pushedName) {
functionStack.pop();
}
};
walkForCalls(rootNode);
return entries;
}
// ---- Private helpers ----
private extractFunction(
node: TreeSitterNode,
functions: StructuralAnalysis["functions"],
): void {
const nameNode = node.childForFieldName("name");
if (!nameNode) return;
const paramsNode = node.childForFieldName("parameters");
const params = extractParams(paramsNode ?? null);
const returnType = extractReturnType(node);
functions.push({
name: nameNode.text,
lineRange: [
node.startPosition.row + 1,
node.endPosition.row + 1,
],
params,
returnType,
});
}
private extractClass(
node: TreeSitterNode,
classes: StructuralAnalysis["classes"],
): void {
const nameNode = node.childForFieldName("name");
if (!nameNode) return;
const methods: string[] = [];
const properties: string[] = [];
const body = node.childForFieldName("body");
if (body) {
for (let i = 0; i < body.childCount; i++) {
const member = body.child(i);
if (!member) continue;
// Methods: function_definition or decorated_definition wrapping a function_definition
const innerMember = unwrapDecorated(member);
if (innerMember.type === "function_definition") {
const methodName = innerMember.childForFieldName("name");
if (methodName) methods.push(methodName.text);
}
// Properties: type-annotated assignments at class body level
// e.g., `name: str` or `value: int = 0`
if (member.type === "expression_statement") {
const assignment = findChild(member, "assignment");
if (assignment) {
// Check if this is a type-annotated class-level assignment (has `:` child = type annotation)
const typeNode = findChild(assignment, "type");
const nameIdent = findChild(assignment, "identifier");
if (typeNode && nameIdent) {
properties.push(nameIdent.text);
}
}
}
}
}
classes.push({
name: nameNode.text,
lineRange: [
node.startPosition.row + 1,
node.endPosition.row + 1,
],
methods,
properties,
});
}
private extractImport(
node: TreeSitterNode,
imports: StructuralAnalysis["imports"],
): void {
// `import os` or `import os.path`
// Can have multiple: `import os, sys`
const dottedNames = findChildren(node, "dotted_name");
const aliasedImports = findChildren(node, "aliased_import");
for (const dn of dottedNames) {
imports.push({
source: dn.text,
specifiers: [dn.text],
lineNumber: node.startPosition.row + 1,
});
}
for (const ai of aliasedImports) {
const dottedName = findChild(ai, "dotted_name");
const alias = ai.children.find(
(c) => c.type === "identifier",
);
if (dottedName) {
imports.push({
source: dottedName.text,
specifiers: [alias ? alias.text : dottedName.text],
lineNumber: node.startPosition.row + 1,
});
}
}
}
private extractFromImport(
node: TreeSitterNode,
imports: StructuralAnalysis["imports"],
): void {
// `from pathlib import Path` or `from typing import Optional, List`
const moduleNode = node.childForFieldName("module_name");
const source = moduleNode ? moduleNode.text : "";
const moduleNodeId = moduleNode?.id;
const specifiers: string[] = [];
// Collect dotted_name specifiers (non-aliased)
// Skip the module_name dotted_name (compare by node id, not reference)
const allDottedNames = findChildren(node, "dotted_name");
for (const dn of allDottedNames) {
if (dn.id === moduleNodeId) continue;
specifiers.push(dn.text);
}
// Collect aliased imports: `from foo import bar as baz`
const aliasedImports = findChildren(node, "aliased_import");
for (const ai of aliasedImports) {
// The alias identifier follows the `as` keyword
const alias = ai.children.find(
(c) => c.type === "identifier",
);
if (alias) {
specifiers.push(alias.text);
}
}
// Handle wildcard imports: `from os import *`
if (findChild(node, "wildcard_import")) {
specifiers.push("*");
}
imports.push({
source,
specifiers,
lineNumber: node.startPosition.row + 1,
});
}
private addExport(
inner: TreeSitterNode,
outer: TreeSitterNode,
exports: StructuralAnalysis["exports"],
): void {
const nameNode = inner.childForFieldName("name");
if (nameNode) {
exports.push({
name: nameNode.text,
lineNumber: outer.startPosition.row + 1,
});
}
}
}