engine: refactoring, caching
This commit is contained in:
parent
23c71ac642
commit
aa9334d372
1 changed files with 72 additions and 42 deletions
114
engine.py
114
engine.py
|
|
@ -25,6 +25,9 @@ FIRST_SEP = ':'
|
|||
OTHER_SEP = ','
|
||||
LOC_SEP = ';'
|
||||
|
||||
cached_dists = {}
|
||||
cached_series = {}
|
||||
|
||||
EncodedLocation = list[tuple[float, list[float]]]
|
||||
|
||||
|
||||
|
|
@ -85,6 +88,37 @@ def spherical_dist(pos1, pos2, r=6378137):
|
|||
return r * np.arccos(cos_lat_d - cos_lat1 * cos_lat2 * (1 - cos_lon_d))
|
||||
|
||||
|
||||
def get_dist_matrix(brand: str):
|
||||
try:
|
||||
return cached_dists[brand]
|
||||
except KeyError:
|
||||
greggs = np.array(fetch_data(brand))
|
||||
|
||||
repeat_rows = np.tile(greggs, (len(greggs), 1, 1))
|
||||
repeat_cols = np.transpose(repeat_rows, (1, 0, 2))
|
||||
dist_matrix = spherical_dist(repeat_rows, repeat_cols)
|
||||
|
||||
cached_dists[brand] = dist_matrix
|
||||
|
||||
return dist_matrix
|
||||
|
||||
|
||||
def get_dist_series_list(brand: str, n_locs: int):
|
||||
if brand in cached_series and cached_series[brand][0] == n_locs:
|
||||
return cached_series[brand][1]
|
||||
|
||||
dist_matrix = get_dist_matrix("greggs")
|
||||
|
||||
# split the distances matrix into a list of series, which allows us to sort each row
|
||||
dist_series_list = []
|
||||
for i in dist_matrix:
|
||||
dist_series_list.append(pd.Series(i).sort_values().head(n_locs+1)[1:])
|
||||
|
||||
cached_series[brand] = (n_locs, dist_series_list)
|
||||
|
||||
return dist_series_list
|
||||
|
||||
|
||||
# (lat, lon), dist
|
||||
StationT = tuple[tuple[float, float], float]
|
||||
|
||||
|
|
@ -102,7 +136,39 @@ def trilaterate(stations: list[StationT], disp: bool = False) -> tuple[float, fl
|
|||
Each station is of the format ((lat, lon), distance).
|
||||
"""
|
||||
|
||||
return scipy.optimize.fmin(lambda pos: trilat_error(stations, pos), (0., 0.), disp=disp)
|
||||
res = scipy.optimize.minimize(
|
||||
lambda pos: trilat_error(stations, pos),
|
||||
stations[0][0],
|
||||
method='Nelder-Mead',
|
||||
)
|
||||
|
||||
if res.success:
|
||||
return res.x
|
||||
else:
|
||||
raise ValueError("Optimisation failed.")
|
||||
|
||||
|
||||
def encode_greggs(loc: int) -> list[float]:
|
||||
"""Given the id of a greggs, encode as a list of distances."""
|
||||
|
||||
dist_matrix = get_dist_matrix("greggs")
|
||||
greggs_distances = np.sort(dist_matrix[loc])[1:DISTS_COUNT+1]
|
||||
|
||||
return list(map(float, greggs_distances))
|
||||
|
||||
|
||||
def decode_greggs(distances: list[float]) -> int:
|
||||
"""Get the id of a greggs given a list of distances."""
|
||||
|
||||
dist_series_list = get_dist_series_list("greggs", len(distances))
|
||||
|
||||
errors = [sum((pd.Series(j) - distances) ** 2) for j in dist_series_list]
|
||||
|
||||
minerr = min(errors)
|
||||
if minerr > 1:
|
||||
print(f"warning: high error value of {minerr}")
|
||||
|
||||
return errors.index(min(errors))
|
||||
|
||||
|
||||
def encode(location: tuple[float, float]) -> EncodedLocation:
|
||||
|
|
@ -110,57 +176,21 @@ def encode(location: tuple[float, float]) -> EncodedLocation:
|
|||
|
||||
greggs = np.array(fetch_data("greggs"))
|
||||
|
||||
repeat_rows = np.tile(greggs, (len(greggs), 1, 1))
|
||||
repeat_cols = np.transpose(repeat_rows, (1, 0, 2))
|
||||
dist_matrix = spherical_dist(repeat_rows, repeat_cols)
|
||||
|
||||
repeated = np.tile(location, (len(greggs), 1))
|
||||
distances = spherical_dist(repeated, greggs)
|
||||
distances = pd.Series(distances)
|
||||
distances = distances.sort_values()
|
||||
|
||||
distances = pd.Series(spherical_dist(repeated, greggs)).sort_values()
|
||||
closest = distances.head(LOCS_COUNT)
|
||||
|
||||
result: EncodedLocation = []
|
||||
for v, i in zip(closest.values, closest.index):
|
||||
greggs_distances = np.sort(dist_matrix[i])[1:DISTS_COUNT+1]
|
||||
|
||||
result.append((v, list(map(float, greggs_distances))))
|
||||
|
||||
# Stub
|
||||
return result
|
||||
return [(v, encode_greggs(i)) for v, i in zip(closest.values, closest.index)]
|
||||
|
||||
|
||||
def decode(location: EncodedLocation, disp: bool = False) -> tuple[float, float]:
|
||||
"""Decode into a location."""
|
||||
|
||||
# form the distances matrix
|
||||
greggs_raw = fetch_data("greggs")
|
||||
greggs = np.array(greggs_raw)
|
||||
repeat_rows = np.tile(greggs, (len(greggs), 1, 1))
|
||||
repeat_cols = np.transpose(repeat_rows, (1, 0, 2))
|
||||
dist_matrix = spherical_dist(repeat_rows, repeat_cols)
|
||||
|
||||
# split the distances matrix into a list of series, which allows us to sort each row
|
||||
dist_series_list = []
|
||||
for i in dist_matrix:
|
||||
dist_series_list.append(pd.Series(i).sort_values().head(len(location[0][1])+1)[1:])
|
||||
|
||||
# part 1: find the ID of each gregg's
|
||||
closest_greggs = []
|
||||
for loc in location:
|
||||
dists = loc[1]
|
||||
|
||||
errors = [sum((j - dists) ** 2) for j in dist_series_list]
|
||||
|
||||
minerr = min(errors)
|
||||
if minerr > 1:
|
||||
print(f"warning: high error value of {minerr}")
|
||||
|
||||
closest_greggs.append(errors.index(min(errors)))
|
||||
closest_greggs = [decode_greggs(dists) for _, dists in location]
|
||||
|
||||
# part 2: trilaterate
|
||||
stations: list[StationT] = [(greggs_raw[g], location[i][0]) for i, g in enumerate(closest_greggs)]
|
||||
greggs = fetch_data("greggs")
|
||||
stations: list[StationT] = [(greggs[g], location[i][0]) for i, g in enumerate(closest_greggs)]
|
||||
return trilaterate(stations, disp=disp)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue