Skip to content

Commit 002c854

Browse files
committed
Add weighted vector search
1 parent d703065 commit 002c854

File tree

6 files changed

+167
-8
lines changed

6 files changed

+167
-8
lines changed

sql/lantern.sql

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ $$ LANGUAGE plpgsql VOLATILE;
110110
-- Create access method
111111
DO $BODY$
112112
DECLARE
113-
hnsw_am_exists boolean;
114113
pgvector_exists boolean;
115114
BEGIN
116115
-- Check if the vector type from pgvector exists
@@ -593,4 +592,90 @@ $async_tasks_related$ LANGUAGE plpgsql;
593592
SELECT _lantern_internal.maybe_setup_lantern_tasks();
594593
DROP FUNCTION _lantern_internal.maybe_setup_lantern_tasks();
595594

595+
-- Weighted vector search
596596
-- Asynchronous task scheduling BEGIN
597+
CREATE OR REPLACE FUNCTION _lantern_internal.maybe_setup_weighted_vector_search() RETURNS VOID AS
598+
$weighted_vector_search$
599+
DECLARE
600+
pgvector_exists boolean;
601+
BEGIN
602+
-- Check if the vector type from pgvector exists
603+
SELECT EXISTS (
604+
SELECT 1
605+
FROM pg_type
606+
WHERE typname = 'vector'
607+
) INTO pgvector_exists;
608+
609+
IF NOT pgvector_exists THEN
610+
RAISE NOTICE 'pgvector extension not found. Skipping lantern weighted vector search setup';
611+
RETURN;
612+
END IF;
613+
614+
CREATE OR REPLACE FUNCTION lantern.weighted_vector_search(
615+
relation_type anyelement,
616+
max_dist numeric,
617+
w1 numeric,
618+
col1 text,
619+
vec1 vector,
620+
w2 numeric= 0,
621+
col2 text = NULL,
622+
vec2 vector = NULL,
623+
w3 numeric = 0,
624+
col3 text = NULL,
625+
vec3 vector = NULL)
626+
RETURNS TABLE ("row" anyelement) AS
627+
$$
628+
DECLARE
629+
joint_condition text;
630+
query_base text;
631+
query1 text;
632+
query2 text;
633+
query3 text;
634+
old_hnsw_ef_search numeric;
635+
qlimit integer := 1000;
636+
BEGIN
637+
qlimit := 500;
638+
-- Joint similarity metric condition
639+
joint_condition := format('(%s * (%I <=> %L))', w1, col1, vec1);
640+
IF w2 > 0 AND col2 IS NOT NULL AND vec2 IS NOT NULL THEN
641+
joint_condition := joint_condition || format(' + (%s * (%I <=> %L))', w2, col2, vec2);
642+
END IF;
643+
IF w3 > 0 AND col3 IS NOT NULL AND vec3 IS NOT NULL THEN
644+
joint_condition := joint_condition || format(' + (%s * (%I <=> %L))', w3, col3, vec3);
645+
END IF;
646+
647+
-- Base query with joint similarity metric
648+
query_base := format('SELECT *
649+
FROM %s WHERE %s < %L', pg_typeof(relation_type), joint_condition, max_dist);
650+
651+
-- extend hnsw depth to the maximum
652+
-- old_hnsw_ef_search := hnsw.ef_search;
653+
EXECUTE format('SET hnsw.ef_search TO %L', qlimit);
654+
655+
656+
-- Query 1: Order by first condition's weighted similarity
657+
query1 := format('%s ORDER BY %I <=> %L LIMIT %L', query_base, col1, vec1, qlimit);
658+
RETURN QUERY EXECUTE query1;
659+
660+
661+
-- Query 2: Order by other conditions' weighted similarity, if applicable
662+
663+
IF w2 > 0 AND col2 IS NOT NULL AND vec2 IS NOT NULL THEN
664+
query2 := format('%s ORDER BY %I <=> %L LIMIT %L', query_base, col2, vec2, qlimit);
665+
RETURN QUERY EXECUTE query2;
666+
END IF;
667+
668+
IF w3 > 0 AND col3 IS NOT NULL AND vec3 IS NOT NULL THEN
669+
query3 := format('%s ORDER BY %I <=> %L LIMIT %L', query_base, col3, vec3, qlimit);
670+
RETURN QUERY EXECUTE query3;
671+
END IF;
672+
673+
-- SET hnsw.ef_search = old_hnsw_ef_search;
674+
END
675+
$$ LANGUAGE plpgsql;
676+
END
677+
$weighted_vector_search$ LANGUAGE plpgsql;
678+
679+
SELECT _lantern_internal.maybe_setup_weighted_vector_search();
680+
DROP FUNCTION _lantern_internal.maybe_setup_weighted_vector_search;
681+

test/expected/hnsw_vector.out

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
-- pgvector if it is present on initialization
66
DROP EXTENSION IF EXISTS lantern;
77
CREATE EXTENSION IF NOT EXISTS vector;
8-
-- Setting min messages to ERROR so the WARNING about existing hnsw access method is NOT printed
9-
-- in tests. This makes sure that regression tests pass on pgvector <=0.4.4 as well as >=0.5.0
10-
SET client_min_messages=ERROR;
118
CREATE EXTENSION lantern;
129
RESET client_min_messages;
1310
SET lantern.pgvector_compat=FALSE;

test/expected/weighted_search.out

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
DROP EXTENSION IF EXISTS lantern;
2+
CREATE EXTENSION IF NOT EXISTS vector;
3+
\ir utils/sift1k_vector.sql
4+
CREATE TABLE IF NOT EXISTS sift_base1k (
5+
id SERIAL,
6+
v VECTOR(128)
7+
);
8+
COPY sift_base1k (v) FROM '/tmp/lantern/vector_datasets/sift_base1k.csv' WITH CSV;
9+
CREATE INDEX ON sift_base1k USING hnsw (v vector_l2_ops) WITH (M=5, ef_construction=20);
10+
-- Note: We drop the Lantern extension and re-create it because Lantern only supports
11+
-- pgvector if it is present on initialization
12+
CREATE EXTENSION lantern;
13+
-- create variables with 4th and 444th vector and find closest 10 IDs to each
14+
SELECT v as v4 FROM sift_Base1k WHERE id = 4 \gset
15+
SELECT v as v444 FROM sift_Base1k WHERE id = 444 \gset
16+
SELECT id, ROUND((v <-> :'v4')::numeric, 2) as dist FROM sift_Base1k ORDER BY v <-> :'v4' LIMIT 10;
17+
id | dist
18+
-----+--------
19+
4 | 0.00
20+
2 | 122.45
21+
15 | 141.39
22+
8 | 226.66
23+
163 | 244.31
24+
6 | 249.80
25+
63 | 257.77
26+
183 | 259.18
27+
254 | 263.45
28+
116 | 264.64
29+
(10 rows)
30+
31+
SELECT id, ROUND((v <-> :'v444')::numeric, 2) as dist FROM sift_Base1k ORDER BY v <-> :'v444' LIMIT 10;
32+
id | dist
33+
-----+--------
34+
444 | 0.00
35+
830 | 214.16
36+
854 | 298.42
37+
557 | 302.76
38+
62 | 305.77
39+
58 | 306.00
40+
30 | 307.27
41+
358 | 308.79
42+
60 | 309.94
43+
591 | 310.15
44+
(10 rows)
45+
46+
-- SELECT id, v <-> :'v4' as v4_dist, v <-> :'v444' v444_dist, weighted_dist FROM lantern.weighted_vector_search(CAST(NULL as "sift_base1k"), max_dist => 20., w1=> 1., col1=>'v'::text, vec1=>:'v4'::vector, w2=> 1., col2=>'v'::text, vec2=>:'v444'::vector) as v4_weighted_search LIMIT 10
47+
SELECT id, ROUND((v <-> :'v4')::numeric, 2) as v4_dist, ROUND((v <-> :'v444')::numeric, 2) v444_dist FROM lantern.weighted_vector_search(CAST(NULL as "sift_base1k"), max_dist => 20., w1=> 1., col1=>'v'::text, vec1=>:'v4'::vector, w2=> 1., col2=>'v'::text, vec2=>:'v444'::vector) as v4_weighted_search LIMIT 10
48+
id | v4_dist | v444_dist
49+
-----+---------+-----------
50+
4 | 0.00 | 390.54
51+
2 | 122.45 | 380.85
52+
15 | 141.39 | 404.62
53+
8 | 226.66 | 345.44
54+
163 | 244.31 | 401.58
55+
6 | 249.80 | 417.95
56+
63 | 257.77 | 338.20
57+
183 | 259.18 | 383.49
58+
254 | 263.45 | 392.81
59+
116 | 264.64 | 398.24
60+
(10 rows)
61+

test/schedule.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
test: hnsw_config hnsw_correct hnsw_create hnsw_create_expr hnsw_dist_func hnsw_insert hnsw_select hnsw_todo hnsw_index_from_file hnsw_cost_estimate ext_relocation hnsw_failure_point hnsw_operators hnsw_blockmap_create hnsw_create_unlogged hnsw_insert_unlogged hnsw_logged_unlogged missing_outer_snapshot_portal hnsw_pq
77
test_pg_cron: async_tasks
88
test: hnsw_pq_index
9-
test_pgvector: hnsw_vector
9+
test_pgvector: hnsw_vector weighted_search
1010
test_extras: hnsw_extras

test/sql/hnsw_vector.sql

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
-- pgvector if it is present on initialization
77
DROP EXTENSION IF EXISTS lantern;
88
CREATE EXTENSION IF NOT EXISTS vector;
9-
-- Setting min messages to ERROR so the WARNING about existing hnsw access method is NOT printed
10-
-- in tests. This makes sure that regression tests pass on pgvector <=0.4.4 as well as >=0.5.0
11-
SET client_min_messages=ERROR;
129
CREATE EXTENSION lantern;
1310
RESET client_min_messages;
1411
SET lantern.pgvector_compat=FALSE;

test/sql/weighted_search.sql

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
DROP EXTENSION IF EXISTS lantern;
2+
CREATE EXTENSION IF NOT EXISTS vector;
3+
4+
\ir utils/sift1k_vector.sql
5+
CREATE INDEX ON sift_base1k USING hnsw (v vector_l2_ops) WITH (M=5, ef_construction=20);
6+
7+
-- Note: We drop the Lantern extension and re-create it because Lantern only supports
8+
-- pgvector if it is present on initialization
9+
CREATE EXTENSION lantern;
10+
11+
-- create variables with 4th and 444th vector and find closest 10 IDs to each
12+
SELECT v as v4 FROM sift_Base1k WHERE id = 4 \gset
13+
SELECT v as v444 FROM sift_Base1k WHERE id = 444 \gset
14+
15+
SELECT id, ROUND((v <-> :'v4')::numeric, 2) as dist FROM sift_Base1k ORDER BY v <-> :'v4' LIMIT 10;
16+
17+
SELECT id, ROUND((v <-> :'v444')::numeric, 2) as dist FROM sift_Base1k ORDER BY v <-> :'v444' LIMIT 10;
18+
-- SELECT id, v <-> :'v4' as v4_dist, v <-> :'v444' v444_dist, weighted_dist FROM lantern.weighted_vector_search(CAST(NULL as "sift_base1k"), max_dist => 20., w1=> 1., col1=>'v'::text, vec1=>:'v4'::vector, w2=> 1., col2=>'v'::text, vec2=>:'v444'::vector) as v4_weighted_search LIMIT 10
19+
SELECT id, ROUND((v <-> :'v4')::numeric, 2) as v4_dist, ROUND((v <-> :'v444')::numeric, 2) v444_dist FROM lantern.weighted_vector_search(CAST(NULL as "sift_base1k"), max_dist => 20., w1=> 1., col1=>'v'::text, vec1=>:'v4'::vector, w2=> 1., col2=>'v'::text, vec2=>:'v444'::vector) as v4_weighted_search LIMIT 10

0 commit comments

Comments
 (0)