44"""Benchmarks for IVF_PQ vector search performance."""
55
66import math
7+ import multiprocessing as mp
78import tempfile
9+ from concurrent .futures import ThreadPoolExecutor
810from pathlib import Path
911
1012import lance
3537K_LABELS = ["k10" , "k100" ]
3638
3739
38- # Global cache for datasets, keyed by (num_rows, dim)
39- _DATASET_CACHE = {}
40+ # Datasets are stored in fixed temporary directories and reused between runs
41+ # to avoid retraining indexes
4042
4143
4244def _generate_vector_dataset (num_rows : int , dim : int = 1024 ):
@@ -73,46 +75,57 @@ def _generate_vector_dataset(num_rows: int, dim: int = 1024):
7375def _get_or_create_dataset (num_rows : int , dim : int = 1024 ) -> str :
7476 """Get or create a dataset with the specified parameters.
7577
76- Datasets are cached globally per process to avoid expensive recreation.
78+ Uses a fixed temporary directory so datasets persist between benchmark runs.
79+ If the dataset exists and has the correct number of rows, it will be reused.
7780 Returns the URI to the dataset.
7881 """
79- cache_key = (num_rows , dim )
80-
81- if cache_key not in _DATASET_CACHE :
82- # Create a persistent temporary directory for this dataset
83- tmpdir = tempfile .mkdtemp (prefix = f"lance_bench_{ num_rows } _{ dim } _" )
84- dataset_uri = str (Path (tmpdir ) / "vector_dataset.lance" )
85-
86- # Create schema
87- schema = pa .schema (
88- [
89- pa .field ("vector" , pa .list_ (pa .float32 (), dim )),
90- pa .field ("id" , pa .int64 ()),
91- ]
92- )
93-
94- # Generate and write dataset
95- data = _generate_vector_dataset (num_rows , dim )
96- ds = lance .write_dataset (
97- data ,
98- dataset_uri ,
99- schema = schema ,
100- mode = "create" ,
101- )
82+ # Use a fixed directory path based on parameters
83+ tmpdir = Path (tempfile .gettempdir ()) / f"lance_bench_{ num_rows } _{ dim } "
84+ tmpdir .mkdir (exist_ok = True )
85+ dataset_uri = "file://" + str (tmpdir / "vector_dataset.lance" )
86+
87+ # Check if dataset already exists and has correct row count
88+ try :
89+ ds = lance .dataset (dataset_uri )
90+ if ds .count_rows () == num_rows :
91+ print (f"Reusing existing dataset at { dataset_uri } " )
92+ return dataset_uri
93+ else :
94+ print (
95+ "Dataset exists but has wrong row count "
96+ f"({ ds .count_rows ()} vs { num_rows } ), recreating..."
97+ )
98+ except Exception :
99+ print (f"Creating new dataset at { dataset_uri } " )
100+
101+ # Create schema
102+ schema = pa .schema (
103+ [
104+ pa .field ("vector" , pa .list_ (pa .float32 (), dim )),
105+ pa .field ("id" , pa .int64 ()),
106+ ]
107+ )
102108
103- num_partitions = min (num_rows // 4000 , int (math .sqrt (num_rows )))
109+ # Generate and write dataset
110+ data = _generate_vector_dataset (num_rows , dim )
111+ ds = lance .write_dataset (
112+ data ,
113+ dataset_uri ,
114+ schema = schema ,
115+ mode = "overwrite" , # Use overwrite to handle recreation
116+ )
104117
105- # Create IVF_PQ index
106- ds .create_index (
107- "vector" ,
108- index_type = "IVF_PQ" ,
109- num_partitions = num_partitions ,
110- num_sub_vectors = dim // 16 ,
111- )
118+ num_partitions = min (num_rows // 4000 , int (math .sqrt (num_rows )))
112119
113- _DATASET_CACHE [cache_key ] = dataset_uri
120+ # Create IVF_PQ index
121+ ds .create_index (
122+ "vector" ,
123+ index_type = "IVF_PQ" ,
124+ num_partitions = num_partitions ,
125+ num_sub_vectors = dim // 16 ,
126+ )
114127
115- return _DATASET_CACHE [ cache_key ]
128+ return dataset_uri
116129
117130
118131@pytest .mark .parametrize ("num_rows" , DATASET_SIZES , ids = DATASET_SIZE_LABELS )
@@ -139,7 +152,7 @@ def test_ivf_pq_search(
139152
140153 Uses 1024-dimensional float32 vectors with IVF_PQ index.
141154 """
142- # Get or create the dataset (cached globally per process )
155+ # Get or create the dataset (reused from fixed temp directory between runs )
143156 dataset_uri = _get_or_create_dataset (num_rows , dim = VECTOR_DIM )
144157 ds = lance .dataset (dataset_uri )
145158
@@ -204,7 +217,7 @@ def test_ivf_pq_search_with_payload(
204217 Similar to test_ivf_pq_search but includes retrieving vector data
205218 along with results, which tests data loading performance.
206219 """
207- # Get or create the dataset (cached globally per process )
220+ # Get or create the dataset (reused from fixed temp directory between runs )
208221 dataset_uri = _get_or_create_dataset (num_rows , dim = VECTOR_DIM )
209222 ds = lance .dataset (dataset_uri )
210223
@@ -248,3 +261,57 @@ def bench():
248261 iterations = 1 ,
249262 setup = setup ,
250263 )
264+
265+
266+ @pytest .mark .parametrize ("use_cache" , [True , False ], ids = ["cache" , "no_cache" ])
267+ def test_ivf_pq_throughput (
268+ benchmark ,
269+ use_cache : bool ,
270+ ):
271+ """Benchmark IVF_PQ vector search throughput (with payload)"""
272+ # Get or create the dataset (reused from fixed temp directory between runs)
273+ dataset_uri = _get_or_create_dataset (1_000_000 , dim = 768 )
274+ ds = lance .dataset (dataset_uri )
275+
276+ NUM_QUERIES = 1000
277+
278+ # Generate query vectors
279+ query_vectors = [
280+ np .random .randn (768 ).astype (np .float32 ) for _ in range (NUM_QUERIES )
281+ ]
282+
283+ def clear_cache ():
284+ if not use_cache :
285+ wipe_os_cache (dataset_uri )
286+
287+ def bench ():
288+ with ThreadPoolExecutor (max_workers = 2 * (mp .cpu_count () - 2 )) as executor :
289+ futures = [
290+ executor .submit (
291+ ds .to_table ,
292+ nearest = {
293+ "column" : "vector" ,
294+ "q" : query_vector ,
295+ "k" : 50 ,
296+ "nprobes" : 20 ,
297+ "refine_factor" : 10 ,
298+ },
299+ columns = ["vector" , "_distance" ],
300+ )
301+ for query_vector in query_vectors
302+ ]
303+ for future in futures :
304+ future .result ()
305+
306+ if use_cache :
307+ setup = None
308+ else :
309+ setup = clear_cache
310+
311+ benchmark .pedantic (
312+ bench ,
313+ warmup_rounds = 1 ,
314+ rounds = 1 ,
315+ iterations = 1 ,
316+ setup = setup ,
317+ )
0 commit comments