diff --git a/.flake8 b/.flake8 index 899b0b3..8216ee5 100644 --- a/.flake8 +++ b/.flake8 @@ -7,3 +7,4 @@ per-file-ignores = service/migrations/*:E501 service/views.py:C901 service/models.py:F403,F401 + service/tasks.py:E712 diff --git a/postamates/settings.py b/postamates/settings.py index a0343b9..3e12de7 100644 --- a/postamates/settings.py +++ b/postamates/settings.py @@ -171,3 +171,4 @@ PROJECT_NAME = 'postamates' CACHE_TIMEOUT = 0 DEFAULT_PLACEMENT_POINT_UPDATE_RADIUS = 500 AGE_DAY_LIMIT = 270 +AGE_DAY_BORDER = 30 diff --git a/requirements.txt b/requirements.txt index b0e619e..bfaeee9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,6 +39,7 @@ idna==3.4 inflection==0.5.1 itypes==1.2.0 Jinja2==3.1.2 +joblib==1.2.0 kiwisolver==1.4.4 kombu==5.2.4 MarkupSafe==2.1.2 @@ -66,12 +67,14 @@ PyYAML==6.0 requests==2.28.2 ruamel.yaml==0.17.21 ruamel.yaml.clib==0.2.7 +scikit-learn==1.2.2 scipy==1.10.0 shapely==2.0.1 six==1.16.0 SQLAlchemy==2.0.3 sqlparse==0.4.3 tenacity==8.2.1 +threadpoolctl==3.1.0 tqdm==4.64.0 typing_extensions==4.5.0 tzdata==2022.7 diff --git a/service/migrations/0018_placementpoint_fact_raw.py b/service/migrations/0018_placementpoint_fact_raw.py new file mode 100644 index 0000000..0d98fd6 --- /dev/null +++ b/service/migrations/0018_placementpoint_fact_raw.py @@ -0,0 +1,18 @@ +# Generated by Django 3.2 on 2023-03-26 12:47 +from django.db import migrations +from django.db import models + + +class Migration(migrations.Migration): + + dependencies = [ + ('service', '0017_rename_wkt_rivals_wkt'), + ] + + operations = [ + migrations.AddField( + model_name='placementpoint', + name='fact_raw', + field=models.IntegerField(blank=True, null=True), + ), + ] diff --git a/service/models.py b/service/models.py index 1eb3d82..98e5061 100644 --- a/service/models.py +++ b/service/models.py @@ -20,6 +20,7 @@ class PlacementPoint(models.Model): plan_first = models.IntegerField(null=True, blank=True, verbose_name='Плановый показатель начальный') plan_current = models.IntegerField(null=True, blank=True, verbose_name='Плановый показатель текущий') fact = models.IntegerField(null=True, blank=True, verbose_name='Фактический показатель') + fact_raw = models.IntegerField(null=True, blank=True) delta_first = models.IntegerField(null=True, blank=True, verbose_name='Разница начальная') delta_current = models.IntegerField(null=True, blank=True, verbose_name='Разница текущая') sample_trn = models.BooleanField(null=True, blank=True) diff --git a/service/service.py b/service/service.py index 51bfc2e..c5445dc 100644 --- a/service/service.py +++ b/service/service.py @@ -27,7 +27,7 @@ class PointService: pnts = models.PlacementPoint.objects.filter( geometry__distance_lt=(point.geometry, Distance(m=DEFAULT_PLACEMENT_POINT_UPDATE_RADIUS)), ) - pnts.update(target_post_cnt=F('target_post_cnt') + 1) + pnts.update(prediction_first=F('prediction_current'), target_post_cnt=F('target_post_cnt') + 1) raschet.delay() elif new_status == PointStatus.Cancelled.name or new_status == PointStatus.Pending.name: if point.status == PointStatus.Installation.name: @@ -36,6 +36,8 @@ class PointService: ) pnts.update(target_post_cnt=F('target_post_cnt') - 1 if F('target_post_cnt') != 0 else 0) raschet.delay() + elif new_status == PointStatus.Working.name and point.status == PointStatus.Pending.name: + raschet.delay() @staticmethod def update_status(qs: models.PlacementPoint, new_status: str) -> models.PlacementPoint: diff --git a/service/tasks.py b/service/tasks.py index 6df2dfb..457b405 100644 --- a/service/tasks.py +++ b/service/tasks.py @@ -1,7 +1,12 @@ -import time - +import catboost +import pandas as pd +import psycopg2 +import sqlalchemy from celery import shared_task from django.db.models import F +from sklearn import metrics +from sklearn import model_selection as ms +from sqlalchemy import text from postamates.settings import AGE_DAY_LIMIT from service.models import PlacementPoint @@ -14,9 +19,70 @@ from service.models import PlacementPoint @shared_task() def raschet(): - print('Hello. Celery task is running...') - time.sleep(5) - print('Finish') + conn = sqlalchemy.create_engine( + 'postgresql://sst_postamates_user:sst_postamates_pass@postnet-dev.selftech.ru:5487/sst_postamates_db', + connect_args={'options': '-csearch_path=public'}, + ) + query = text('select * from service_placementpoint') + pts = pd.read_sql(query, conn.connect()) + feats = ['id', 'popul_home', 'popul_job', 'other_post_cnt', 'yndxfood_sum', 'target_post_cnt'] + + # Записи для обучения + pts_trn = pts.loc[pts.sample_trn == True].reset_index(drop=True) + X_trn = pts_trn[feats].drop(columns=['id']) + Y_trn = pts_trn[['fact']] + + # Записи для инференса + pts_inf = pts.loc[(pts.status == 'Pending') | + (pts.status == 'Installation') | + (pts.status == 'Cancelled') | + ((pts.status == 'Pending') & (pts.sample_trn == False))].reset_index(drop=True) + X_inf = pts_inf[feats] + + # Обучение, инференс + r2_scores = [] + mapes = [] + y_infers = [] + + while len(r2_scores) < 5: + x_trn, x_test, y_trn, y_test = ms.train_test_split(X_trn, Y_trn, test_size=0.2) + model = catboost.CatBoostRegressor(cat_features=[]) + model.fit(x_trn, y_trn, verbose=False) + r2_score = metrics.r2_score(y_test, model.predict(x_test)) + mape = metrics.mean_absolute_percentage_error(y_test, model.predict(x_test)) + if ((r2_score > 0) & (mape < 0.5)): + r2_scores.append(r2_score) + mapes.append(mape) + y_infers.append(model.predict(X_inf.drop(columns=['id']))) + current_pred = sum(y_infers) / 5 + + # Обновление полей по результатам работы модели + update_fields = pts_inf[['id', 'delta_current', 'delta_first', 'plan_current', 'plan_first', 'prediction_first']] + update_fields = update_fields.join( + pd.concat( + [ + X_inf[['id']], + pd.DataFrame({'prediction_current': current_pred}), + ], + axis=1, + ).set_index('id'), + on='id', + ) + update_fields['prediction_current'] = update_fields['prediction_current'].astype(int) + + # Загрузка в базу обновленных значений + conn2 = psycopg2.connect( + database='sst_postamates_db', user='sst_postamates_user', password='sst_postamates_pass', + host='postnet-dev.selftech.ru', port='5487', options='-c search_path=public', + ) + cursor = conn2.cursor() + update_records1 = [] + for i in range(0, len(update_fields)): + update_records1.append((int(update_fields.prediction_current[i]), int(update_fields.id[i]))) + + sql_update_query = """Update service_placementpoint set prediction_current = %s where id = %s""" + cursor.executemany(sql_update_query, update_records1) + conn2.commit() @shared_task() diff --git a/service/views.py b/service/views.py index 43a07aa..b359a76 100644 --- a/service/views.py +++ b/service/views.py @@ -12,6 +12,7 @@ from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.viewsets import ReadOnlyModelViewSet +from postamates.settings import AGE_DAY_BORDER from service import models from service import pagination from service import serializers @@ -161,7 +162,13 @@ class PlacementPointViewSet(ReadOnlyModelViewSet): qs = models.PlacementPoint.objects.filter(postamat_id=point_id) if not qs: return Response(status=HTTPStatus.NOT_FOUND) - qs.update(**{'fact': fact}) + for q in qs: + if q.age_day < AGE_DAY_BORDER: + qs.update(**{'fact': fact, 'fact_raw': fact}) + else: + new_fact = fact // (q.age_day / AGE_DAY_BORDER) + qs.update(fact=new_fact, fact_raw=fact) + qs.update(**{'fact_raw': fact}) return Response({'message': 'fact updated'}, status=HTTPStatus.OK) @action(detail=False, methods=['put'])