new_ml_func

dev
AlexP077 3 years ago committed by Dmitry Titov
parent e12dfd9c08
commit 3cfc80e034

@ -20,6 +20,7 @@ from service.models import PlacementPoint
# Запустить scheduler
# celery -A postamates beat -l INFO.
@shared_task()
def raschet():
conn = sqlalchemy.create_engine(
@ -27,11 +28,26 @@ def raschet():
connect_args={'options': '-csearch_path=public'},
)
query = text('select * from service_placementpoint')
pts = pd.read_sql(query, conn.connect())
feats = ['id', 'popul_home', 'popul_job', 'other_post_cnt', 'yndxfood_sum', 'target_post_cnt']
connection = conn.connect()
pts = pd.read_sql(query, connection)
pts.loc[pts.target_dist > 700, 'target_dist'] = 700
pts = pts.sort_values(by='id').reset_index(drop=True)
feats = [
'id', 'metro_dist', 'target_dist', 'property_price_bargains', 'property_price_offers',
'property_mean_floor',
'property_era', 'flats_cnt_2', 'flats_cnt', 'popul_home', 'popul_job', 'other_post_cnt', 'yndxfood_sum',
'yndxfood_cnt', 'school_cnt', 'kindergar_cnt', 'target_post_cnt', 'public_stop_cnt', 'sport_center_cnt',
'pharmacy_cnt', 'supermarket_cnt', 'supermarket_premium_cnt', 'clinic_cnt', 'bank_cnt', 'reca_cnt',
'lab_cnt', 'culture_cnt', 'attraction_cnt', 'mfc_cnt', 'bc_cnt', 'tc_cnt', 'rival_pvz_cnt',
'rival_post_cnt',
'business_activity', 'age_day', 'target_age_nearby_mean', 'target_cnt_ao_mean',
# 'target_cnt_nearby_mean'
]
# Записи для обучения
pts_trn = pts.loc[pts.sample_trn == True].reset_index(drop=True)
# pts_trn = pts_trn.loc[pts_trn.fact < 450].reset_index(drop=True)
X_trn = pts_trn[feats].drop(columns=['id'])
Y_trn = pts_trn[['fact']]
@ -39,24 +55,27 @@ def raschet():
pts_inf = pts.loc[(pts.status == 'Pending') |
(pts.status == 'Installation') |
(pts.status == 'Cancelled') |
((pts.status == 'Pending') & (pts.sample_trn == False))].reset_index(drop=True)
((pts.status == 'Working') & (pts.sample_trn == False))].reset_index(drop=True)
pts_inf['age_day'] = 240
X_inf = pts_inf[feats]
seeds = [39, 85, 15, 1, 59]
# Обучение, инференс
r2_scores = []
mapes = []
y_infers = []
while len(r2_scores) < 5:
x_trn, x_test, y_trn, y_test = ms.train_test_split(X_trn, Y_trn, test_size=0.2)
model = catboost.CatBoostRegressor(cat_features=[])
for i in seeds:
x_trn, x_test, y_trn, y_test = ms.train_test_split(X_trn, Y_trn, test_size=0.2, random_state=i)
model = catboost.CatBoostRegressor(cat_features=['property_era'], random_state=i)
model.fit(x_trn, y_trn, verbose=False)
r2_score = metrics.r2_score(y_test, model.predict(x_test))
mape = metrics.mean_absolute_percentage_error(y_test, model.predict(x_test))
if ((r2_score > 0) & (mape < 0.5)):
if ((r2_score > 0.45) & (mape < 0.25)):
r2_scores.append(r2_score)
mapes.append(mape)
y_infers.append(model.predict(X_inf.drop(columns=['id'])))
current_pred = sum(y_infers) / 5
# Обновление полей по результатам работы модели
@ -65,7 +84,7 @@ def raschet():
pd.concat(
[
X_inf[['id']],
pd.DataFrame({'prediction_current': current_pred}),
pd.DataFrame([{'prediction_current': current_pred}]),
],
axis=1,
).set_index('id'),
@ -86,6 +105,11 @@ def raschet():
update_records1.append((int(update_fields.prediction_current[i]), int(update_fields.id[i])))
sql_update_query = """Update service_placementpoint set prediction_current = %s where id = %s"""
try:
cursor.executemany(sql_update_query, update_records1)
conn2.commit()
except Exception:
cursor.execute('ROLLBACK')
cursor.executemany(sql_update_query, update_records1)
conn2.commit()

Loading…
Cancel
Save