Skip to content
This repository was archived by the owner on Jan 8, 2026. It is now read-only.
Open
22 changes: 19 additions & 3 deletions gramex/handlers/mlhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from slugify import slugify
from tornado.gen import coroutine
from tornado.web import HTTPError
from sklearn.metrics import get_scorer
from sklearn.model_selection import cross_val_predict, cross_val_score

op = os.path
MLCLASS_MODULES = [
Expand All @@ -40,6 +42,7 @@
'nums': [],
'cats': [],
'target_col': None,
'cv': True,
}
ACTIONS = ['predict', 'score', 'append', 'train', 'retrain']
DEFAULT_TEMPLATE = op.join(op.dirname(__file__), '..', 'apps', 'mlhandler', 'template.html')
Expand Down Expand Up @@ -103,7 +106,6 @@ def setup(cls, data=None, model={}, config_dir='', **kwargs):

cls.set_opt('class', model.get('class'))
cls.set_opt('params', model.get('params', {}))

if op.exists(cls.model_path): # If the pkl exists, load it
cls.model = joblib.load(cls.model_path)
elif data is not None:
Expand All @@ -112,14 +114,23 @@ def setup(cls, data=None, model={}, config_dir='', **kwargs):
data = cls._filtercols(data)
data = cls._filterrows(data)
cls.model = cls._assemble_pipeline(data, mclass=mclass, params=params)

# train the model
target = data[target_col]
train = data[[c for c in data if c != target_col]]
# cross validation
cls.cross_validation(train,target)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required here.

gramex.service.threadpool.submit(
_fit, cls.model, train, target, cls.model_path, cls.name)
cls.config_store.flush()


@classmethod
def cross_validation(cls,train,target):
cv = cls.get_opt('cv',True)
if cv:
CVscore = cross_val_score(cls.model.steps[-1][1], X=train, y=target, cv=cv)
CVavg = sum(CVscore)/len(CVscore)
print('Cross Validation Score : ',CVavg)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CV should take place within the train method only.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if cv:
    cvscore = cross_val_score(mod, X=train, y=target, cv=cv)
else:
   # Do the usual .fit


@classmethod
def load_data(cls, default=pd.DataFrame()):
try:
Expand Down Expand Up @@ -268,6 +279,10 @@ def _predict(self, data=None, score_col=''):
self.model = cache.open(self.model_path, joblib.load)
try:
target = data.pop(score_col)
metric = self.get_argument('_metric', False)
if metric:
scorer = get_scorer(metric)
return scorer(self.model, data, target)
return self.model.score(data, target)
except KeyError:
# Set data in the same order as the transformer requests
Expand Down Expand Up @@ -347,6 +362,7 @@ def _train(self, data=None):
target = data[target_col]
train = data[[c for c in data if c != target_col]]
self.model = self._assemble_pipeline(data, force=True)
self.cross_validation(train,target)
_fit(self.model, train, target, self.model_path)
return {'score': self.model.score(train, target)}

Expand Down
5 changes: 5 additions & 0 deletions tests/test_mlhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ def test_get_bulk_score(self):
data=self.df.to_json(orient='records'),
headers={'Content-Type': 'application/json'})
self.assertGreaterEqual(resp.json()['score'], self.ACC_TOL)
resp = self.get(
'/mlhandler?_action=score&_metric=f1_weighted', method='post',
data=self.df.to_json(orient='records'),
headers={'Content-Type': 'application/json'})
self.assertGreaterEqual(resp.json()['score'], self.ACC_TOL)

def test_get_cache(self):
df = pd.DataFrame.from_records(self.get('/mlhandler?_cache=true').json())
Expand Down