asigalov61 commited on
Commit
02ba06c
·
verified ·
1 Parent(s): ab15083

Upload 2 files

Browse files
Files changed (1) hide show
  1. TCUPY.py +136 -9
TCUPY.py CHANGED
@@ -58,20 +58,23 @@ import tqdm
58
 
59
  try:
60
  import cupy as cp
61
- import cupy as np
62
  print('=' * 70)
63
  print('CuPy is found!')
64
  print('Will use CuPy and GPU for processing!')
65
  print('=' * 70)
66
 
67
  except ImportError as e:
68
- print(f"Error: Could not import CuPy. Details: {e}")
69
- # Handle the error, such as providing a fallback or exiting the program
70
- # For example:
71
  print("Please make sure CuPy is installed.")
 
 
 
 
 
72
  print('=' * 70)
73
-
74
- raise RuntimeError("CuPy could not be loaded!") from e
75
 
76
  ################################################################################
77
 
@@ -264,7 +267,12 @@ void merge_pair_kernel(const long* input, long* output,
264
  }
265
  }
266
  '''
267
- merge_kernel = cp.RawKernel(merge_kernel_code, 'merge_pair_kernel')
 
 
 
 
 
268
 
269
  ###################################################################################
270
 
@@ -394,7 +402,12 @@ void fused_merge_kernel(long* data_in, long* data_out, long* lengths, const long
394
  }
395
  }
396
  '''
397
- fused_kernel = cp.RawKernel(fused_merge_kernel_code, 'fused_merge_kernel')
 
 
 
 
 
398
 
399
  ###################################################################################
400
 
@@ -1233,7 +1246,7 @@ def find_matches_fast(src_array, trg_array, seed: int = 0) -> int:
1233
 
1234
  ###################################################################################
1235
 
1236
- def find_repeating_non_overlapping_patterns(arr, min_len):
1237
  """
1238
  Finds all repeating non-overlapping patterns of min_len and longer.
1239
  GPU-Accelerated using CuPy with O(N) memory per length.
@@ -1342,6 +1355,120 @@ def find_repeating_non_overlapping_patterns(arr, min_len):
1342
 
1343
  ###################################################################################
1344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1345
  print('Module is loaded!')
1346
  print('Enjoy! :)')
1347
  print('=' * 70)
 
58
 
59
  try:
60
  import cupy as cp
61
+ import numpy as np
62
  print('=' * 70)
63
  print('CuPy is found!')
64
  print('Will use CuPy and GPU for processing!')
65
  print('=' * 70)
66
 
67
  except ImportError as e:
68
+ print('Error: Could not import CuPy!')
69
+ print(f'Details: {e}')
70
+ print('=' * 70)
71
  print("Please make sure CuPy is installed.")
72
+ print('pip install cupy-cuda13x')
73
+ print('=' * 70)
74
+ print('Will use NumPy for now...')
75
+ import numpy as cp
76
+ import numpy as np
77
  print('=' * 70)
 
 
78
 
79
  ################################################################################
80
 
 
267
  }
268
  }
269
  '''
270
+
271
+ try:
272
+ merge_kernel = cp.RawKernel(merge_kernel_code, 'merge_pair_kernel')
273
+
274
+ except:
275
+ pass
276
 
277
  ###################################################################################
278
 
 
402
  }
403
  }
404
  '''
405
+
406
+ try:
407
+ fused_kernel = cp.RawKernel(fused_merge_kernel_code, 'fused_merge_kernel')
408
+
409
+ except:
410
+ pass
411
 
412
  ###################################################################################
413
 
 
1246
 
1247
  ###################################################################################
1248
 
1249
+ def find_repeating_non_overlapping_patterns(arr, min_len=8):
1250
  """
1251
  Finds all repeating non-overlapping patterns of min_len and longer.
1252
  GPU-Accelerated using CuPy with O(N) memory per length.
 
1355
 
1356
  ###################################################################################
1357
 
1358
+ def find_repeating_non_overlapping_patterns_numpy(arr, min_len=8):
1359
+ """
1360
+ Finds all repeating non-overlapping patterns of min_len and longer.
1361
+ Fully NumPy-vectorized except where sequential logic is required.
1362
+ """
1363
+ arr = np.asarray(arr, dtype=np.int64)
1364
+ n = len(arr)
1365
+ if n < min_len * 2:
1366
+ return {}
1367
+
1368
+ max_len = n // 2
1369
+ consumed = np.zeros(n, dtype=bool)
1370
+ result = {}
1371
+
1372
+ BASE = np.int64(1000000007)
1373
+
1374
+ # ---------------------------------------------------------
1375
+ # 1. Precompute powers (vectorized)
1376
+ # ---------------------------------------------------------
1377
+ powers = np.ones(max_len + 1, dtype=np.int64)
1378
+ with np.errstate(over='ignore'):
1379
+ powers[1:] = np.cumprod(np.full(max_len, BASE, dtype=np.int64))
1380
+
1381
+ # ---------------------------------------------------------
1382
+ # 2. Prefix hash (must stay a loop — correct polynomial hash)
1383
+ # ---------------------------------------------------------
1384
+ pref = np.zeros(n + 1, dtype=np.int64)
1385
+ with np.errstate(over='ignore'):
1386
+ for i in range(1, n + 1):
1387
+ pref[i] = pref[i - 1] * BASE + arr[i - 1]
1388
+
1389
+ # ---------------------------------------------------------
1390
+ # 3. Main loop over pattern lengths
1391
+ # ---------------------------------------------------------
1392
+ for L in range(max_len, min_len - 1, -1):
1393
+ n_hashes = n - L + 1
1394
+ if n_hashes <= 0:
1395
+ continue
1396
+
1397
+ # -----------------------------------------------------
1398
+ # 3A. Vectorized rolling hash extraction
1399
+ # -----------------------------------------------------
1400
+ with np.errstate(over='ignore'):
1401
+ raw = pref[L:L + n_hashes] - pref[:n_hashes] * powers[L]
1402
+
1403
+ # Mix upper/lower bits
1404
+ hash_l = raw ^ (raw >> 32)
1405
+
1406
+ # -----------------------------------------------------
1407
+ # 3B. Sort hashes and find equal groups (vectorized)
1408
+ # -----------------------------------------------------
1409
+ sort_idx = np.argsort(hash_l)
1410
+ sorted_hash = hash_l[sort_idx]
1411
+
1412
+ diff = np.flatnonzero(sorted_hash[1:] != sorted_hash[:-1])
1413
+ start_idx = np.concatenate(([0], diff + 1))
1414
+ end_idx = np.concatenate((diff + 1, [n_hashes]))
1415
+
1416
+ # Keep only groups with ≥2 matches
1417
+ mask = (end_idx - start_idx) >= 2
1418
+ if not np.any(mask):
1419
+ continue
1420
+
1421
+ start_idx = start_idx[mask]
1422
+ end_idx = end_idx[mask]
1423
+
1424
+ # -----------------------------------------------------
1425
+ # 3C. Flatten candidate indices (vectorized)
1426
+ # -----------------------------------------------------
1427
+ ranges = [np.arange(s, e) for s, e in zip(start_idx, end_idx)]
1428
+ indices_flat = sort_idx[np.concatenate(ranges)]
1429
+
1430
+ # -----------------------------------------------------
1431
+ # 3D. Group by exact bytes (collision‑free)
1432
+ # -----------------------------------------------------
1433
+ groups = {}
1434
+ for idx in indices_flat:
1435
+ key = arr[idx:idx + L].tobytes()
1436
+ groups.setdefault(key, []).append(idx)
1437
+
1438
+ # -----------------------------------------------------
1439
+ # 3E. Greedy non-overlapping selection
1440
+ # -----------------------------------------------------
1441
+ for key, indices in groups.items():
1442
+ indices = np.array(indices)
1443
+ indices.sort()
1444
+
1445
+ valid = []
1446
+ last_end = -1
1447
+
1448
+ for i in indices:
1449
+ if consumed[i]:
1450
+ continue
1451
+ if i >= last_end:
1452
+ valid.append(i)
1453
+ last_end = i + L
1454
+
1455
+ if len(valid) >= 2:
1456
+ pat = tuple(int(x) for x in arr[valid[0]:valid[0] + L])
1457
+ result[pat] = len(valid)
1458
+
1459
+ # Mark consumed
1460
+ last_end = -1
1461
+ for i in indices:
1462
+ if consumed[i]:
1463
+ continue
1464
+ if i >= last_end:
1465
+ consumed[i:i + L] = True
1466
+ last_end = i + L
1467
+
1468
+ return result
1469
+
1470
+ ###################################################################################
1471
+
1472
  print('Module is loaded!')
1473
  print('Enjoy! :)')
1474
  print('=' * 70)