Skip to content

Commit

Permalink
fix: Koloxarto/fix ragknowledge for postgres (elizaOS#2153)
Browse files Browse the repository at this point in the history
* fix formatting out of the way

* fix postgress chunk uuid handling for ragKnowledge

---------

Co-authored-by: Odilitime <[email protected]>
  • Loading branch information
web3gh and odilitime authored Jan 12, 2025
1 parent 35d857e commit 6690ea6
Show file tree
Hide file tree
Showing 3 changed files with 449 additions and 218 deletions.
226 changes: 168 additions & 58 deletions packages/adapter-postgres/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,31 @@ import { v4 } from "uuid";
import pg from "pg";
type Pool = pg.Pool;

import {
QueryConfig,
QueryConfigValues,
QueryResult,
QueryResultRow,
} from "pg";
import {
Account,
Actor,
DatabaseAdapter,
EmbeddingProvider,
GoalStatus,
Participant,
RAGKnowledgeItem,
elizaLogger,
getEmbeddingConfig,
type Goal,
type IDatabaseCacheAdapter,
type Memory,
type Relationship,
type UUID,
type IDatabaseCacheAdapter,
Participant,
elizaLogger,
getEmbeddingConfig,
DatabaseAdapter,
EmbeddingProvider,
RAGKnowledgeItem
} from "@elizaos/core";
import fs from "fs";
import { fileURLToPath } from "url";
import path from "path";
import {
QueryConfig,
QueryConfigValues,
QueryResult,
QueryResultRow,
} from "pg";
import { fileURLToPath } from "url";

const __filename = fileURLToPath(import.meta.url); // get the resolved path to the file
const __dirname = path.dirname(__filename); // get the name of the directory
Expand Down Expand Up @@ -199,7 +199,7 @@ export class PostgresDatabaseAdapter
return true;
} catch (error) {
elizaLogger.error("Failed to validate vector extension:", {
error: error instanceof Error ? error.message : String(error)
error: error instanceof Error ? error.message : String(error),
});
return false;
}
Expand Down Expand Up @@ -239,8 +239,10 @@ export class PostgresDatabaseAdapter
);
`);

if (!rows[0].exists || !await this.validateVectorSetup()) {
elizaLogger.info("Applying database schema - tables or vector extension missing");
if (!rows[0].exists || !(await this.validateVectorSetup())) {
elizaLogger.info(
"Applying database schema - tables or vector extension missing"
);
const schema = fs.readFileSync(
path.resolve(__dirname, "../schema.sql"),
"utf8"
Expand Down Expand Up @@ -1523,12 +1525,17 @@ export class PostgresDatabaseAdapter

const { rows } = await this.pool.query(sql, queryParams);

return rows.map(row => ({
return rows.map((row) => ({
id: row.id,
agentId: row.agentId,
content: typeof row.content === 'string' ? JSON.parse(row.content) : row.content,
embedding: row.embedding ? new Float32Array(row.embedding) : undefined,
createdAt: row.createdAt.getTime()
content:
typeof row.content === "string"
? JSON.parse(row.content)
: row.content,
embedding: row.embedding
? new Float32Array(row.embedding)
: undefined,
createdAt: row.createdAt.getTime(),
}));
}, "getKnowledge");
}
Expand All @@ -1544,7 +1551,7 @@ export class PostgresDatabaseAdapter
const cacheKey = `embedding_${params.agentId}_${params.searchText}`;
const cachedResult = await this.getCache({
key: cacheKey,
agentId: params.agentId
agentId: params.agentId,
});

if (cachedResult) {
Expand Down Expand Up @@ -1594,24 +1601,29 @@ export class PostgresDatabaseAdapter
const { rows } = await this.pool.query(sql, [
vectorStr,
params.agentId,
`%${params.searchText || ''}%`,
`%${params.searchText || ""}%`,
params.match_threshold,
params.match_count
params.match_count,
]);

const results = rows.map(row => ({
const results = rows.map((row) => ({
id: row.id,
agentId: row.agentId,
content: typeof row.content === 'string' ? JSON.parse(row.content) : row.content,
embedding: row.embedding ? new Float32Array(row.embedding) : undefined,
content:
typeof row.content === "string"
? JSON.parse(row.content)
: row.content,
embedding: row.embedding
? new Float32Array(row.embedding)
: undefined,
createdAt: row.createdAt.getTime(),
similarity: row.combined_score
similarity: row.combined_score,
}));

await this.setCache({
key: cacheKey,
agentId: params.agentId,
value: JSON.stringify(results)
value: JSON.stringify(results),
});

return results;
Expand All @@ -1622,35 +1634,52 @@ export class PostgresDatabaseAdapter
return this.withDatabase(async () => {
const client = await this.pool.connect();
try {
await client.query('BEGIN');

const sql = `
INSERT INTO knowledge (
id, "agentId", content, embedding, "createdAt",
"isMain", "originalId", "chunkIndex", "isShared"
) VALUES ($1, $2, $3, $4, to_timestamp($5/1000.0), $6, $7, $8, $9)
ON CONFLICT (id) DO NOTHING
`;
await client.query("BEGIN");

const metadata = knowledge.content.metadata || {};
const vectorStr = knowledge.embedding ?
`[${Array.from(knowledge.embedding).join(",")}]` : null;

await client.query(sql, [
knowledge.id,
metadata.isShared ? null : knowledge.agentId,
knowledge.content,
vectorStr,
knowledge.createdAt || Date.now(),
metadata.isMain || false,
metadata.originalId || null,
metadata.chunkIndex || null,
metadata.isShared || false
]);
const vectorStr = knowledge.embedding
? `[${Array.from(knowledge.embedding).join(",")}]`
: null;

// If this is a chunk, use createKnowledgeChunk
if (metadata.isChunk && metadata.originalId) {
await this.createKnowledgeChunk({
id: knowledge.id,
originalId: metadata.originalId,
agentId: metadata.isShared ? null : knowledge.agentId,
content: knowledge.content,
embedding: knowledge.embedding,
chunkIndex: metadata.chunkIndex || 0,
isShared: metadata.isShared || false,
createdAt: knowledge.createdAt || Date.now(),
});
} else {
// This is a main knowledge item
await client.query(
`
INSERT INTO knowledge (
id, "agentId", content, embedding, "createdAt",
"isMain", "originalId", "chunkIndex", "isShared"
) VALUES ($1, $2, $3, $4, to_timestamp($5/1000.0), $6, $7, $8, $9)
ON CONFLICT (id) DO NOTHING
`,
[
knowledge.id,
metadata.isShared ? null : knowledge.agentId,
knowledge.content,
vectorStr,
knowledge.createdAt || Date.now(),
true,
null,
null,
metadata.isShared || false,
]
);
}

await client.query('COMMIT');
await client.query("COMMIT");
} catch (error) {
await client.query('ROLLBACK');
await client.query("ROLLBACK");
throw error;
} finally {
client.release();
Expand All @@ -1660,19 +1689,100 @@ export class PostgresDatabaseAdapter

async removeKnowledge(id: UUID): Promise<void> {
return this.withDatabase(async () => {
await this.pool.query('DELETE FROM knowledge WHERE id = $1', [id]);
const client = await this.pool.connect();
try {
await client.query("BEGIN");

// Check if this is a pattern-based chunk deletion (e.g., "id-chunk-*")
if (typeof id === "string" && id.includes("-chunk-*")) {
const mainId = id.split("-chunk-")[0];
// Delete chunks for this main ID
await client.query(
'DELETE FROM knowledge WHERE "originalId" = $1',
[mainId]
);
} else {
// First delete all chunks associated with this knowledge item
await client.query(
'DELETE FROM knowledge WHERE "originalId" = $1',
[id]
);
// Then delete the main knowledge item
await client.query("DELETE FROM knowledge WHERE id = $1", [
id,
]);
}

await client.query("COMMIT");
} catch (error) {
await client.query("ROLLBACK");
elizaLogger.error("Error removing knowledge", {
error:
error instanceof Error ? error.message : String(error),
id,
});
throw error;
} finally {
client.release();
}
}, "removeKnowledge");
}

async clearKnowledge(agentId: UUID, shared?: boolean): Promise<void> {
return this.withDatabase(async () => {
const sql = shared ?
'DELETE FROM knowledge WHERE ("agentId" = $1 OR "isShared" = true)' :
'DELETE FROM knowledge WHERE "agentId" = $1';
const sql = shared
? 'DELETE FROM knowledge WHERE ("agentId" = $1 OR "isShared" = true)'
: 'DELETE FROM knowledge WHERE "agentId" = $1';

await this.pool.query(sql, [agentId]);
}, "clearKnowledge");
}

private async createKnowledgeChunk(params: {
id: UUID;
originalId: UUID;
agentId: UUID | null;
content: any;
embedding: Float32Array | undefined | null;
chunkIndex: number;
isShared: boolean;
createdAt: number;
}): Promise<void> {
const vectorStr = params.embedding
? `[${Array.from(params.embedding).join(",")}]`
: null;

// Store the pattern-based ID in the content metadata for compatibility
const patternId = `${params.originalId}-chunk-${params.chunkIndex}`;
const contentWithPatternId = {
...params.content,
metadata: {
...params.content.metadata,
patternId,
},
};

await this.pool.query(
`
INSERT INTO knowledge (
id, "agentId", content, embedding, "createdAt",
"isMain", "originalId", "chunkIndex", "isShared"
) VALUES ($1, $2, $3, $4, to_timestamp($5/1000.0), $6, $7, $8, $9)
ON CONFLICT (id) DO NOTHING
`,
[
v4(), // Generate a proper UUID for PostgreSQL
params.agentId,
contentWithPatternId, // Store the pattern ID in metadata
vectorStr,
params.createdAt,
false,
params.originalId,
params.chunkIndex,
params.isShared,
]
);
}
}

export default PostgresDatabaseAdapter;
Loading

0 comments on commit 6690ea6

Please sign in to comment.