engine: refactoring, caching
This commit is contained in:
parent
23c71ac642
commit
aa9334d372
1 changed files with 72 additions and 42 deletions
114
engine.py
114
engine.py
|
|
@ -25,6 +25,9 @@ FIRST_SEP = ':'
|
||||||
OTHER_SEP = ','
|
OTHER_SEP = ','
|
||||||
LOC_SEP = ';'
|
LOC_SEP = ';'
|
||||||
|
|
||||||
|
cached_dists = {}
|
||||||
|
cached_series = {}
|
||||||
|
|
||||||
EncodedLocation = list[tuple[float, list[float]]]
|
EncodedLocation = list[tuple[float, list[float]]]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -85,6 +88,37 @@ def spherical_dist(pos1, pos2, r=6378137):
|
||||||
return r * np.arccos(cos_lat_d - cos_lat1 * cos_lat2 * (1 - cos_lon_d))
|
return r * np.arccos(cos_lat_d - cos_lat1 * cos_lat2 * (1 - cos_lon_d))
|
||||||
|
|
||||||
|
|
||||||
|
def get_dist_matrix(brand: str):
|
||||||
|
try:
|
||||||
|
return cached_dists[brand]
|
||||||
|
except KeyError:
|
||||||
|
greggs = np.array(fetch_data(brand))
|
||||||
|
|
||||||
|
repeat_rows = np.tile(greggs, (len(greggs), 1, 1))
|
||||||
|
repeat_cols = np.transpose(repeat_rows, (1, 0, 2))
|
||||||
|
dist_matrix = spherical_dist(repeat_rows, repeat_cols)
|
||||||
|
|
||||||
|
cached_dists[brand] = dist_matrix
|
||||||
|
|
||||||
|
return dist_matrix
|
||||||
|
|
||||||
|
|
||||||
|
def get_dist_series_list(brand: str, n_locs: int):
|
||||||
|
if brand in cached_series and cached_series[brand][0] == n_locs:
|
||||||
|
return cached_series[brand][1]
|
||||||
|
|
||||||
|
dist_matrix = get_dist_matrix("greggs")
|
||||||
|
|
||||||
|
# split the distances matrix into a list of series, which allows us to sort each row
|
||||||
|
dist_series_list = []
|
||||||
|
for i in dist_matrix:
|
||||||
|
dist_series_list.append(pd.Series(i).sort_values().head(n_locs+1)[1:])
|
||||||
|
|
||||||
|
cached_series[brand] = (n_locs, dist_series_list)
|
||||||
|
|
||||||
|
return dist_series_list
|
||||||
|
|
||||||
|
|
||||||
# (lat, lon), dist
|
# (lat, lon), dist
|
||||||
StationT = tuple[tuple[float, float], float]
|
StationT = tuple[tuple[float, float], float]
|
||||||
|
|
||||||
|
|
@ -102,7 +136,39 @@ def trilaterate(stations: list[StationT], disp: bool = False) -> tuple[float, fl
|
||||||
Each station is of the format ((lat, lon), distance).
|
Each station is of the format ((lat, lon), distance).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return scipy.optimize.fmin(lambda pos: trilat_error(stations, pos), (0., 0.), disp=disp)
|
res = scipy.optimize.minimize(
|
||||||
|
lambda pos: trilat_error(stations, pos),
|
||||||
|
stations[0][0],
|
||||||
|
method='Nelder-Mead',
|
||||||
|
)
|
||||||
|
|
||||||
|
if res.success:
|
||||||
|
return res.x
|
||||||
|
else:
|
||||||
|
raise ValueError("Optimisation failed.")
|
||||||
|
|
||||||
|
|
||||||
|
def encode_greggs(loc: int) -> list[float]:
|
||||||
|
"""Given the id of a greggs, encode as a list of distances."""
|
||||||
|
|
||||||
|
dist_matrix = get_dist_matrix("greggs")
|
||||||
|
greggs_distances = np.sort(dist_matrix[loc])[1:DISTS_COUNT+1]
|
||||||
|
|
||||||
|
return list(map(float, greggs_distances))
|
||||||
|
|
||||||
|
|
||||||
|
def decode_greggs(distances: list[float]) -> int:
|
||||||
|
"""Get the id of a greggs given a list of distances."""
|
||||||
|
|
||||||
|
dist_series_list = get_dist_series_list("greggs", len(distances))
|
||||||
|
|
||||||
|
errors = [sum((pd.Series(j) - distances) ** 2) for j in dist_series_list]
|
||||||
|
|
||||||
|
minerr = min(errors)
|
||||||
|
if minerr > 1:
|
||||||
|
print(f"warning: high error value of {minerr}")
|
||||||
|
|
||||||
|
return errors.index(min(errors))
|
||||||
|
|
||||||
|
|
||||||
def encode(location: tuple[float, float]) -> EncodedLocation:
|
def encode(location: tuple[float, float]) -> EncodedLocation:
|
||||||
|
|
@ -110,57 +176,21 @@ def encode(location: tuple[float, float]) -> EncodedLocation:
|
||||||
|
|
||||||
greggs = np.array(fetch_data("greggs"))
|
greggs = np.array(fetch_data("greggs"))
|
||||||
|
|
||||||
repeat_rows = np.tile(greggs, (len(greggs), 1, 1))
|
|
||||||
repeat_cols = np.transpose(repeat_rows, (1, 0, 2))
|
|
||||||
dist_matrix = spherical_dist(repeat_rows, repeat_cols)
|
|
||||||
|
|
||||||
repeated = np.tile(location, (len(greggs), 1))
|
repeated = np.tile(location, (len(greggs), 1))
|
||||||
distances = spherical_dist(repeated, greggs)
|
distances = pd.Series(spherical_dist(repeated, greggs)).sort_values()
|
||||||
distances = pd.Series(distances)
|
|
||||||
distances = distances.sort_values()
|
|
||||||
|
|
||||||
closest = distances.head(LOCS_COUNT)
|
closest = distances.head(LOCS_COUNT)
|
||||||
|
|
||||||
result: EncodedLocation = []
|
return [(v, encode_greggs(i)) for v, i in zip(closest.values, closest.index)]
|
||||||
for v, i in zip(closest.values, closest.index):
|
|
||||||
greggs_distances = np.sort(dist_matrix[i])[1:DISTS_COUNT+1]
|
|
||||||
|
|
||||||
result.append((v, list(map(float, greggs_distances))))
|
|
||||||
|
|
||||||
# Stub
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def decode(location: EncodedLocation, disp: bool = False) -> tuple[float, float]:
|
def decode(location: EncodedLocation, disp: bool = False) -> tuple[float, float]:
|
||||||
"""Decode into a location."""
|
"""Decode into a location."""
|
||||||
|
|
||||||
# form the distances matrix
|
|
||||||
greggs_raw = fetch_data("greggs")
|
|
||||||
greggs = np.array(greggs_raw)
|
|
||||||
repeat_rows = np.tile(greggs, (len(greggs), 1, 1))
|
|
||||||
repeat_cols = np.transpose(repeat_rows, (1, 0, 2))
|
|
||||||
dist_matrix = spherical_dist(repeat_rows, repeat_cols)
|
|
||||||
|
|
||||||
# split the distances matrix into a list of series, which allows us to sort each row
|
|
||||||
dist_series_list = []
|
|
||||||
for i in dist_matrix:
|
|
||||||
dist_series_list.append(pd.Series(i).sort_values().head(len(location[0][1])+1)[1:])
|
|
||||||
|
|
||||||
# part 1: find the ID of each gregg's
|
# part 1: find the ID of each gregg's
|
||||||
closest_greggs = []
|
closest_greggs = [decode_greggs(dists) for _, dists in location]
|
||||||
for loc in location:
|
|
||||||
dists = loc[1]
|
|
||||||
|
|
||||||
errors = [sum((j - dists) ** 2) for j in dist_series_list]
|
|
||||||
|
|
||||||
minerr = min(errors)
|
|
||||||
if minerr > 1:
|
|
||||||
print(f"warning: high error value of {minerr}")
|
|
||||||
|
|
||||||
closest_greggs.append(errors.index(min(errors)))
|
|
||||||
|
|
||||||
# part 2: trilaterate
|
# part 2: trilaterate
|
||||||
stations: list[StationT] = [(greggs_raw[g], location[i][0]) for i, g in enumerate(closest_greggs)]
|
greggs = fetch_data("greggs")
|
||||||
|
stations: list[StationT] = [(greggs[g], location[i][0]) for i, g in enumerate(closest_greggs)]
|
||||||
return trilaterate(stations, disp=disp)
|
return trilaterate(stations, disp=disp)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue