@@ -115,4 +115,292 @@ $async_tasks_related$ LANGUAGE plpgsql;
115
115
SELECT _lantern_internal .maybe_setup_lantern_tasks ();
116
116
DROP FUNCTION _lantern_internal .maybe_setup_lantern_tasks ();
117
117
118
- -- Asynchronous task scheduling BEGIN
118
+ -- ^^^^
119
+ -- Asynchronous task scheduling END
120
+
121
+ -- Weighted vector search
122
+
123
+ CREATE OR REPLACE FUNCTION _lantern_internal .mask_arrays(arr text )
124
+ RETURNS text AS $$
125
+ BEGIN
126
+ -- match:
127
+ -- single quote (escaped by doubling it)
128
+ -- opening square bracket (escaped with a backslash)
129
+ -- any character (as few as possible, via *?)
130
+ -- closing square bracket (escaped with a backslash)
131
+ -- single quote (escaped by doubling it)
132
+ -- the string ::vector literally
133
+ arr := regexp_replace(arr, ' ' ' \[ .*?\] ' ' ::vector' , ' ' ' [MASKED_VECTOR]' ' ::vector' ,' g' );
134
+ -- same as above, but for non-explain context where the explicit cast is missing
135
+ arr := regexp_replace(arr, ' ' ' \[ .*?\] ' ' ' , ' ' ' [MASKED_VECTOR]' ' ' ,' g' );
136
+
137
+ RETURN arr;
138
+ END
139
+ $$ LANGUAGE plpgsql;
140
+
141
+
142
+ -- Helper function that takes in the output of EXPLAIN (FORMAT JSON) and masks long vectors in ORDER BY clauses
143
+ CREATE OR REPLACE FUNCTION _lantern_internal .mask_order_by_in_plan(json_data jsonb)
144
+ RETURNS jsonb AS $$
145
+ DECLARE
146
+ key TEXT ;
147
+ value JSONB;
148
+ BEGIN
149
+ -- Check if the input is null
150
+ IF json_data IS NULL THEN
151
+ RETURN NULL ;
152
+ END IF;
153
+
154
+ -- Check if the input is a JSON object
155
+ IF jsonb_typeof(json_data) = ' object' THEN
156
+ -- Loop through each key-value pair in the JSON object
157
+ FOR key, value IN SELECT * FROM jsonb_each(json_data) LOOP
158
+ -- If the key is "Order By", set the value to null
159
+ IF key = ' Order By' OR key = ' Filter' OR key = ' Sort Key' THEN
160
+ value = _lantern_internal .mask_arrays (value::text );
161
+ json_data = jsonb_set(json_data, ARRAY[key], value);
162
+ ELSE
163
+ -- Recursively call the function for nested JSON objects or arrays
164
+ json_data = jsonb_set(json_data, ARRAY[key], _lantern_internal .mask_order_by_in_plan (value));
165
+ END IF;
166
+ END LOOP;
167
+ -- Check if the input is a JSON array
168
+ ELSIF jsonb_typeof(json_data) = ' array' THEN
169
+ -- Loop through each element in the JSON array
170
+ FOR idx IN 0 .. jsonb_array_length(json_data) - 1 LOOP
171
+ -- Recursively call the function for elements of the array
172
+ json_data = jsonb_set(json_data, ARRAY[idx::text ], _lantern_internal .mask_order_by_in_plan (json_data- > idx));
173
+ END LOOP;
174
+ END IF;
175
+
176
+ RETURN json_data;
177
+ END;
178
+ $$ LANGUAGE plpgsql;
179
+
180
+ CREATE OR REPLACE FUNCTION _lantern_internal .maybe_setup_weighted_vector_search() RETURNS VOID AS
181
+ $weighted_vector_search$
182
+ DECLARE
183
+ pgvector_exists boolean ;
184
+ BEGIN
185
+ -- Check if the vector type from pgvector exists
186
+ SELECT EXISTS (
187
+ SELECT 1
188
+ FROM pg_type
189
+ WHERE typname = ' vector'
190
+ ) INTO pgvector_exists;
191
+
192
+ IF NOT pgvector_exists THEN
193
+ RAISE NOTICE ' pgvector extension not found. Skipping lantern weighted vector search setup' ;
194
+ RETURN;
195
+ END IF;
196
+
197
+ CREATE OR REPLACE FUNCTION lantern .weighted_vector_search(
198
+ relation_type anyelement,
199
+ w1 numeric ,
200
+ col1 text ,
201
+ vec1 vector,
202
+ w2 numeric = 0 ,
203
+ col2 text = NULL ,
204
+ vec2 vector = NULL ,
205
+ w3 numeric = 0 ,
206
+ col3 text = NULL ,
207
+ vec3 vector = NULL ,
208
+ ef integer = 100 ,
209
+ max_dist numeric = NULL ,
210
+ -- set l2 (pgvector) and l2sq (lantern) as default, as we do for lantern index.
211
+ distance_operator text = ' <->' ,
212
+ id_col text = ' id' ,
213
+ exact boolean = false,
214
+ debug_output boolean = false,
215
+ analyze_output boolean = false
216
+ )
217
+ -- N.B. Something seems strange about PL/pgSQL functions that return table with anyelement
218
+ -- when there is single "anylement column" being returned (e.g. returns table ("row" anylement))
219
+ -- then that single "column" is properly spread with source table's column names
220
+ -- but, when returning ("row" anyelement, "anothercol" integer), things fall all oaver the place
221
+ -- now, the returned table always has 2 columns one row that is a record of sorts, and one "anothercol"
222
+ RETURNS TABLE (" row" anyelement) AS
223
+ $$
224
+ DECLARE
225
+ joint_condition text ;
226
+ query_base text ;
227
+ query_final_where text = ' ' ;
228
+ query1 text ;
229
+ query2 text ;
230
+ query3 text ;
231
+ -- variables for weighted columns
232
+ wc1 text = NULL ;
233
+ wc2 text = NULL ;
234
+ wc3 text = NULL ;
235
+ cte_query text ;
236
+ maybe_unions_query text ;
237
+ final_query text ;
238
+ explain_query text ;
239
+ explain_output jsonb;
240
+ old_hnsw_ef_search numeric ;
241
+ debug_count integer ;
242
+ maybe_analyze text = ' ' ;
243
+ BEGIN
244
+ -- TODO:: better validate inputs to throw nicer errors in case of wrong input:
245
+ -- 1. only allow valid distance_operator stirngs (<->, <=>, but not abracadabra)
246
+ -- 2. only allow valid column names
247
+ -- 3. throw an error on negative weights
248
+ -- 4. check that id_col column exists before proceeding
249
+
250
+ IF analyze_output THEN
251
+ maybe_analyze := ' ANALYZE,' ;
252
+ END IF;
253
+ -- Joint similarity metric condition
254
+ wc1 := format(' (%s * (%I %s %L))' , w1, col1, distance_operator, vec1);
255
+ IF w2 > 0 AND col2 IS NOT NULL AND vec2 IS NOT NULL THEN
256
+ wc2 := format(' (%s * (%I %s %L))' , w2, col2, distance_operator, vec2);
257
+ END IF;
258
+ IF w3 > 0 AND col3 IS NOT NULL AND vec3 IS NOT NULL THEN
259
+ wc3 := format(' (%s * (%I %s %L))' , w3, col3, distance_operator, vec3);
260
+ END IF;
261
+
262
+ joint_condition := wc1 || COALESCE(' +' || wc2, ' ' ) || COALESCE(' +' || wc3, ' ' );
263
+
264
+ -- Base query with joint similarity metric
265
+ query_base := format(' SELECT * FROM %s ' , pg_typeof(relation_type));
266
+ IF max_dist IS NOT NULL THEN
267
+ query_final_where := format(' WHERE %s < %L' , joint_condition, max_dist);
268
+ END IF;
269
+
270
+ IF exact THEN
271
+ final_query := query_base || query_final_where || format(' ORDER BY %s' , joint_condition);
272
+ IF debug_output THEN
273
+ explain_query := format(' EXPLAIN (%s COSTS FALSE, FORMAT JSON) %s' , maybe_analyze, final_query);
274
+ EXECUTE explain_query INTO explain_output;
275
+
276
+ RAISE WARNING ' Query: %' , _lantern_internal .mask_arrays (final_query);
277
+
278
+ explain_output := _lantern_internal .mask_order_by_in_plan (explain_output);
279
+ RAISE WARNING ' weighted vector search explain(exact=true): %' , jsonb_pretty(explain_output);
280
+ END IF;
281
+ RETURN QUERY EXECUTE final_query;
282
+ -- the empty return below is crucial, to make sure the rest of the function is not executed after the return query above
283
+ RETURN;
284
+ END IF;
285
+
286
+ EXECUTE format(' SET LOCAL hnsw.ef_search TO %L' , ef);
287
+ -- UNION ALL.. part of the final query that aggregates results from individual vector search queries
288
+ maybe_unions_query := ' ' ;
289
+
290
+ -- Query 1: Order by first condition's weighted similarity
291
+ query1 := format(' %s ORDER BY %I %s %L LIMIT %L' , query_base || query_final_where, col1, distance_operator, vec1, ef);
292
+
293
+ IF debug_output THEN
294
+ EXECUTE format(' SELECT count(*) FROM (%s) t' , query1) INTO debug_count;
295
+ RAISE WARNING ' col1 yielded % rows' , debug_count;
296
+ END IF;
297
+
298
+ cte_query = format(' WITH query1 AS (%s) ' , query1);
299
+
300
+ -- Query 2: Order by other conditions' weighted similarity, if applicable
301
+ IF w2 > 0 AND col2 IS NOT NULL AND vec2 IS NOT NULL THEN
302
+ query2 := format(' %s ORDER BY %I %s %L LIMIT %L' , query_base || query_final_where, col2, distance_operator, vec2, ef);
303
+ cte_query := cte_query || format(' , query2 AS (%s)' , query2);
304
+ maybe_unions_query := maybe_unions_query || format(' UNION ALL (SELECT * FROM query2) ' );
305
+ IF debug_output THEN
306
+ EXECUTE format(' SELECT count(*) FROM (%s) t' , query2) INTO debug_count;
307
+ RAISE WARNING ' col2 yielded % rows' , debug_count;
308
+ END IF;
309
+ END IF;
310
+
311
+ IF w3 > 0 AND col3 IS NOT NULL AND vec3 IS NOT NULL THEN
312
+ query3 := format(' %s ORDER BY %I %s %L LIMIT %L' , query_base || query_final_where, col3, distance_operator, vec3, ef);
313
+ cte_query := cte_query || format(' , query3 AS (%s)' , query3);
314
+ maybe_unions_query := maybe_unions_query || format(' UNION ALL (SELECT * FROM query3) ' );
315
+ IF debug_output THEN
316
+ EXECUTE format(' SELECT count(*) FROM (%s) t' , query3) INTO debug_count;
317
+ RAISE WARNING ' col3 yielded % rows' , debug_count;
318
+ END IF;
319
+ END IF;
320
+
321
+ final_query := cte_query || format($final_cte_query$SELECT * FROM (
322
+ SELECT DISTINCT ON (%I) * FROM (
323
+ (SELECT * FROM query1)
324
+ %s
325
+ ) t
326
+ )
327
+ tt %s ORDER BY %s$final_cte_query$,
328
+ id_col, maybe_unions_query, query_final_where, joint_condition);
329
+
330
+ IF debug_output THEN
331
+ explain_query := format(' EXPLAIN (%s COSTS FALSE, FORMAT JSON) %s' , maybe_analyze, final_query);
332
+ EXECUTE explain_query INTO explain_output;
333
+
334
+ RAISE WARNING ' Query: %' , _lantern_internal .mask_arrays (final_query);
335
+
336
+ explain_output := _lantern_internal .mask_order_by_in_plan (explain_output);
337
+ RAISE WARNING ' weighted vector search explain: %' , jsonb_pretty(explain_output);
338
+ END IF;
339
+ RETURN QUERY EXECUTE final_query;
340
+ END
341
+ $$ LANGUAGE plpgsql;
342
+
343
+ -- setup API shortcuts
344
+ CREATE OR REPLACE FUNCTION lantern .weighted_vector_search_cos(
345
+ relation_type anyelement,
346
+ w1 numeric ,
347
+ col1 text ,
348
+ vec1 vector,
349
+ w2 numeric = 0 ,
350
+ col2 text = NULL ,
351
+ vec2 vector = NULL ,
352
+ w3 numeric = 0 ,
353
+ col3 text = NULL ,
354
+ vec3 vector = NULL ,
355
+ ef integer = 100 ,
356
+ max_dist numeric = NULL ,
357
+ id_col text = ' id' ,
358
+ exact boolean = false,
359
+ debug_output boolean = false
360
+ )
361
+ -- N.B. Something seems strange about PL/pgSQL functions that return table with anyelement
362
+ -- when there is single "anylement column" being returned (e.g. returns table ("row" anylement))
363
+ -- then that single "column" is properly spread with source table's column names
364
+ -- but, when returning ("row" anyelement, "anothercol" integer), things fall all oaver the place
365
+ -- now, the returned table always has 2 columns one row that is a record of sorts, and one "anothercol"
366
+ RETURNS TABLE (" row" anyelement) AS $$
367
+
368
+ BEGIN
369
+ RETURN QUERY SELECT * FROM lantern .weighted_vector_search (relation_type, w1, col1, vec1, w2, col2, vec2, w3, col3, vec3, ef, max_dist, ' <=>' , id_col, exact, debug_output);
370
+ END $$ LANGUAGE plpgsql;
371
+
372
+ CREATE OR REPLACE FUNCTION lantern .weighted_vector_search_l2sq(
373
+ relation_type anyelement,
374
+ w1 numeric ,
375
+ col1 text ,
376
+ vec1 vector,
377
+ w2 numeric = 0 ,
378
+ col2 text = NULL ,
379
+ vec2 vector = NULL ,
380
+ w3 numeric = 0 ,
381
+ col3 text = NULL ,
382
+ vec3 vector = NULL ,
383
+ ef integer = 100 ,
384
+ max_dist numeric = NULL ,
385
+ id_col text = ' id' ,
386
+ exact boolean = false,
387
+ debug_output boolean = false
388
+ )
389
+ -- N.B. Something seems strange about PL/pgSQL functions that return table with anyelement
390
+ -- when there is single "anylement column" being returned (e.g. returns table ("row" anylement))
391
+ -- then that single "column" is properly spread with source table's column names
392
+ -- but, when returning ("row" anyelement, "anothercol" integer), things fall all oaver the place
393
+ -- now, the returned table always has 2 columns one row that is a record of sorts, and one "anothercol"
394
+ RETURNS TABLE (" row" anyelement) AS $$
395
+
396
+ BEGIN
397
+ RETURN QUERY SELECT * FROM lantern .weighted_vector_search (relation_type, w1, col1, vec1, w2, col2, vec2, w3, col3, vec3, ef, max_dist, ' <->' , id_col, exact, debug_output);
398
+ END $$ LANGUAGE plpgsql;
399
+
400
+
401
+ END
402
+ $weighted_vector_search$ LANGUAGE plpgsql;
403
+
404
+ SELECT _lantern_internal .maybe_setup_weighted_vector_search ();
405
+ DROP FUNCTION _lantern_internal .maybe_setup_weighted_vector_search ;
406
+
0 commit comments