engine: refactoring, caching

This commit is contained in:
Oliver Gaskell 2025-11-02 12:40:31 +00:00
parent 23c71ac642
commit aa9334d372
No known key found for this signature in database
GPG key ID: F971A08925FCC0AD

114
engine.py
View file

@ -25,6 +25,9 @@ FIRST_SEP = ':'
OTHER_SEP = ','
LOC_SEP = ';'
cached_dists = {}
cached_series = {}
EncodedLocation = list[tuple[float, list[float]]]
@ -85,6 +88,37 @@ def spherical_dist(pos1, pos2, r=6378137):
return r * np.arccos(cos_lat_d - cos_lat1 * cos_lat2 * (1 - cos_lon_d))
def get_dist_matrix(brand: str):
try:
return cached_dists[brand]
except KeyError:
greggs = np.array(fetch_data(brand))
repeat_rows = np.tile(greggs, (len(greggs), 1, 1))
repeat_cols = np.transpose(repeat_rows, (1, 0, 2))
dist_matrix = spherical_dist(repeat_rows, repeat_cols)
cached_dists[brand] = dist_matrix
return dist_matrix
def get_dist_series_list(brand: str, n_locs: int):
if brand in cached_series and cached_series[brand][0] == n_locs:
return cached_series[brand][1]
dist_matrix = get_dist_matrix("greggs")
# split the distances matrix into a list of series, which allows us to sort each row
dist_series_list = []
for i in dist_matrix:
dist_series_list.append(pd.Series(i).sort_values().head(n_locs+1)[1:])
cached_series[brand] = (n_locs, dist_series_list)
return dist_series_list
# (lat, lon), dist
StationT = tuple[tuple[float, float], float]
@ -102,7 +136,39 @@ def trilaterate(stations: list[StationT], disp: bool = False) -> tuple[float, fl
Each station is of the format ((lat, lon), distance).
"""
return scipy.optimize.fmin(lambda pos: trilat_error(stations, pos), (0., 0.), disp=disp)
res = scipy.optimize.minimize(
lambda pos: trilat_error(stations, pos),
stations[0][0],
method='Nelder-Mead',
)
if res.success:
return res.x
else:
raise ValueError("Optimisation failed.")
def encode_greggs(loc: int) -> list[float]:
"""Given the id of a greggs, encode as a list of distances."""
dist_matrix = get_dist_matrix("greggs")
greggs_distances = np.sort(dist_matrix[loc])[1:DISTS_COUNT+1]
return list(map(float, greggs_distances))
def decode_greggs(distances: list[float]) -> int:
"""Get the id of a greggs given a list of distances."""
dist_series_list = get_dist_series_list("greggs", len(distances))
errors = [sum((pd.Series(j) - distances) ** 2) for j in dist_series_list]
minerr = min(errors)
if minerr > 1:
print(f"warning: high error value of {minerr}")
return errors.index(min(errors))
def encode(location: tuple[float, float]) -> EncodedLocation:
@ -110,57 +176,21 @@ def encode(location: tuple[float, float]) -> EncodedLocation:
greggs = np.array(fetch_data("greggs"))
repeat_rows = np.tile(greggs, (len(greggs), 1, 1))
repeat_cols = np.transpose(repeat_rows, (1, 0, 2))
dist_matrix = spherical_dist(repeat_rows, repeat_cols)
repeated = np.tile(location, (len(greggs), 1))
distances = spherical_dist(repeated, greggs)
distances = pd.Series(distances)
distances = distances.sort_values()
distances = pd.Series(spherical_dist(repeated, greggs)).sort_values()
closest = distances.head(LOCS_COUNT)
result: EncodedLocation = []
for v, i in zip(closest.values, closest.index):
greggs_distances = np.sort(dist_matrix[i])[1:DISTS_COUNT+1]
result.append((v, list(map(float, greggs_distances))))
# Stub
return result
return [(v, encode_greggs(i)) for v, i in zip(closest.values, closest.index)]
def decode(location: EncodedLocation, disp: bool = False) -> tuple[float, float]:
"""Decode into a location."""
# form the distances matrix
greggs_raw = fetch_data("greggs")
greggs = np.array(greggs_raw)
repeat_rows = np.tile(greggs, (len(greggs), 1, 1))
repeat_cols = np.transpose(repeat_rows, (1, 0, 2))
dist_matrix = spherical_dist(repeat_rows, repeat_cols)
# split the distances matrix into a list of series, which allows us to sort each row
dist_series_list = []
for i in dist_matrix:
dist_series_list.append(pd.Series(i).sort_values().head(len(location[0][1])+1)[1:])
# part 1: find the ID of each gregg's
closest_greggs = []
for loc in location:
dists = loc[1]
errors = [sum((j - dists) ** 2) for j in dist_series_list]
minerr = min(errors)
if minerr > 1:
print(f"warning: high error value of {minerr}")
closest_greggs.append(errors.index(min(errors)))
closest_greggs = [decode_greggs(dists) for _, dists in location]
# part 2: trilaterate
stations: list[StationT] = [(greggs_raw[g], location[i][0]) for i, g in enumerate(closest_greggs)]
greggs = fetch_data("greggs")
stations: list[StationT] = [(greggs[g], location[i][0]) for i, g in enumerate(closest_greggs)]
return trilaterate(stations, disp=disp)